# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the LSTM cell and layer."""

import argparse
import os
import sys

import numpy as np

from tensorflow.compiler.tests import lstm
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


def _DumpGraph(graph, basename):
  if FLAGS.dump_graph_dir:
    name = os.path.join(FLAGS.dump_graph_dir, basename + '.pbtxt')
    with open(name, 'w') as f:
      f.write(str(graph.as_graph_def()))


def _Sigmoid(x):
  return 1. / (1. + np.exp(-x))


def _Clip(x):
  return np.maximum(np.minimum(x, 1.), -1.)


class LSTMTest(test.TestCase):

  def setUp(self):
    # The tests for a single LSTM cell and LSTM layer use these values as
    # inputs.  We always set the dimensionality of num_inputs=1; thus batch_size
    # actually represents the different input cases.
    self._inputs = np.array([[-1.], [-.5], [0.], [.5], [1.]], np.float32)
    self._batch_size = len(self._inputs)

  def _NextC(self, inputs, weight, m_prev, c_prev):
    """Returns the next c states of an LSTM cell."""
    x = (inputs + m_prev) * weight
    return _Clip(_Clip(_Sigmoid(x) * c_prev) + _Clip(_Sigmoid(x) * np.tanh(x)))

  def _NextM(self, inputs, weight, m_prev, c_prev):
    """Returns the next m states of an LSTM cell."""
    x = (inputs + m_prev) * weight
    return _Clip(_Sigmoid(x) * self._NextC(inputs, weight, m_prev, c_prev))

  def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar,
                   pad_scalar):
    with self.session() as sess:
      num_inputs = 1
      num_nodes = 1

      weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
      m_prev = constant_op.constant([[m_prev_scalar]] * self._batch_size)
      c_prev = constant_op.constant([[c_prev_scalar]] * self._batch_size)
      x = constant_op.constant(self._inputs)
      pad = constant_op.constant([[pad_scalar]] * self._batch_size)

      m, c = lstm.LSTMCell(weights, m_prev, c_prev, x, pad)
      _DumpGraph(sess.graph, 'lstm_cell_%s_%d_%d_%d' %
                 (basename, m_prev_scalar, c_prev_scalar, pad_scalar))

      # Initialize variables and run the unrolled LSTM step.
      self.evaluate(variables.global_variables_initializer())
      return self.evaluate([m, c])

  @test_util.run_without_tensor_float_32('TF32 capable devices fail the test'
                                         ' due to reduced matmul precision')
  def testLSTMCell(self):
    # Run with all-0 weights, no padding.
    m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 0., 0.)
    self.assertAllClose(m, [[0.]] * self._batch_size)
    self.assertAllClose(c, [[0.]] * self._batch_size)
    m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 1., 0.)
    self.assertAllClose(m, [[.25]] * self._batch_size)
    self.assertAllClose(c, [[.5]] * self._batch_size)
    m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 0., 0.)
    self.assertAllClose(m, [[.0]] * self._batch_size)
    self.assertAllClose(c, [[.0]] * self._batch_size)
    m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 1., 0.)
    self.assertAllClose(m, [[.25]] * self._batch_size)
    self.assertAllClose(c, [[.5]] * self._batch_size)

    # Run with all-1 weights, no padding.
    for m_prev in [0., 1.]:
      for c_prev in [0., 1.]:
        m, c = self._RunLSTMCell('ones',
                                 init_ops.ones_initializer(), m_prev, c_prev,
                                 0.)
        self.assertAllClose(m, self._NextM(self._inputs, 1., m_prev, c_prev))
        self.assertAllClose(c, self._NextC(self._inputs, 1., m_prev, c_prev))

    # Run with random weights.
    for weight in np.random.rand(3):
      weight_tf = constant_op.constant(weight, dtypes.float32)
      random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w)

      # No padding.
      for m_prev in [0., 1.]:
        for c_prev in [0., 1.]:
          m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 0.)
          self.assertAllClose(m,
                              self._NextM(self._inputs, weight, m_prev, c_prev))
          self.assertAllClose(c,
                              self._NextC(self._inputs, weight, m_prev, c_prev))

      # Set padding.
      for m_prev in [0., 1.]:
        for c_prev in [0., 1.]:
          m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 1.)
          self.assertAllClose(m, [[m_prev]] * self._batch_size)
          self.assertAllClose(c, [[c_prev]] * self._batch_size)

  def testLSTMLayerErrors(self):
    num_inputs = 1
    num_nodes = 1
    seq_length = 3

    weights = array_ops.zeros(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
    m = constant_op.constant([[0.]] * self._batch_size)
    c = constant_op.constant([[0.]] * self._batch_size)
    x_seq = [constant_op.constant(self._inputs)] * seq_length
    pad = constant_op.constant([[0.]] * self._batch_size)

    with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
      lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad])
    with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
      lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 2)
    with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
      lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 4)

  def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar,
                    pad_scalar):
    with self.session() as sess:
      num_inputs = 1
      num_nodes = 1
      seq_length = 3

      weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
      m_init = constant_op.constant([[m_init_scalar]] * self._batch_size)
      c_init = constant_op.constant([[c_init_scalar]] * self._batch_size)
      x_seq = [constant_op.constant(self._inputs)] * seq_length
      pad_seq = [constant_op.constant([[pad_scalar]] * self._batch_size)
                ] * seq_length

      out_seq = lstm.LSTMLayer('lstm', weights, m_init, c_init, x_seq, pad_seq)
      _DumpGraph(sess.graph, 'lstm_layer_%s_%d_%d_%d' %
                 (basename, m_init_scalar, c_init_scalar, pad_scalar))

      # Initialize variables and run the unrolled LSTM layer.
      self.evaluate(variables.global_variables_initializer())
      return self.evaluate(out_seq)

  @test_util.run_without_tensor_float_32('TF32 capable devices fail the test'
                                         ' due to reduced matmul precision')
  def testLSTMLayer(self):
    # Run with all-0 weights, no padding.
    o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 0., 0.)
    self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
    o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 1., 0.)
    self.assertAllClose(o, [[[.25]] * self._batch_size,
                            [[.125]] * self._batch_size,
                            [[.0625]] * self._batch_size])
    o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 0., 0.)
    self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
    o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 1., 0.)
    self.assertAllClose(o, [[[.25]] * self._batch_size,
                            [[.125]] * self._batch_size,
                            [[.0625]] * self._batch_size])

    # Run with all-1 weights, no padding.
    weight1 = 1.
    for m_init in [0., 1.]:
      for c_init in [0., 1.]:
        o = self._RunLSTMLayer('ones',
                               init_ops.ones_initializer(), m_init, c_init, 0.)
        m0 = self._NextM(self._inputs, weight1, m_init, c_init)
        c0 = self._NextC(self._inputs, weight1, m_init, c_init)
        self.assertAllClose(o[0], m0)
        m1 = self._NextM(self._inputs, weight1, m0, c0)
        c1 = self._NextC(self._inputs, weight1, m0, c0)
        self.assertAllClose(o[1], m1)
        m2 = self._NextM(self._inputs, weight1, m1, c1)
        self.assertAllClose(o[2], m2)

    # Run with random weights.
    for weight in np.random.rand(3):
      weight_tf = constant_op.constant(weight, dtypes.float32)
      random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w)

      # No padding.
      for m_init in [0., 1.]:
        for c_init in [0., 1.]:
          o = self._RunLSTMLayer('random', random_weight, m_init, c_init, 0.)
          m0 = self._NextM(self._inputs, weight, m_init, c_init)
          c0 = self._NextC(self._inputs, weight, m_init, c_init)
          self.assertAllClose(o[0], m0)
          m1 = self._NextM(self._inputs, weight, m0, c0)
          c1 = self._NextC(self._inputs, weight, m0, c0)
          self.assertAllClose(o[1], m1)
          m2 = self._NextM(self._inputs, weight, m1, c1)
          self.assertAllClose(o[2], m2)

      # Set padding.
      o = self._RunLSTMLayer('random', random_weight, 0., 0., 1.)
      self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
      o = self._RunLSTMLayer('random', random_weight, 0., 1., 1.)
      self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
      o = self._RunLSTMLayer('random', random_weight, 1., 0., 1.)
      self.assertAllClose(o, [[[1.]] * self._batch_size] * 3)
      o = self._RunLSTMLayer('random', random_weight, 1., 1., 1.)
      self.assertAllClose(o, [[[1.]] * self._batch_size] * 3)


class LSTMBenchmark(test.Benchmark):
  """Mcro-benchmarks for a single layer of LSTM cells."""

  def _LayerBuilder(self, do_training):
    out_seq, weights = lstm.BuildLSTMLayer(FLAGS.batch_size, FLAGS.seq_length,
                                           FLAGS.num_inputs, FLAGS.num_nodes)
    name, fetches = ('lstm_layer_inference', out_seq)
    if do_training:
      # Not a real loss function, but good enough for benchmarking backprop.
      loss = math_ops.reduce_sum(math_ops.add_n(out_seq))
      dw = gradients_impl.gradients(loss, weights)
      name, fetches = ('lstm_layer_training', dw)

    _DumpGraph(ops.get_default_graph(),
               '%s_%d_%d_%d_%d' % (name, FLAGS.batch_size, FLAGS.seq_length,
                                   FLAGS.num_inputs, FLAGS.num_nodes))
    return name, fetches

  def benchmarkLayerInference(self):
    xla_test.Benchmark(self, lambda: self._LayerBuilder(False), False,
                       FLAGS.device)

  def benchmarkLayerInferenceXLA(self):
    xla_test.Benchmark(self, lambda: self._LayerBuilder(False), True,
                       FLAGS.device)

  def benchmarkLayerTraining(self):
    xla_test.Benchmark(self, lambda: self._LayerBuilder(True), False,
                       FLAGS.device)

  def benchmarkLayerTrainingXLA(self):
    xla_test.Benchmark(self, lambda: self._LayerBuilder(True), True,
                       FLAGS.device)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.register('type', 'bool', lambda v: v.lower() == 'true')
  parser.add_argument(
      '--batch_size',
      type=int,
      default=128,
      help="""\
      Inputs are fed in batches of this size, for both inference and training.
      Larger values cause the matmul in each LSTM cell to have higher
      dimensionality.\
      """
  )
  parser.add_argument(
      '--seq_length',
      type=int,
      default=60,
      help="""\
      Length of the unrolled sequence of LSTM cells in a layer.Larger values
      cause more LSTM matmuls to be run.\
      """
  )
  parser.add_argument(
      '--num_inputs',
      type=int,
      default=1024,
      help='Dimension of inputs that are fed into each LSTM cell.'
  )
  parser.add_argument(
      '--num_nodes',
      type=int,
      default=1024,
      help='Number of nodes in each LSTM cell.'
  )
  parser.add_argument(
      '--device',
      type=str,
      default='gpu',
      help="""\
      TensorFlow device to assign ops to, e.g. "gpu", "cpu". For details see
      documentation for tf.Graph.device.\
      """
  )
  parser.add_argument(
      '--dump_graph_dir',
      type=str,
      default='',
      help='If non-empty, dump graphs in *.pbtxt format to this directory.'
  )
  global FLAGS  # pylint:disable=global-at-module-level
  FLAGS, unparsed = parser.parse_known_args()
  # This test is using Tensorflow sessions which are not compatible with eager
  # mode.
  ops.disable_eager_execution()
  test.main(argv=[sys.argv[0]] + unparsed)
