# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for while loops in XLA."""

import os

import numpy as np

from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import while_loop
from tensorflow.python.platform import test


class WhileTest(xla_test.XLATestCase):

  def testSingletonLoopHandrolled(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32)
    def loop_body(step):
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      return step_out

    # Define a function for the loop condition
    @function.Defun(dtypes.int32)
    def loop_cond(step):
      return step < 10

    with self.session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index], loop_cond, loop_body)

      result = sess.run(loop_outputs, {init_index: 0})
      self.assertAllClose(result, [10], rtol=1e-3)

  def testCountingLoopHandrolled(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32, dtypes.float32)
    def loop_body(step, rsum):
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      sum_out = rsum + constant_op.constant(1.5, dtype=dtypes.float32)
      return step_out, sum_out

    # Define a function for the loop condition
    @function.Defun(dtypes.int32, dtypes.float32)
    def loop_cond(step, rsum):
      del rsum
      return step < 10

    with self.session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      init_sum = array_ops.placeholder(dtypes.float32, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index, init_sum], loop_cond,
                                      loop_body)

      result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0})
      self.assertAllClose(result, [10, 15.0], rtol=1e-3)
      no_iters_result = sess.run(loop_outputs, {init_index: 10, init_sum: 0.0})
      self.assertAllClose(no_iters_result, [10, 0.0], rtol=1e-3)

  def testCountingLoopHandrolledC64(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32, dtypes.complex64)
    def loop_body(step, rsum):
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      sum_out = rsum + constant_op.constant(1.5 + 2j, dtype=dtypes.complex64)
      return step_out, sum_out

    # Define a function for the loop condition
    @function.Defun(dtypes.int32, dtypes.complex64)
    def loop_cond(step, rsum):
      del rsum
      return step < 10

    with self.session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      init_sum = array_ops.placeholder(dtypes.complex64, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index, init_sum], loop_cond,
                                      loop_body)

      result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0})
      self.assertAllClose(result[1], np.complex64(15 + 20j), rtol=1e-3)
      no_iters_result = sess.run(loop_outputs, {init_index: 10, init_sum: 0.0})
      self.assertAllClose(no_iters_result[1], np.complex64(0), rtol=1e-3)

  def testLoopWithConstantOutput(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32, dtypes.int32)
    def loop_body(step, x):
      del x
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      return (step_out, 7)

    # Define a function for the loop condition
    @function.Defun(dtypes.int32, dtypes.int32)
    def loop_cond(step, x):
      del x
      return step < 10

    with self.session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body)

      result = sess.run(loop_outputs, {init_index: 0})
      self.assertAllClose(result, [10, 7], rtol=1e-3)

  def _testMaxItersSimple(self):
    if is_compile_on_demand():
      self.skipTest("list_ops are not supported in cpu_ondemand")
    with self.session() as sess, self.test_scope():
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      v = constant_op.constant(1.0)
      p = array_ops.placeholder(dtype=dtypes.int32)

      def create_while_loop():
        iterations = array_ops.size(p, name="iterations")
        r = while_loop.while_loop(
            lambda *_: True,
            lambda i, x: (i + 1, v * x), (0, 1.0),
            maximum_iterations=iterations,
            name="outer")
        return array_ops.identity(r[1])

      output = create_while_loop()
      output = gradients_impl.gradients(output, v)[0]

      result = sess.run(output, feed_dict={p: [0, 0, 0]})
      print(result)
      xla_context.Exit()

  def testMaxItersSimple(self):
    self.skipTest("Fails with v1 control flow")
    # This fails with old control.
    # self._testMaxItersSimple()

  @test_util.enable_control_flow_v2
  def testMaxItersSimpleV2(self):
    self._testMaxItersSimple()

  def _testNestedWhileLoopWithMaxItersFromOuterContext(self):
    if is_compile_on_demand():
      self.skipTest("list_ops are not supported in cpu_ondemand")
    with self.session() as sess, self.test_scope():
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      v = constant_op.constant(1.0)
      p = array_ops.placeholder(dtype=dtypes.int32)

      def mid_body_builder(iterations):

        def mid_body(i, x):
          r = while_loop.while_loop(
              lambda *_: True,
              lambda i, x: (i + 1, v * x), (0, x),
              maximum_iterations=iterations,
              name="inner")
          return (i + 1, gradients_impl.gradients(x + r[1], v)[0])

        return mid_body

      def outer_body(i, x):
        iterations = array_ops.size(p, name="iterations")
        return (i + 1, x + while_loop.while_loop(
            lambda *_: True,
            mid_body_builder(iterations), (0, x),
            maximum_iterations=iterations,
            name="mid")[1])

      def create_while_loop():
        r = while_loop.while_loop(
            lambda *_: True,
            outer_body, (0, 1.0),
            maximum_iterations=5,
            name="outer")
        return array_ops.identity(r[1])

      # p:placeholder
      # j = 0
      # i, x = 0, 1.
      # while j++ < 5:
      #   i1, x1 = 0, x
      #   while i1++ < len(p):
      #     i2, x2 = 0, x1
      #     while i2++ < len(p):
      #       x2 = v * x2
      #     x1 = grad(x1 + x2, v)
      #   x = x1
      # output = x
      output = create_while_loop()
      sess.run(output, feed_dict={p: [0, 0, 0]})
      xla_context.Exit()

  def testNestedWhileLoopWithMaxItersFromOuterContext(self):
    self._testNestedWhileLoopWithMaxItersFromOuterContext()

  @test_util.enable_control_flow_v2
  def testNestedWhileLoopWithMaxItersFromOuterContextV2(self):
    self._testNestedWhileLoopWithMaxItersFromOuterContext()

  @test_util.enable_control_flow_v2
  def testMap(self):
    if is_compile_on_demand():
      self.skipTest("list_ops are not supported in cpu_ondemand")
    with self.session(), self.test_scope():
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      nums = [1, 2, 3, 4, 5, 6]
      elems = constant_op.constant(nums, name="data")
      r = map_fn.map_fn(lambda x: math_ops.multiply(math_ops.add(x, 3), 2),
                        elems)
      self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
      xla_context.Exit()

  @test_util.enable_control_flow_v2
  def testMapBackPropFalse(self):
    if is_compile_on_demand():
      self.skipTest("list_ops are not supported in cpu_ondemand")
    with self.session(), self.test_scope():
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      nums = [1, 2, 3, 4, 5, 6]
      elems = constant_op.constant(nums, name="data")
      r = map_fn.map_fn(
          lambda x: math_ops.multiply(math_ops.add(x, 3), 2),
          elems,
          back_prop=False)
      self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
      xla_context.Exit()


def is_compile_on_demand():
  return ("TF_XLA_FLAGS" in os.environ and
          "tf_xla_compile_on_demand" in os.environ["TF_XLA_FLAGS"])


if __name__ == "__main__":
  os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " +
                                os.environ.get("TF_XLA_FLAGS", ""))
  test.main()
