# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tridiagonal solve ops."""

import numpy as np

from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.platform import test

_sample_diags = np.array([[2, 1, 4, 0], [1, 3, 2, 2], [0, 1, -1, 1]],
                         dtype=np.float32)
_sample_rhs = np.array([1, 2, 3, 4], dtype=np.float32)
_sample_result = np.array([-9, 5, -4, 4], dtype=np.float32)


def _tfconst(array):
  return constant_op.constant(array, dtype=dtypes.float32)


def _tf_ones(shape):
  return array_ops.ones(shape, dtype=dtypes.float64)


class TridiagonalSolveOpsTest(xla_test.XLATestCase):
  """Test for tri-diagonal matrix related ops."""

  def testTridiagonalSolverSolves1Rhs(self):
    np.random.seed(19)

    batch_size = 8
    num_dims = 11

    diagonals_np = np.random.normal(size=(batch_size, 3,
                                          num_dims)).astype(np.float32)
    rhs_np = np.random.normal(size=(batch_size, num_dims, 1)).astype(np.float32)

    with self.session() as sess, self.test_scope():
      diags = array_ops.placeholder(
          shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
      rhs = array_ops.placeholder(
          shape=(batch_size, num_dims, 1), dtype=dtypes.float32)
      x_np = sess.run(
          linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=False),
          feed_dict={
              diags: diagonals_np,
              rhs: rhs_np
          })[:, :, 0]

    superdiag_np = diagonals_np[:, 0]
    diag_np = diagonals_np[:, 1]
    subdiag_np = diagonals_np[:, 2]

    y = np.zeros((batch_size, num_dims), dtype=np.float32)

    for i in range(num_dims):
      if i == 0:
        y[:, i] = (
            diag_np[:, i] * x_np[:, i] + superdiag_np[:, i] * x_np[:, i + 1])
      elif i == num_dims - 1:
        y[:, i] = (
            subdiag_np[:, i] * x_np[:, i - 1] + diag_np[:, i] * x_np[:, i])
      else:
        y[:, i] = (
            subdiag_np[:, i] * x_np[:, i - 1] + diag_np[:, i] * x_np[:, i] +
            superdiag_np[:, i] * x_np[:, i + 1])

    self.assertAllClose(y, rhs_np[:, :, 0], rtol=1e-4, atol=1e-4)

  def testTridiagonalSolverSolvesKRhs(self):
    np.random.seed(19)

    batch_size = 8
    num_dims = 11
    num_rhs = 5

    diagonals_np = np.random.normal(size=(batch_size, 3,
                                          num_dims)).astype(np.float32)
    rhs_np = np.random.normal(size=(batch_size, num_dims,
                                    num_rhs)).astype(np.float32)

    with self.session() as sess, self.test_scope():
      diags = array_ops.placeholder(
          shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
      rhs = array_ops.placeholder(
          shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32)
      x_np = sess.run(
          linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=False),
          feed_dict={
              diags: diagonals_np,
              rhs: rhs_np
          })

    superdiag_np = diagonals_np[:, 0]
    diag_np = diagonals_np[:, 1]
    subdiag_np = diagonals_np[:, 2]

    for eq in range(num_rhs):
      y = np.zeros((batch_size, num_dims), dtype=np.float32)
      for i in range(num_dims):
        if i == 0:
          y[:, i] = (
              diag_np[:, i] * x_np[:, i, eq] +
              superdiag_np[:, i] * x_np[:, i + 1, eq])
        elif i == num_dims - 1:
          y[:, i] = (
              subdiag_np[:, i] * x_np[:, i - 1, eq] +
              diag_np[:, i] * x_np[:, i, eq])
        else:
          y[:, i] = (
              subdiag_np[:, i] * x_np[:, i - 1, eq] +
              diag_np[:, i] * x_np[:, i, eq] +
              superdiag_np[:, i] * x_np[:, i + 1, eq])

      self.assertAllClose(y, rhs_np[:, :, eq], rtol=1e-4, atol=1e-4)

  # All the following is adapted from tridiagonal_solve_op_test.py
  def _test(self,
            diags,
            rhs,
            expected,
            diags_format="compact",
            transpose_rhs=False):
    with self.session() as sess, self.test_scope():
      self.assertAllClose(
          sess.run(
              linalg_impl.tridiagonal_solve(
                  _tfconst(diags),
                  _tfconst(rhs),
                  diags_format,
                  transpose_rhs,
                  conjugate_rhs=False,
                  partial_pivoting=False)),
          np.asarray(expected, dtype=np.float32))

  def _testWithDiagonalLists(self,
                             diags,
                             rhs,
                             expected,
                             diags_format="compact",
                             transpose_rhs=False):
    with self.session() as sess, self.test_scope():
      self.assertAllClose(
          sess.run(
              linalg_impl.tridiagonal_solve([_tfconst(x) for x in diags],
                                            _tfconst(rhs),
                                            diags_format,
                                            transpose_rhs,
                                            conjugate_rhs=False,
                                            partial_pivoting=False)),
          sess.run(_tfconst(expected)))

  def testReal(self):
    self._test(diags=_sample_diags, rhs=_sample_rhs, expected=_sample_result)

  # testComplex is skipped as complex type is not yet supported.

  def test3x3(self):
    self._test(
        diags=[[2.0, -1.0, 0.0], [1.0, 3.0, 1.0], [0.0, -1.0, -2.0]],
        rhs=[1.0, 2.0, 3.0],
        expected=[-3.0, 2.0, 7.0])

  def test2x2(self):
    self._test(
        diags=[[2.0, 0.0], [1.0, 3.0], [0.0, 1.0]],
        rhs=[1.0, 4.0],
        expected=[-5.0, 3.0])

  def test1x1(self):
    self._test(diags=[[0], [3], [0]], rhs=[6], expected=[2])

  def test0x0(self):
    self._test(
        diags=np.zeros(shape=(3, 0), dtype=np.float32),
        rhs=np.zeros(shape=(0, 1), dtype=np.float32),
        expected=np.zeros(shape=(0, 1), dtype=np.float32))

  def test2x2WithMultipleRhs(self):
    self._test(
        diags=[[2, 0], [1, 3], [0, 1]],
        rhs=[[1, 2, 3], [4, 8, 12]],
        expected=[[-5, -10, -15], [3, 6, 9]])

  def test1x1WithMultipleRhs(self):
    self._test(diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]])

  # test1x1NotInvertible is skipped as runtime error not raised for now.

  # test2x2NotInvertible is skipped as runtime error not raised for now.

  @test_util.disable_mlir_bridge("Error messages differ")
  def testPartialPivotingRaises(self):
    np.random.seed(0)
    batch_size = 8
    num_dims = 11
    num_rhs = 5

    diagonals_np = np.random.normal(size=(batch_size, 3,
                                          num_dims)).astype(np.float32)
    rhs_np = np.random.normal(size=(batch_size, num_dims,
                                    num_rhs)).astype(np.float32)

    with self.session() as sess, self.test_scope():
      with self.assertRaisesRegex(
          errors_impl.UnimplementedError,
          "Current implementation does not yet support pivoting."):
        diags = array_ops.placeholder(
            shape=(batch_size, 3, num_dims), dtype=dtypes.float32)
        rhs = array_ops.placeholder(
            shape=(batch_size, num_dims, num_rhs), dtype=dtypes.float32)
        sess.run(
            linalg_impl.tridiagonal_solve(diags, rhs, partial_pivoting=True),
            feed_dict={
                diags: diagonals_np,
                rhs: rhs_np
            })

  # testCaseRequiringPivotingLastRows is skipped as pivoting is not supported
  # for now.

  # testNotInvertible is skipped as runtime error not raised for now.

  def testDiagonal(self):
    self._test(
        diags=[[0, 0, 0, 0], [1, 2, -1, -2], [0, 0, 0, 0]],
        rhs=[1, 2, 3, 4],
        expected=[1, 1, -3, -2])

  def testUpperTriangular(self):
    self._test(
        diags=[[2, 4, -1, 0], [1, 3, 1, 2], [0, 0, 0, 0]],
        rhs=[1, 6, 4, 4],
        expected=[13, -6, 6, 2])

  def testLowerTriangular(self):
    self._test(
        diags=[[0, 0, 0, 0], [2, -1, 3, 1], [0, 1, 4, 2]],
        rhs=[4, 5, 6, 1],
        expected=[2, -3, 6, -11])

  def testWithTwoRightHandSides(self):
    self._test(
        diags=_sample_diags,
        rhs=np.transpose([_sample_rhs, 2 * _sample_rhs]),
        expected=np.transpose([_sample_result, 2 * _sample_result]))

  def testBatching(self):
    self._test(
        diags=np.array([_sample_diags, -_sample_diags]),
        rhs=np.array([_sample_rhs, 2 * _sample_rhs]),
        expected=np.array([_sample_result, -2 * _sample_result]))

  def testWithTwoBatchingDimensions(self):
    self._test(
        diags=np.array([[_sample_diags, -_sample_diags, _sample_diags],
                        [-_sample_diags, _sample_diags, -_sample_diags]]),
        rhs=np.array([[_sample_rhs, 2 * _sample_rhs, 3 * _sample_rhs],
                      [4 * _sample_rhs, 5 * _sample_rhs, 6 * _sample_rhs]]),
        expected=np.array(
            [[_sample_result, -2 * _sample_result, 3 * _sample_result],
             [-4 * _sample_result, 5 * _sample_result, -6 * _sample_result]]))

  def testBatchingAndTwoRightHandSides(self):
    rhs = np.transpose([_sample_rhs, 2 * _sample_rhs])
    expected_result = np.transpose([_sample_result, 2 * _sample_result])
    self._test(
        diags=np.array([_sample_diags, -_sample_diags]),
        rhs=np.array([rhs, 2 * rhs]),
        expected=np.array([expected_result, -2 * expected_result]))

  def testSequenceFormat(self):
    self._testWithDiagonalLists(
        diags=[[2, 1, 4], [1, 3, 2, 2], [1, -1, 1]],
        rhs=[1, 2, 3, 4],
        expected=[-9, 5, -4, 4],
        diags_format="sequence")

  def testSequenceFormatWithDummyElements(self):
    dummy = 20  # Should be ignored by the solver.
    self._testWithDiagonalLists(
        diags=[
            [2, 1, 4, dummy],
            [1, 3, 2, 2],
            [dummy, 1, -1, 1],
        ],
        rhs=[1, 2, 3, 4],
        expected=[-9, 5, -4, 4],
        diags_format="sequence")

  def testSequenceFormatWithBatching(self):
    self._testWithDiagonalLists(
        diags=[[[2, 1, 4], [-2, -1, -4]], [[1, 3, 2, 2], [-1, -3, -2, -2]],
               [[1, -1, 1], [-1, 1, -1]]],
        rhs=[[1, 2, 3, 4], [1, 2, 3, 4]],
        expected=[[-9, 5, -4, 4], [9, -5, 4, -4]],
        diags_format="sequence")

  def testMatrixFormat(self):
    self._test(
        diags=[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]],
        rhs=[1, 2, 3, 4],
        expected=[-9, 5, -4, 4],
        diags_format="matrix")

  def testMatrixFormatWithMultipleRightHandSides(self):
    self._test(
        diags=[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]],
        rhs=[[1, -1], [2, -2], [3, -3], [4, -4]],
        expected=[[-9, 9], [5, -5], [-4, 4], [4, -4]],
        diags_format="matrix")

  def testMatrixFormatWithBatching(self):
    self._test(
        diags=[[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]],
               [[-1, -2, 0, 0], [-1, -3, -1, 0], [0, 1, -2, -4], [0, 0, -1,
                                                                  -2]]],
        rhs=[[1, 2, 3, 4], [1, 2, 3, 4]],
        expected=[[-9, 5, -4, 4], [9, -5, 4, -4]],
        diags_format="matrix")

  def testRightHandSideAsColumn(self):
    self._test(
        diags=_sample_diags,
        rhs=np.transpose([_sample_rhs]),
        expected=np.transpose([_sample_result]),
        diags_format="compact")

  def testTransposeRhs(self):
    self._test(
        diags=_sample_diags,
        rhs=np.array([_sample_rhs, 2 * _sample_rhs]),
        expected=np.array([_sample_result, 2 * _sample_result]).T,
        transpose_rhs=True)

  # testConjugateRhs is skipped as complex type is not yet supported.

  # testAjointRhs is skipped as complex type is not yet supported

  def testTransposeRhsWithRhsAsVector(self):
    self._test(
        diags=_sample_diags,
        rhs=_sample_rhs,
        expected=_sample_result,
        transpose_rhs=True)

  # testConjugateRhsWithRhsAsVector is skipped as complex type is not yet
  # supported.

  def testTransposeRhsWithRhsAsVectorAndBatching(self):
    self._test(
        diags=np.array([_sample_diags, -_sample_diags]),
        rhs=np.array([_sample_rhs, 2 * _sample_rhs]),
        expected=np.array([_sample_result, -2 * _sample_result]),
        transpose_rhs=True)

  # Gradient tests

  def _gradientTest(
      self,
      diags,
      rhs,
      y,  # output = reduce_sum(y * tridiag_solve(diags, rhs))
      expected_grad_diags,  # expected gradient of output w.r.t. diags
      expected_grad_rhs,  # expected gradient of output w.r.t. rhs
      diags_format="compact",
      transpose_rhs=False,
      feed_dict=None):
    expected_grad_diags = np.array(expected_grad_diags).astype(np.float32)
    expected_grad_rhs = np.array(expected_grad_rhs).astype(np.float32)
    with self.session() as sess, self.test_scope():
      diags = _tfconst(diags)
      rhs = _tfconst(rhs)
      y = _tfconst(y)

      x = linalg_impl.tridiagonal_solve(
          diags,
          rhs,
          diagonals_format=diags_format,
          transpose_rhs=transpose_rhs,
          conjugate_rhs=False,
          partial_pivoting=False)

      res = math_ops.reduce_sum(x * y)
      actual_grad_diags = sess.run(
          gradient_ops.gradients(res, diags)[0], feed_dict=feed_dict)
      actual_rhs_diags = sess.run(
          gradient_ops.gradients(res, rhs)[0], feed_dict=feed_dict)
    self.assertAllClose(expected_grad_diags, actual_grad_diags)
    self.assertAllClose(expected_grad_rhs, actual_rhs_diags)

  def testGradientSimple(self):
    self._gradientTest(
        diags=_sample_diags,
        rhs=_sample_rhs,
        y=[1, 3, 2, 4],
        expected_grad_diags=[[-5, 0, 4, 0], [9, 0, -4, -16], [0, 0, 5, 16]],
        expected_grad_rhs=[1, 0, -1, 4])

  def testGradientWithMultipleRhs(self):
    self._gradientTest(
        diags=_sample_diags,
        rhs=[[1, 2], [2, 4], [3, 6], [4, 8]],
        y=[[1, 5], [2, 6], [3, 7], [4, 8]],
        expected_grad_diags=[[-20, 28, -60, 0], [36, -35, 60, 80],
                             [0, 63, -75, -80]],
        expected_grad_rhs=[[0, 2], [1, 3], [1, 7], [0, -10]])

  def _makeDataForGradientWithBatching(self):
    y = np.array([1, 3, 2, 4]).astype(np.float32)
    grad_diags = np.array([[-5, 0, 4, 0], [9, 0, -4, -16],
                           [0, 0, 5, 16]]).astype(np.float32)
    grad_rhs = np.array([1, 0, -1, 4]).astype(np.float32)

    diags_batched = np.array(
        [[_sample_diags, 2 * _sample_diags, 3 * _sample_diags],
         [4 * _sample_diags, 5 * _sample_diags,
          6 * _sample_diags]]).astype(np.float32)
    rhs_batched = np.array([[_sample_rhs, -_sample_rhs, _sample_rhs],
                            [-_sample_rhs, _sample_rhs,
                             -_sample_rhs]]).astype(np.float32)
    y_batched = np.array([[y, y, y], [y, y, y]]).astype(np.float32)
    expected_grad_diags_batched = np.array(
        [[grad_diags, -grad_diags / 4, grad_diags / 9],
         [-grad_diags / 16, grad_diags / 25,
          -grad_diags / 36]]).astype(np.float32)
    expected_grad_rhs_batched = np.array(
        [[grad_rhs, grad_rhs / 2, grad_rhs / 3],
         [grad_rhs / 4, grad_rhs / 5, grad_rhs / 6]]).astype(np.float32)

    return (y_batched, diags_batched, rhs_batched, expected_grad_diags_batched,
            expected_grad_rhs_batched)

  def testGradientWithBatchDims(self):
    y, diags, rhs, expected_grad_diags, expected_grad_rhs = (
        self._makeDataForGradientWithBatching())

    self._gradientTest(
        diags=diags,
        rhs=rhs,
        y=y,
        expected_grad_diags=expected_grad_diags,
        expected_grad_rhs=expected_grad_rhs)

  # testGradientWithUnknownShapes is skipped as shapes should be fully known.

  def _assertRaises(self, diags, rhs, diags_format="compact"):
    with self.assertRaises(ValueError):
      linalg_impl.tridiagonal_solve(diags, rhs, diags_format)

  # Invalid input shapes
  def testInvalidShapesCompactFormat(self):

    def test_raises(diags_shape, rhs_shape):
      self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "compact")

    test_raises((5, 4, 4), (5, 4))
    test_raises((5, 3, 4), (4, 5))
    test_raises((5, 3, 4), (5))
    test_raises((5), (5, 4))

  def testInvalidShapesSequenceFormat(self):

    def test_raises(diags_tuple_shapes, rhs_shape):
      diagonals = tuple(_tf_ones(shape) for shape in diags_tuple_shapes)
      self._assertRaises(diagonals, _tf_ones(rhs_shape), "sequence")

    test_raises(((5, 4), (5, 4)), (5, 4))
    test_raises(((5, 4), (5, 4), (5, 6)), (5, 4))
    test_raises(((5, 3), (5, 4), (5, 6)), (5, 4))
    test_raises(((5, 6), (5, 4), (5, 3)), (5, 4))
    test_raises(((5, 4), (7, 4), (5, 4)), (5, 4))
    test_raises(((5, 4), (7, 4), (5, 4)), (3, 4))

  def testInvalidShapesMatrixFormat(self):

    def test_raises(diags_shape, rhs_shape):
      self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "matrix")

    test_raises((5, 4, 7), (5, 4))
    test_raises((5, 4, 4), (3, 4))
    test_raises((5, 4, 4), (5, 3))

  # Tests involving placeholder with an unknown dimension are all skipped.
  # Dimensions have to be all known statically.


if __name__ == "__main__":
  test.main()
