# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple LSTM layer with benchmarks.

This sets up a simple LSTM (Long Short Term Memory) layer, unrolled to a fixed
length sequence.  The only deviation from standard LSTM cells is that
activations are clipped, inspired by the GNMT machine translation model.
The GNMT paper has more details: https://arxiv.org/abs/1609.08144
"""

from six.moves import range

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_v1


def Clip(x):
  """Clips x to the range [-1., 1.]."""
  return math_ops.maximum(math_ops.minimum(x, 1.), -1.)


def LSTMCellWeightsShape(num_inputs, num_nodes):
  """Returns the shape of the weights for a single LSTM cell."""
  # Dimension 0 accounts for combining x with the previous m state.
  # Dimension 1 accounts for the in value and the (in, forget, out) gates.
  return [num_inputs + num_nodes, 4 * num_nodes]


def LSTMCell(weights, m_prev, c_prev, x, pad):
  """Unrolls a single LSTM cell with clipped activations forward by one step.

  Args:
    weights: Weight matrix with shape LSTMCellWeightsShape.
    m_prev: Previous m states with shape [batch_size, num_nodes].
    c_prev: Previous c states with shape [batch_size, num_nodes].
    x: Input with shape [batch_size, num_inputs].
    pad: Padding with shape [batch_size, 1].  Each padding value is either
        0 or 1, where 1 indicates padding; i.e. the input is shorter than the
        sequence length, and the (m, c) states should simply be passed through
        from the previous states.
  Returns:
    The next (m, c) states, each with shape [batch_size, num_nodes].
  """
  # Apply weights to the input and previous hidden state.
  # The matmul here is the "big" operation.
  xm = array_ops.concat([x, m_prev], 1)
  xmw = math_ops.matmul(xm, weights)

  # Element-wise ops for the standard LSTM cell, with clipped activations.
  # XLA can fuse these operations into a single loop.
  in_value, in_gate, forget_gate, out_gate = array_ops.split(
      value=xmw, num_or_size_splits=4, axis=1)
  in_value = math_ops.tanh(in_value)
  in_gate = math_ops.sigmoid(in_gate)
  forget_gate = math_ops.sigmoid(forget_gate)
  out_gate = math_ops.sigmoid(out_gate)
  c_next = Clip(Clip(forget_gate * c_prev) + Clip(in_gate * in_value))
  m_next = Clip(out_gate * c_next)

  # Account for padding.
  c_next = c_prev * pad + c_next * (1.0 - pad)
  m_next = m_prev * pad + m_next * (1.0 - pad)

  return m_next, c_next


def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq):
  """Unrolls a layer of LSTM cells forward by the sequence length.

  The sequence length is determined by the length of x_seq and pad_seq, which
  must be the same.

  Args:
    cell_name: Base name of each cell.
    weights: Weight matrix with shape LSTMCellWeightsShape.
    m: Initial m states with shape [batch_size, num_nodes].
    c: Initial c states with shape [batch_size, num_nodes].
    x_seq: List of inputs, each with shape [batch_size, num_inputs].
        The length of the list is the sequence length.
    pad_seq: List of paddings, each with shape [batch_size, 1].
        The length of the list is the sequence length.
        Each padding value is either 0 or 1, where 1 indicates padding;
        i.e. the input is shorter than the sequence length.
  Returns:
    List of per-sequence-step outputs, each with shape [batch_size, num_nodes].
  Raises:
    ValueError: If len(x_seq) != len(pad_seq).
  """
  if len(x_seq) != len(pad_seq):
    raise ValueError('length of x_seq(%d) != pad_seq(%d)' %
                     (len(x_seq), len(pad_seq)))
  out_seq = []
  for seq in range(len(x_seq)):
    with ops.name_scope('%s_%d' % (cell_name, seq)):
      m, c = LSTMCell(weights, m, c, x_seq[seq], pad_seq[seq])
      out_seq.append(array_ops.identity(m, name='out'))
  return out_seq


def RandomVar(shape, name=None):
  """Returns a variable of the given shape initialized to random values."""
  return variable_v1.VariableV1(
      random_ops.random_uniform(shape), dtype=dtypes.float32, name=name)


def RandomInputs(batch_size, seq_length, num_inputs):
  """Returns randomly initialized (x_seq, pad_seq) sequences."""
  x_seq = []
  pad_seq = []
  with ops.name_scope('inputs'):
    for seq in range(seq_length):
      x_seq.append(RandomVar([batch_size, num_inputs], name='x_seq_%d' % seq))
      # Real padding values are always a sequence of 0 followed by a
      # sequence of 1, but random values are fine for benchmarking.
      pad_seq.append(RandomVar([batch_size, 1], name='pad_seq_%d' % seq))
  return x_seq, pad_seq


def BuildLSTMLayer(batch_size, seq_length, num_inputs, num_nodes):
  """Builds a single LSTM layer with random weights and inputs.

  Args:
    batch_size: Inputs are fed in batches of this size.
    seq_length: The sequence length to unroll the LSTM layer.
    num_inputs: Dimension of inputs that are fed into each LSTM cell.
    num_nodes: The number of nodes in each LSTM cell.

  Returns:
    (out_seq, weights) pair.  The out_seq is a list of per-sequence-step
    outputs, each with shape [batch_size, num_nodes].  The weights are a list of
    weight variables that may be trained.
  """
  weights = RandomVar(
      LSTMCellWeightsShape(num_inputs, num_nodes), name='weights')
  m = array_ops.zeros([batch_size, num_nodes], name='init_m')
  c = array_ops.zeros([batch_size, num_nodes], name='init_c')
  x_seq, pad_seq = RandomInputs(batch_size, seq_length, num_inputs)

  out_seq = LSTMLayer('lstm', weights, m, c, x_seq, pad_seq)
  return out_seq, [weights]
