# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for XLA TensorArray Ops."""

import numpy as np

from tensorflow.compiler.tests import xla_test
from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


def _make_converter(dtype):
  def _converter(x):
    return np.asarray(x).astype(dtype.as_numpy_dtype)
  return _converter


# This lets me define `fn` repeatedly to pass to xla.compile.
#
# pylint: disable=function-redefined
@test_util.run_v1_only("b/")  # Support TF2 list operations
@test_util.with_control_flow_v2
class TensorArrayTest(xla_test.XLATestCase):

  @test_util.disable_control_flow_v2("Tries to evaluate flow")
  def testTensorArrayWriteRead(self):
    with self.session() as session, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)

        w0 = ta.write(0, [[4.0, 5.0]])
        w1 = w0.write(1, [[1.0, 3.0]])
        w2 = w1.write(2, [[7.0, -8.5]])

        r0 = w2.read(0)
        r1 = w2.read(1)
        r2 = w2.read(2)
        flow = w2.flow
        return [r0, r1, r2, flow]

      d0, d1, d2, flow_val = self.evaluate(xla.compile(fn))
      self.assertAllEqual([[4.0, 5.0]], d0)
      self.assertAllEqual([[1.0, 3.0]], d1)
      self.assertAllEqual([[7.0, -8.5]], d2)
      self.assertAllEqual([], flow_val.shape)

  def _testTensorArrayWritePack(self, tf_dtype):
    with self.session(), self.test_scope():
      convert = _make_converter(tf_dtype)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=3)

        w0 = ta.write(0, convert([[4.0, 5.0]]))
        w1 = w0.write(1, convert([[6.0, 7.0]]))
        w2 = w1.write(2, convert([[8.0, 9.0]]))

        return w2.stack()

      self.assertAllEqual(
          convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]),
          self.evaluate(xla.compile(fn)[0]))

  def testTensorArrayWritePack(self):
    for dtype in self.numeric_tf_types:
      self._testTensorArrayWritePack(dtype)

  def testEmptyTensorArrayPack(self):
    with self.session(), self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)

        empty_element = np.zeros((0, 1), dtype=np.float32)
        w0 = ta.write(0, empty_element)
        w1 = w0.write(1, empty_element)
        w2 = w1.write(2, empty_element)

        return w2.stack()

      self.assertAllEqual([3, 0, 1], self.evaluate(xla.compile(fn)[0]).shape)

  def _testTensorArrayWriteConcat(self, tf_dtype):
    with self.session(), self.test_scope():
      convert = _make_converter(tf_dtype)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=3)

        w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0]]))
        w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]]))
        w2 = w1.write(2, convert([[8.0, 9.0], [124.0, 125.0]]))

        return w2.concat()

      self.assertAllEqual(
          convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0],
                   [8.0, 9.0], [124.0, 125.0]]),
          self.evaluate(xla.compile(fn)[0]))

  @test_util.disable_control_flow_v2("b/122315751 (concat)")
  def testTensorArrayWriteConcat(self):
    for dtype in self.numeric_tf_types:
      self._testTensorArrayWriteConcat(dtype)

  def _testTensorArrayUnpackRead(self, tf_dtype):
    with self.session() as session, self.test_scope():
      convert = _make_converter(tf_dtype)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=3)

        # Unpack a vector into scalars
        w0 = ta.unstack(convert([1.0, 2.0, 3.0]))
        r0 = w0.read(0)
        r1 = w0.read(1)
        r2 = w0.read(2)

        return [r0, r1, r2]

      d0, d1, d2 = self.evaluate(xla.compile(fn))
      self.assertAllEqual(convert(1.0), d0)
      self.assertAllEqual(convert(2.0), d1)
      self.assertAllEqual(convert(3.0), d2)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=3)

        # Unpack a matrix into vectors.
        w1 = ta.unstack(
            convert([[1.0, 1.03125], [2.0, 2.03125], [3.0, 3.03125]]))
        r0 = w1.read(0)
        r1 = w1.read(1)
        r2 = w1.read(2)
        return [r0, r1, r2]

      d0, d1, d2 = self.evaluate(xla.compile(fn))

      self.assertAllEqual(convert([1.0, 1.03125]), d0)
      self.assertAllEqual(convert([2.0, 2.03125]), d1)
      self.assertAllEqual(convert([3.0, 3.03125]), d2)

      def fn():
        # Reset ta because we're going to change the shape, else shape
        # inference will throw an error.
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=3)

        # Try unpacking an empty matrix, which should not cause an error.
        w2 = ta.unstack(convert([[], [], []]))
        r0 = w2.read(0)
        r1 = w2.read(1)
        r2 = w2.read(2)
        return [r0, r1, r2]

      d0, d1, d2 = self.evaluate(xla.compile(fn))
      self.assertAllEqual(convert([]), d0)
      self.assertAllEqual(convert([]), d1)
      self.assertAllEqual(convert([]), d2)

  def _testTensorArrayUnpackReadMaybeLegacy(self):
    for dtype in self.numeric_tf_types:
      self._testTensorArrayUnpackRead(dtype)

  def testTensorArrayUnpackRead(self):
    self._testTensorArrayUnpackReadMaybeLegacy()

  def _testTensorArraySplitRead(self, tf_dtype):
    with self.session() as session, self.test_scope():
      convert = _make_converter(tf_dtype)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=3)

        # Split an empty vector.
        lengths = constant_op.constant([0, 0, 0])
        w0 = ta.split(convert([]), lengths=lengths)
        r0 = w0.read(0)
        r1 = w0.read(1)
        r2 = w0.read(2)
        return [r0, r1, r2]

      d0, d1, d2 = self.evaluate(xla.compile(fn))

      self.assertAllEqual(convert([]), d0)
      self.assertAllEqual(convert([]), d1)
      self.assertAllEqual(convert([]), d2)

      def fn():
        # Split a vector.
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=3)
        lengths = constant_op.constant([1, 1, 1])
        w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths)
        r0 = w0.read(0)
        r1 = w0.read(1)
        r2 = w0.read(2)
        return [r0, r1, r2]

      d0, d1, d2 = self.evaluate(xla.compile(fn))

      self.assertAllEqual(convert([1.0]), d0)
      self.assertAllEqual(convert([2.0]), d1)
      self.assertAllEqual(convert([3.0]), d2)

      def fn():
        # Split a matrix.
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=3)
        lengths = constant_op.constant([1, 1, 1])
        w0 = ta.split(
            convert([[1.0, 101.0], [2.0, 121.0], [3.0, 127.0]]),
            lengths=lengths)
        r0 = w0.read(0)
        r1 = w0.read(1)
        r2 = w0.read(2)
        return [r0, r1, r2]

      d0, d1, d2 = self.evaluate(xla.compile(fn))
      self.assertAllEqual(convert([[1.0, 101.0]]), d0)
      self.assertAllEqual(convert([[2.0, 121.0]]), d1)
      self.assertAllEqual(convert([[3.0, 127.0]]), d2)

  @test_util.disable_control_flow_v2("b/122315872 (split)")
  def testTensorArraySplitRead(self):
    for dtype in self.numeric_tf_types:
      self._testTensorArraySplitRead(dtype)

  @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2")
  def testTensorGradArrayWriteRead(self):
    with self.session() as session, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)

        w0 = ta.write(0, [[4.0]])
        w1 = w0.write(1, [[1.0]])
        w2 = w1.write(2, [[-3.0]])

        g_ta = w2.grad("grad")

        g_w0 = g_ta.write(0, [[5.0]])
        g_w1 = g_w0.write(1, [[2.0]])
        g_w2 = g_w1.write(2, [[-2.0]])

        r0 = w2.read(0)
        r1 = w2.read(1)
        r2 = w2.read(2)

        g_r0 = g_w2.read(0)
        g_r1 = g_w2.read(1)
        g_r2 = g_w2.read(2)

        return [r0, r1, r2, g_r0, g_r1, g_r2]

      d0, d1, d2, g_d0, g_d1, g_d2 = self.evaluate(xla.compile(fn))
      self.assertAllEqual([[4.0]], d0)
      self.assertAllEqual([[1.0]], d1)
      self.assertAllEqual([[-3.0]], d2)
      self.assertAllEqual([[5.0]], g_d0)
      self.assertAllEqual([[2.0]], g_d1)
      self.assertAllEqual([[-2.0]], g_d2)

  @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2")
  def testTensorGradArrayDynamicWriteRead(self):
    with self.session() as session, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)

        w0 = ta.write(0, [[4.0]])
        w1 = w0.write(1, [[1.0]])
        w2 = w1.write(2, [[-3.0]])

        g_ta = w2.grad("grad")  # Get gradient array here so we know the shape

        s = w2.size()
        g_s = g_ta.size()

        g_w0 = g_ta.write(0, [[5.0]])
        g_w1 = g_w0.write(1, [[2.0]])
        g_w2 = g_w1.write(2, [[-2.0]])

        r0 = w2.read(0)
        r1 = w2.read(1)
        r2 = w2.read(2)

        g_r0 = g_w2.read(0)
        g_r1 = g_w2.read(1)
        g_r2 = g_w2.read(2)

        return [r0, r1, r2, g_r0, g_r1, g_r2, s, g_s]

      d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = self.evaluate(xla.compile(fn))
      self.assertAllEqual([[4.0]], d0)
      self.assertAllEqual([[1.0]], d1)
      self.assertAllEqual([[-3.0]], d2)
      self.assertAllEqual([[5.0]], g_d0)
      self.assertAllEqual([[2.0]], g_d1)
      self.assertAllEqual([[-2.0]], g_d2)
      self.assertAllEqual(3, vs)
      self.assertAllEqual(3, g_vs)

  @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2")
  def testTensorGradAccessTwiceReceiveSameObject(self):
    with self.session() as session, self.test_scope():
      ta_out = {}

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=3,
            element_shape=[1, 2])

        g_ta_0 = ta.grad("grad")
        g_ta_1 = ta.grad("grad")

        ta_out[0] = g_ta_0.handle
        ta_out[1] = g_ta_1.handle

        with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]):
          # Write with one gradient handle, read with another copy of it
          r1_0 = g_ta_1.read(0)

        with ops.control_dependencies([g_ta_0.handle.op, g_ta_1.handle.op]):
          return [r1_0]

      [d_r1_0] = self.evaluate(xla.compile(fn))
      self.assertAllEqual([[4.0, 5.0]], d_r1_0)

      # Can't assert this because adding a side output like we have here fails
      # as follows:
      #
      # ValueError: Operation u'TensorArrayGrad/TensorArrayGradV3' has been
      # marked as not fetchable.
      #
      # On the other hand, legitimately returning the handle from the
      # xla.compile function fails because we don't support DT_RESOURCE outputs
      # from XLA clusters.
      #
      # self.assertAllEqual(ta_out[0], ta_out[1])

  @test_util.disable_control_flow_v2("b/124334470")
  def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
    with self.session(), self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)
        return ta.write(-1, constant_op.constant(7)).flow

      # Test writing the wrong datatype.
      # TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
      # callers provide proper init dtype.
      with self.assertRaisesRegex(
          (ValueError, errors.InvalidArgumentError), r"("
          r"conversion requested dtype float32 for Tensor with dtype int32"
          r"|"
          r"TensorArray dtype is float but op has dtype int32"
          r")"):
        xla.compile(fn)[0].eval()

  @test_util.disable_control_flow_v2("b/124334096 verify dtype")
  def testTensorArrayReadWrongIndexOrDataTypeFails(self):
    # Find two different floating point types, create an array of
    # the first type, but try to read the other type.
    if len(self.float_types) > 1:
      dtype1, dtype2 = list(self.float_types)[:2]
      with self.session(), self.test_scope():

        def fn():
          ta = tensor_array_ops.TensorArray(
              dtype=dtype1, tensor_array_name="foo", size=3)

          w0 = ta.write(0, math_ops.cast([[4.0, 5.0]], dtype1))

          # Test reading wrong datatype.
          return gen_data_flow_ops.tensor_array_read_v3(
              handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow)

        with self.assertRaisesOpError("TensorArray dtype is "):
          self.evaluate(xla.compile(fn))

        def fn():
          ta = tensor_array_ops.TensorArray(
              dtype=dtype1, tensor_array_name="foo", size=3)

          w0 = ta.write(0, math_ops.cast([[4.0, 5.0]], dtype1))

          # Test reading from a different index than the one we wrote to
          with ops.control_dependencies([w0.read(1)]):
            return 1.0

        xla.compile(fn)[0].eval()

  @test_util.disable_control_flow_v2("b/122315872 (split)")
  def testTensorArraySplitIncompatibleShapesFails(self):
    with self.session(), self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=3,
            infer_shape=False)
        return ta.split([1.0, 2.0, 3.0], 1).flow

      with self.assertRaisesWithPredicateMatch(
          ValueError, r"Shape must be rank 1 but is rank 0"):
        xla.compile(fn)[0].eval()

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=3,
            infer_shape=False)
        return ta.split([1.0, 2.0, 3.0], [1, 2, 3]).flow

      with self.assertRaisesOpError(
          r"lengths must be equal: 1 vs. 2"):
        xla.compile(fn)[0].eval()

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=3,
            infer_shape=False)
        return ta.split(1.0, [1]).flow

      with self.assertRaisesOpError(
          r"value must have rank >= 1"):
        xla.compile(fn)[0].eval()

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=2,
            infer_shape=False)

        return ta.split([1.0], [1]).flow

      with self.assertRaisesOpError(
          r"TensorArray's size is not equal to the size of lengths "
          r"\(1 vs. 2\)"):
        xla.compile(fn)[0].eval()

  def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
    with self.session(), self.test_scope():
      c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)

        w0 = ta.write(2, c(3.0))
        w1 = w0.write(2, c(4.0))

        ta_grad = w1.grad("grad")

        w0_grad = ta_grad.write(2, c(3.0))
        w1_grad = w0_grad.write(2, c(4.0))
        w2_grad = w1_grad.write(2, c(5.0))

        return w2_grad.read(2)

      # Assert that aggregation works correctly
      self.assertAllEqual(c(12.00), xla.compile(fn)[0])

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)

        w0 = ta.write(2, c(3.0))
        w1 = w0.write(2, c(4.0))

        ta_grad = w1.grad("grad")
        # Using differing shapes causes an exception
        wb0_grad = ta_grad.write(1, c(1.0))
        wb1_grad = wb0_grad.write(1, c([1.0]))

        return wb1_grad.flow

      with self.assertRaisesOpError(
          r"Mismatched TensorArray sizes"):
        xla.compile(fn)[0].eval()

  @test_util.disable_control_flow_v2("TensorArray.grad is not supported in v2")
  def testTensorArrayWriteGradientAddMultipleAdds(self):
    for dtype in self.numeric_tf_types:
      self._testTensorArrayWriteGradientAddMultipleAdds(dtype)

  def testMultiTensorArray(self):
    with self.session(), self.test_scope():

      def fn():
        h1 = tensor_array_ops.TensorArray(
            size=1, dtype=dtypes.float32, tensor_array_name="foo")
        w1 = h1.write(0, 4.0)
        r1 = w1.read(0)

        h2 = tensor_array_ops.TensorArray(
            size=1, dtype=dtypes.float32, tensor_array_name="bar")

        w2 = h2.write(0, 5.0)
        r2 = w2.read(0)
        return r1 + r2

      self.assertAllClose(9.0, self.evaluate(xla.compile(fn)[0]))

  def _testTensorArrayGradientWriteReadType(self, dtype):
    with self.session() as session, self.test_scope():
      c = lambda x: np.array(x, dtype=dtype)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.as_dtype(dtype),
            tensor_array_name="foo",
            size=3,
            infer_shape=False)

        value_0 = constant_op.constant(c([[4.0, 5.0]]))
        value_1 = constant_op.constant(c([[3.0, 3.5]]))

        w0 = ta.write(0, value_0)
        w1 = w0.write(1, value_1)
        r0 = w1.read(0)
        r1 = w1.read(1)
        r0_2 = w1.read(0)

        # Test individual components' gradients
        grad_just_r0 = gradients_impl.gradients(
            ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])])
        grad_r0_r0_2 = gradients_impl.gradients(
            ys=[r0, r0_2],
            xs=[value_0],
            grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])])
        grad_just_r1 = gradients_impl.gradients(
            ys=[r1], xs=[value_1], grad_ys=[c([[-2.0, -4.0]])])
        # Test combined gradients
        grad = gradients_impl.gradients(
            ys=[r0, r0_2, r1],
            xs=[value_0, value_1],
            grad_ys=[c([[2.0, 3.0]]),
                     c([[1.0, -1.0]]),
                     c([[-2.0, -10.0]])])

        return [grad_just_r0, grad_r0_r0_2, grad_just_r1, grad]

      [grad_just_r0_vals, grad_r0_r0_2_vals, grad_just_r1_vals,
       grad_vals] = self.evaluate(xla.compile(fn))

      self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0])

      self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0])

      self.assertAllEqual(c([[-2.0, -4.0]]), grad_just_r1_vals[0])

      self.assertEqual(len(grad_vals), 2)
      self.assertAllEqual(c([[3.0, 2.0]]), grad_vals[0])
      self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1])

  def testTensorArrayGradientWriteRead(self):
    for dtype in self.float_types:
      self._testTensorArrayGradientWriteReadType(dtype)
    for dtype in self.complex_types:
      self._testTensorArrayGradientWriteReadType(dtype)

  def _testTensorArrayGradientWritePackConcatAndRead(self):
    with self.session() as sess, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=2,
            clear_after_read=False)

        value_0 = constant_op.constant([-1.0, 1.0])
        value_1 = constant_op.constant([-10.0, 10.0])

        w0 = ta.write(0, value_0)
        w1 = w0.write(1, value_1)
        p0 = w1.stack()
        r0 = w1.read(0)
        s0 = w1.concat()

        # Test gradient accumulation between read(0), pack(), and concat().
        with ops.control_dependencies([p0, r0, s0]):
          return gradients_impl.gradients(
              ys=[p0, r0, s0],
              xs=[value_0, value_1],
              grad_ys=[
                  [[2.0, 3.0], [4.0, 5.0]],  # stack gradient
                  [-0.5, 1.5],  # read(0) gradient
                  [20.0, 30.0, 40.0, 50.0],  # concat gradient
              ])

      grad_vals = self.evaluate(xla.compile(fn))  # 2 + 2 entries

      self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
      self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])

  @test_util.disable_control_flow_v2("b/122315751 (concat)")
  def testTensorArrayGradientWritePackConcatAndRead(self):
    self._testTensorArrayGradientWritePackConcatAndRead()

  def testTensorArrayReadTwice(self):
    with self.session(), self.test_scope():

      def fn():
        value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])

        ta_readtwice = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=2,
            clear_after_read=False)
        w_readtwice = ta_readtwice.unstack(value)
        r0_readtwice = w_readtwice.read(0)
        with ops.control_dependencies([r0_readtwice]):
          r1_readtwice = w_readtwice.read(0)

        return [r0_readtwice, r1_readtwice]

      self.assertAllEqual([1.0, -1.0], self.evaluate(xla.compile(fn))[0])

  def _testTensorArrayGradientUnpackRead(self):
    with self.session() as session, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=2,
            clear_after_read=False)

        value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])

        w = ta.unstack(value)
        r0 = w.read(0)
        r0_1 = w.read(0)
        r1 = w.read(1)

        # Test combined gradients + aggregation of read(0).
        return gradients_impl.gradients(
            ys=[r0, r0_1, r1],
            xs=[value],
            grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]])

      grad_vals = self.evaluate(xla.compile(fn))

      self.assertEqual(len(grad_vals), 1)
      self.assertAllEqual([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0])

  def testTensorArrayGradientUnpackRead(self):
    self._testTensorArrayGradientUnpackRead()

  @test_util.disable_control_flow_v2("b/122315751(concat), b/122315872(split)")
  def testTensorArrayGradientSplitConcat(self):
    with self.session() as session, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=2)

        value = constant_op.constant([[1.0, -1.0], [10.0, -10.0],
                                      [100.0, -100.0], [1000.0, -1000.0]])

        w = ta.split(value, [2, 2])
        r = w.concat()

        # Test combined gradients
        return gradients_impl.gradients(
            ys=[r],
            xs=[value],
            grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0],
                      [2000.0, -2000.0]]])

      grad_vals = self.evaluate(xla.compile(fn))

      self.assertEqual(len(grad_vals), 1)
      self.assertAllEqual([[2.0, -2.0], [20.0, -20.0], [200.0, -200.0],
                           [2000.0, -2000.0]],
                          grad_vals[0])

  def testCloseTensorArray(self):
    with self.session() as session, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)
        with ops.control_dependencies([ta.close()]):
          return 1.0

      self.evaluate(xla.compile(fn)[0])

  def testSizeTensorArray(self):
    with self.session(), self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)
        return ta.size()

      self.assertAllEqual(3, self.evaluate(xla.compile(fn))[0])

  def testWriteCloseTensorArray(self):
    with self.session(), self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32,
            tensor_array_name="foo",
            size=3,
            infer_shape=False)
        w0 = ta.write(0, [[4.0, 5.0]])
        w1 = w0.write(1, [[3.0, 1.0]])
        with ops.control_dependencies([w1.close()]):
          return 1.0

      self.evaluate(xla.compile(fn))

  # TODO(phawkins): implement while loops.
  # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
  #   np_dtype = dtype.as_numpy_dtype
  #   with self.session() as session, self.test_scope():
  #     v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5))
  #     var = variables.Variable(np.arange(100, 105, dtype=np_dtype))
  #     state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype))
  #     ta = tensor_array_ops.TensorArray(
  #         dtype=dtype,
  #         tensor_array_name="foo",
  #         size=0 if dynamic_size else 3,
  #         dynamic_size=dynamic_size)
  #     time_0 = array_ops.identity(0)

  #     def body(time, ta_t, state):
  #       sliced = array_ops.slice(
  #           v0, begin=array_ops_stack.stack([time, 0]), size=[1, -1])
  #       sliced = array_ops.squeeze(sliced)
  #       out = sliced + var + state
  #       state += sliced
  #       ta_t = ta_t.write(time, out)
  #       return (time + 1, ta_t, state)

  #     (unused_0, h_final, unused_2) = control_flow_ops.while_loop(
  #         cond=lambda time, unused_1, unused_2: time < 3,
  #         body=body,
  #         loop_vars=(time_0, ta, state0),
  #         shape_invariants=(time_0.get_shape(), tensor_shape.unknown_shape(),
  #                           tensor_shape.unknown_shape()),
  #         parallel_iterations=3)
  #     vout = h_final.stack()

  #     grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
  #     v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
  #     state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
  #     var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]

  #     self.evaluate(variables.global_variables_initializer())
  #     state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
  #         self.evaluate([state0, var, v0, vout, v0_grad, var_grad, state0_grad])
  #     )
  #     just_v0_grad_t, = self.evaluate([v0_grad])

  #     # state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ]
  #     # vout = [ v0[0] + var + state[0] |
  #     #          v0[1] + var + state[1] |
  #     #          v0[2] + var + state[2] ]
  #     #      = [ v0[0] + var + state0 |
  #     #          v0[1] + var + state0 + v0[0] |
  #     #          v0[2] + var + state0 + v0[0] + v0[1] ]
  #     #
  #     # d(vout[0])/d(v0) = [1 | 0 | 0 ]
  #     # d(vout[1])/d(v0) = [1 | 1 | 0 ]
  #     # d(vout[2])/d(v0) = [1 | 1 | 1 ]
  #     # d(vout)/d(var) = [1 | 1 | 1]
  #     # d(vout)/d(state0) = [ 1 | 1 | 1 ]

  #     state_per_time = np.array(
  #         [state0_t, state0_t + v0_t[0, :],
  #         state0_t + v0_t[0, :] + v0_t[1, :]])

  #     # Compare forward prop
  #     self.assertAllClose(v0_t + var_t + state_per_time, vout_t)

  #     # Compare backward prop
  #     expected_v0_grad_t = np.array([
  #         grad_val[0, :] + grad_val[1, :] + grad_val[2, :],
  #         grad_val[1, :] + grad_val[2, :], grad_val[2, :]
  #     ])

  #     self.assertAllEqual(expected_v0_grad_t, v0_grad_t)
  #     self.assertAllEqual(expected_v0_grad_t, just_v0_grad_t)
  #     self.assertAllClose(grad_val.sum(axis=0), var_grad_t)
  #     self.assertAllClose(grad_val.sum(axis=0), state0_grad_t)

  # def testWhileLoopWritePackGradients(self):
  #   self._testWhileLoopWritePackGradients(
  #       dynamic_size=False, dtype=dtypes.float32)
  #   # TODO(ebrevdo): re-enable when While supports non-float32 gradients.
  #   # self._testWhileLoopWritePackGradients(
  #   #     dynamic_size=False, dtype=tf.int64)

  # def testWhileLoopDynamicWritePackGradients(self):
  #   self._testWhileLoopWritePackGradients(
  #       dynamic_size=True, dtype=dtypes.float32)

  # def testGradSerialTwoLoops(self):
  #   with self.session(), self.test_scope():
  #     num_steps = 100
  #     acc = tensor_array_ops.TensorArray(
  #         dtype=dtypes.float32,
  #         size=num_steps,
  #         clear_after_read=False,
  #         element_shape=tensor_shape.scalar())
  #     i = constant_op.constant(0, name="i")
  #     x = constant_op.constant(2.0, name="x")

  #     c = lambda i, acc: i < 5

  #     def b(i, acc):
  #       x1 = cond.cond(
  #           math_ops.equal(i, 0), lambda: x,
  #           lambda: math_ops.multiply(acc.read(i - 1), 2.0))
  #       return i + 1, acc.write(i, x1)

  #     i1, acc1 = control_flow_ops.while_loop(c, b, [i, acc])

  #     z = constant_op.constant(0.0)

  #     def fn(i, acc):
  #       return i + 1, acc.write(i, z)

  #     _, acc2 = control_flow_ops.while_loop(lambda i, acc: i < num_steps, fn,
  #                                           [i1, acc1])

  #     r = acc2.stack()
  #     grad = gradients_impl.gradients(r, [x])[0]
  #     self.assertAllClose(31.0, self.evaluate(grad))

  def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
    with self.session() as session, self.test_scope():
      g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1)

      def fn():
        a = array_ops.identity(
            np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1)
        b = array_ops.identity(
            np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5)
        ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
        ta = ta.write(0, a, name="write_a")
        ta = ta.write(1, b, name="write_b")
        c = (
            ta.read(0, name="read_a_0") +  # a + b
            ta.read(1, name="read_b_0"))
        grad_a = gradients_impl.gradients([c], [a], [g0])[0]  # d(a+b)/da = 1
        grad_b = gradients_impl.gradients([c], [b], [g0])[0]  # d(a+b)/db = 1

        return [grad_a, grad_b]

      grad_a, grad_b = xla.compile(fn)

      # Test gradients calculated individually
      grad_a_t, = self.evaluate([grad_a])
      self.assertAllEqual(grad_a_t, g0)

      grad_b_t, = self.evaluate([grad_b])
      self.assertAllEqual(grad_b_t, g0)

      # Test gradients calculated jointly.
      joint_grad_a_t, joint_grad_b_t = self.evaluate([grad_a, grad_b])
      self.assertAllEqual(joint_grad_a_t, g0)
      self.assertAllEqual(joint_grad_b_t, g0)

  def testWriteShape(self):
    with self.session(), self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)
        c0 = constant_op.constant([4.0, 5.0])
        w0 = ta.write(0, c0)
        r0 = w0.read(0)

        return [c0, r0]

      c0, r0 = xla.compile(fn)

      self.assertAllEqual(c0.get_shape(), r0.get_shape())

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)
        c1 = constant_op.constant([6.0, 7.0])
        w0 = ta.write(0, c0)
        w1 = w0.write(1, c1)
        r0 = w1.read(0)
        r1 = w1.read(1)

        return [r0, c1, r1]

      [r0, c1, r1] = xla.compile(fn)

      self.assertAllEqual(c0.get_shape(), r0.get_shape())
      self.assertAllEqual(c1.get_shape(), r1.get_shape())

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=3)
        w0 = ta.write(0, c0)
        c2 = constant_op.constant([4.0, 5.0, 6.0])
        return w0.write(0, c2).flow

      with self.assertRaises(ValueError):
        self.evaluate(xla.compile(fn))

  def _testGradientWhenNotAllComponentsRead(self):
    with self.session() as session, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
        x = constant_op.constant([2.0, 3.0])
        w = ta.unstack(x)
        r0 = w.read(0)
        # Calculate (dr0/dx0, dr0/dx1).  since r0 = x0, gradients are (1, 0).
        return gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0])

      grad_r0_vals = self.evaluate(xla.compile(fn))[0]
      self.assertAllEqual(grad_r0_vals, [1.0, 0.0])

  def testGradientWhenNotAllComponentsRead(self):
    self._testGradientWhenNotAllComponentsRead()

  def _testTensorArrayEvalEmpty(self):
    with self.session(), self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, size=0, infer_shape=False)
        return ta.stack()

      with self.assertRaisesWithPredicateMatch(
          errors.InvalidArgumentError, "Uninitialized TensorArray passed to "
          "TensorArrayStack/TensorArrayGatherV3"):
        xla.compile(fn)[0].eval()

  @test_util.disable_control_flow_v2("b/124335246")
  def testTensorArrayEvalEmpty(self):
    self._testTensorArrayEvalEmpty()

  def _testTensorArrayEvalEmptyWithDefault(self):
    with self.session(), self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, size=0, infer_shape=True)
        size = ta.size()
        ta = ta.unstack(array_ops.zeros([0, 3, 5]))
        return [size, ta.stack()]

      [size, stack] = self.evaluate(xla.compile(fn))
      self.assertEqual(0, size)
      self.assertAllEqual([0, 3, 5], stack.shape)
      # Concatenating zero tensors along their first dimension gives a
      # first dimension of zero
      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:

        def fn():
          ta = tensor_array_ops.TensorArray(
              dtype=dtypes.float32, size=0, infer_shape=True)
          ta = ta.unstack(array_ops.zeros([0, 3, 5]))
          return ta.concat()

        # TODO(b/122315751): Enable this.
        self.assertAllEqual([0, 5], self.evaluate(xla.compile(fn))[0].shape)

  def testTensorArrayEvalEmptyWithDefault(self):
    self._testTensorArrayEvalEmptyWithDefault()

  def _testTensorArrayScatterRead(self, tf_dtype):
    with self.session() as session, self.test_scope():
      convert = _make_converter(tf_dtype)
      id0 = array_ops.placeholder(dtypes.int32)
      id1 = array_ops.placeholder(dtypes.int32)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=tf_dtype, tensor_array_name="foo", size=10)

        indices = constant_op.constant([1, 8])
        value = constant_op.constant(convert([[1.0, 5.0], [10.0, 20.0]]))

        w = ta.scatter(indices, value)
        r0 = w.read(id0)
        r1 = w.read(id1)

        return [r0, r1]

      # Test aggregation of read
      read_vals = session.run(xla.compile(fn), feed_dict={id0: 1, id1: 8})
      self.assertAllEqual(convert([1.0, 5.0]), read_vals[0])
      self.assertAllEqual(convert([10.0, 20.0]), read_vals[1])

  @test_util.disable_control_flow_v2("b/122315734 (scatter)")
  def testTensorArrayScatterRead(self):
    for dtype in self.numeric_tf_types:
      self._testTensorArrayScatterRead(dtype)
    self._testTensorArrayScatterRead(dtypes.bool)

  @test_util.disable_control_flow_v2("b/122315734 (scatter)")
  def testTensorArrayScatterReadAndGradients(self):
    with self.session() as session, self.test_scope():
      id0 = array_ops.placeholder(dtypes.int32)
      id1 = array_ops.placeholder(dtypes.int32)

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=10)

        indices = constant_op.constant([1, 8])
        value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])

        w = ta.scatter(indices, value)
        r0 = w.read(id0)
        r1 = w.read(id1)

        # Test combined gradients + aggregation of read(0).
        grad = gradients_impl.gradients(
            ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
        return [[r0, r1], grad]

      read_vals, grad_vals = session.run(
          xla.compile(fn), feed_dict={
              id0: 1,
              id1: 8
          })

      self.assertEqual(len(read_vals), 2)
      self.assertEqual(len(grad_vals), 1)
      self.assertAllEqual([1.0, -1.0], read_vals[0])
      self.assertAllEqual([10.0, -10.0], read_vals[1])
      self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])

  @test_util.disable_control_flow_v2("b/122315378 (gather)")
  def testTensorArrayWriteGatherAndGradients(self):
    with self.session() as session, self.test_scope():

      def fn():
        ta = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, tensor_array_name="foo", size=10)

        values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)])
        indices = constant_op.constant([1, 8])

        w = ta.unstack(values)
        g = w.gather(indices)

        # Test combined gradients + aggregation of read(0).
        grad = gradients_impl.gradients(
            ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]])
        return [[g], grad]

      g_vals, grad_vals = self.evaluate(xla.compile(fn))

      # Gradients for 8 of the 10 unread components are zero.
      expected_grad = np.zeros((10, 2))
      expected_grad[1] = [2.0, 3.0]
      expected_grad[8] = [4.0, 5.0]

      self.assertEqual(len(g_vals), 1)
      self.assertEqual(len(grad_vals), 1)
      self.assertAllEqual([[1.0, -1.0], [8.0, -8.0]], g_vals[0])
      self.assertAllEqual(expected_grad, grad_vals[0])

  def testTensorArrayIdentity(self):
    with self.session() as session, self.test_scope():
      tensor_arrays = {}

      v0 = resource_variable_ops.ResourceVariable(0.0)
      v1 = resource_variable_ops.ResourceVariable(0.0)

      def fn():
        ta0 = tensor_array_ops.TensorArray(
            dtype=dtypes.float32, size=2, infer_shape=False)
        ta1 = tensor_array_ops.TensorArray(
            dtype=dtypes.int32, size=4, infer_shape=True)

        ta0 = ta0.write(0, 0.)
        ta1 = ta1.write(0, 1)

        with ops.control_dependencies([v0.assign_add(1.0)]):
          ta0 = ta0.identity()

        with ops.control_dependencies([v1.assign_add(1.0)]):
          ta1 = ta1.identity()

        read0 = ta0.read(0)
        read1 = ta1.read(0)

        size0 = ta0.size()
        size1 = ta1.size()

        tensor_arrays[0] = ta0
        tensor_arrays[1] = ta1

        return [read0, read1, size0, size1, v0, v1]

      self.evaluate(variables.global_variables_initializer())

      read0_v, read1_v, size0_v, size1_v, v0, v1 = self.evaluate(
          xla.compile(fn))

      # Tests correct properties on new TensorArrays.
      self.assertEqual(dtypes.float32, tensor_arrays[0].dtype)
      self.assertEqual(dtypes.int32, tensor_arrays[1].dtype)

      # Tests that the control dependencies was added and executed.
      self.assertEqual(1.0, v0)
      self.assertEqual(1.0, v1)

      # Tests correct TensorArray.
      self.assertEqual(read0_v, 0)
      self.assertEqual(read1_v, 1)
      self.assertEqual(size0_v, 2)
      self.assertEqual(size1_v, 4)

if __name__ == "__main__":
  test.main()
