# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for multinomial generation ops in the XLA JIT compiler."""

import collections

import numpy as np

from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test


# TODO(srvasude): Merge this with
# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py.
class CategoricalTest(xla_test.XLATestCase):
  """Test cases for random-number generating operators."""

  def output_dtypes(self):
    return set(self.int_types).intersection([np.int32, np.int64])

  def _chi2(self, expected, actual):
    """Returns Chi2 GOF statistic."""
    actual = np.asarray(actual)
    expected = np.asarray(expected)
    diff = actual - expected
    chi2 = np.sum(diff * diff / expected)
    return chi2

  def _do_sampling(self, logits, num_samples):
    """Categorical samples from given input.

    Args:
      logits: Numpy ndarray of shape [batch_size, num_classes].
      num_samples: Int; number of samples to draw.

    Returns:
      Frequencies from sampled classes; shape [batch_size, num_classes].
    """
    with self.session(), self.test_scope():
      random_seed.set_random_seed(1618)
      op = random_ops.multinomial(logits, num_samples,
                                  output_dtype=dtypes.int32)
      d = self.evaluate(op)

    batch_size, num_classes = logits.shape
    freqs_mat = []
    for i in range(batch_size):
      cnts = dict(collections.Counter(d[i, :]))

      # Requires drawn class labels be in range.
      self.assertLess(max(cnts.keys()), num_classes)
      self.assertGreaterEqual(min(cnts.keys()), 0)

      freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0)
               for k in range(num_classes)]
      freqs_mat.append(freqs)

    return freqs_mat

  def _testRngIsNotConstant(self, rng, dtype, output_dtype):
    # Tests that 'rng' does not always return the same value.
    with self.session():
      with self.test_scope():
        x = rng(dtype, output_dtype)

      # The random-number generator, if working correctly, should produce the
      # same output multiple times with low probability.
      y = self.evaluate(x)
      z = self.evaluate(x)
      w = self.evaluate(x)

      # We use exact equality here. If the random-number generator is producing
      # deterministic output, all three outputs will be bitwise identical.
      self.assertTrue((not np.array_equal(y, z)) or
                      (not np.array_equal(z, w)) or
                      (not np.array_equal(y, w)))

  def testCategoricalIsNotConstant(self):
    def rng(dtype, output_dtype):
      return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10,
                                    output_dtype=output_dtype)

    dtype = np.float32
    for output_dtype in self.output_dtypes():
      self._testRngIsNotConstant(rng, dtype, output_dtype)

  @test.disable_with_predicate(
      pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.")
  def testCategoricalIsInRange(self):
    for dtype in self.float_types:
      for output_dtype in self.output_dtypes():
        with self.session():
          with self.test_scope():
            x = random_ops.multinomial(
                array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
                output_dtype=output_dtype)
          y = self.evaluate(x)
          self.assertTrue((y >= 0).sum() == 1000)
          self.assertTrue((y < 20).sum() == 1000)

  def testSamplingCorrectness(self):
    np.random.seed(1618)  # Make it reproducible.
    num_samples = 40000

    rand_probs = np.random.dirichlet([1., 1., 2., 3.])
    rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3)  # batched
    for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]:
      probs = np.asarray(probs)
      if len(probs.shape) == 1:
        probs = probs.reshape(1, probs.size)  # singleton batch

      logits = np.log(probs).astype(np.float32)
      freqs = self._do_sampling(logits, num_samples)

      # the test here is similar to
      # python/kernel_tests/random/multinomial_op_test.py
      # Note that df >= 1 in all these cases. Choosing a cutoff of 1e-3
      # corresponds to an alpha value of 2.5% for df = 1, and smaller for larger
      # df.
      chi2 = self._chi2(probs, freqs)
      self.assertLess(chi2, 1e-3)

  def testStatelessMultinomialIsInRange(self):
    for dtype in self.float_types.intersection(
        [dtypes.float32, dtypes.bfloat16]):
      for output_dtype in self.output_dtypes():
        with self.session() as sess:
          with self.test_scope():
            seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
            x = stateless_random_ops.stateless_multinomial(
                array_ops.ones(shape=[1, 20], dtype=dtype),
                1000,
                seed_t,
                output_dtype=output_dtype)
          y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
          self.assertTrue((y >= 0).sum() == 1000)
          self.assertTrue((y < 20).sum() == 1000)

  def testDeterminismMultinomial(self):
    # Stateless values should be equal iff the seeds are equal (roughly)
    num_samples = 10
    with self.session(), self.test_scope():
      seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
      seeds = [(x, y) for x in range(5) for y in range(5)] * 3
      for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
                                                [0.25, 0.75]]):
        pure = stateless_random_ops.stateless_multinomial(
            logits, num_samples, seed=seed_t)
        values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds]
        for s0, v0 in values:
          for s1, v1 in values:
            self.assertEqual(s0 == s1, np.all(v0 == v1))

  def testEmpty(self):
    with self.session():
      with self.test_scope():
        x = random_ops.multinomial(
            array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
        y = self.evaluate(x)
        self.assertEqual(y.shape, (42, 0))

  def testEmptyStateless(self):
    with self.session() as sess:
      with self.test_scope():
        seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
        x = stateless_random_ops.stateless_multinomial(
            array_ops.zeros([42, 40]),
            0,
            seed=seed_t,
            output_dtype=dtypes.int32)
        y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
        self.assertEqual(y.shape, (42, 0))



if __name__ == '__main__':
  googletest.main()
