# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for where op."""

# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu
# pylint: enable=g-direct-tensorflow-import


class WhereOpTest(xla_test.XLATestCase):

  def __init__(self, method_name="runTest"):
    super(WhereOpTest, self).__init__(method_name)
    if config.list_logical_devices("TPU"):
      with self.session() as sess:
        sess.run(tpu.initialize_system())

  def testWhere(self):
    """Test first form of where (return indices)."""

    with self.session() as sess:
      with self.test_scope():
        x = array_ops.placeholder(dtypes.bool)
        true_vals = array_ops.where(x)

      # Output of the computation is dynamic.
      feed = [[True, False, False], [False, True, True]]
      self.assertAllEqual([[0, 0], [1, 1], [1, 2]],
                          sess.run(true_vals, {x: feed}))

  def testWhereGather(self):
    """Test where followed by a gather."""

    with self.session() as sess:
      with self.test_scope():
        x = array_ops.placeholder(dtypes.bool)
        value = array_ops.constant([[0, 1], [2, 3]], dtypes.float32)
        true_vals = array_ops.where(x)

        # Gather 0, 2, 3.
        gathered = array_ops.gather_nd(value, true_vals)

      feed = [[True, False], [True, True]]
      self.assertAllEqual([0, 2, 3], sess.run(gathered, {x: feed}))

  def testWhereGatherReduce(self):
    """Test where followed by a gather and a reduce."""

    with self.session() as sess:
      with self.test_scope():
        x = array_ops.placeholder(dtypes.bool)
        value = array_ops.constant([[0, 1], [2, 3]], dtypes.float32)
        indices = array_ops.where(x)

        # Reduce to 5
        gathered = array_ops.gather_nd(value, indices)
        reduction = math_ops.reduce_sum(gathered)

      feed = [[True, False], [True, True]]
      self.assertAllEqual(5, sess.run(reduction, {x: feed}))

  def testWhere1D(self):
    """Test first form of where (return indices)."""

    with self.session() as sess:
      with self.test_scope():
        x = array_ops.placeholder(dtypes.bool)
        result = array_ops.where(x)

      # Output of the computation is dynamic.
      feed = [True, False, True]
      self.assertAllEqual([[0], [2]], sess.run(result, {x: feed}))

  def testWhereInt(self):
    """Test Where with integers."""

    with self.session() as sess:
      with self.test_scope():
        x = array_ops.placeholder(dtypes.int32)
        result = array_ops.where(x)

      # Output of the computation is dynamic.
      feed = [-1, 0, 1]
      self.assertAllEqual([[0], [2]], sess.run(result, {x: feed}))

  def testWhereFloat(self):
    """Test Where with floats."""

    with self.session() as sess:
      with self.test_scope():
        x = array_ops.placeholder(dtypes.float32)
        result = array_ops.where(x)

      # Output of the computation is dynamic.
      feed = [-1.0, -0.0, 0.0, 1.0]
      self.assertAllEqual([[0], [3]], sess.run(result, {x: feed}))

  def testWhereComplex(self):
    """Test Where with floats."""

    with self.session() as sess:
      with self.test_scope():
        x = array_ops.placeholder(dtypes.complex64)
        result = array_ops.where(x)

      # Output of the computation is dynamic.
      feed = [
          -1.0 + 0.0j, -0.0 + 0.0j, 0.0 - 0.0j, 1.0 - 1.0j, 1.0 + 0.0j,
          0.0 + 1.0j
      ]
      self.assertAllEqual([[0], [3], [4], [5]], sess.run(result, {x: feed}))

if __name__ == "__main__":
  test.main()
