# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.MatrixTriangularSolve."""

import itertools

import numpy as np

from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.platform import test


def MakePlaceholder(x, dtype=None):
  return array_ops.placeholder(
      dtypes.as_dtype(x.dtype) if dtype is None else dtype, shape=x.shape)


class MatrixTriangularSolveOpTest(xla_test.XLATestCase):

  #  MatrixTriangularSolve defined for float64, float32, complex64, complex128
  # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve)
  @property
  def float_types(self):
    return set(super(MatrixTriangularSolveOpTest,
                     self).float_types).intersection(
                         (np.float64, np.float32, np.complex64, np.complex128))

  def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca,
                                 placeholder_b, a, clean_a, b, verification,
                                 atol):
    feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b}
    verification_np = sess.run(verification, feed_dict)
    broadcasted_shape = a.shape[:-2] + (b.shape[-2], b.shape[-1])
    broadcasted_b = b + np.zeros(shape=broadcasted_shape, dtype=b.dtype)
    self.assertAllClose(broadcasted_b, verification_np, atol=atol)

  def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol, dtype=None):
    clean_a = np.tril(a) if lower else np.triu(a)
    with self.session() as sess:
      placeholder_a = MakePlaceholder(a, dtype)
      placeholder_ca = MakePlaceholder(clean_a, dtype)
      placeholder_b = MakePlaceholder(b, dtype)
      with self.test_scope():
        x = linalg_ops.matrix_triangular_solve(
            placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
      verification = test_util.matmul_without_tf32(
          placeholder_ca, x, adjoint_a=adjoint)
      self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
                                      placeholder_b, a, clean_a, b,
                                      verification, atol)

  def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4, dtype=None):
    transp = lambda x: np.swapaxes(x, -1, -2)
    for lower, adjoint in itertools.product([True, False], repeat=2):
      self._VerifyTriangularSolve(
          a if lower else transp(a), b, lower, adjoint, atol, dtype=dtype)

  def testBasic(self):
    rng = np.random.RandomState(0)
    a = np.tril(rng.randn(5, 5))
    b = rng.randn(5, 7)
    for dtype in self.float_types:
      self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))

  def testBfloat16(self):
    rng = np.random.RandomState(0)
    a = np.tril(rng.randn(5, 5))
    b = rng.randn(5, 7)
    self._VerifyTriangularSolveCombo(a, b, atol=5e-2, dtype=dtypes.bfloat16)

  def testBasicNotActuallyTriangular(self):
    rng = np.random.RandomState(0)
    a = rng.randn(5, 5)  # the `a` matrix is not lower-triangular
    b = rng.randn(5, 7)
    for dtype in self.float_types:
      self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))

  def testBasicComplexDtypes(self):

    if xla_test.test.is_built_with_rocm():
      # The following subtest invokes the call to "BlasTrsm"
      # That operation is currently not supported on the ROCm platform
      self.skipTest("BlasTrsm op for complex types is not supported in ROCm")

    rng = np.random.RandomState(0)
    a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j)
    b = rng.randn(5, 7) + rng.randn(5, 7) * 1j
    for dtype in self.complex_types:
      self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))

  def testBatch(self):
    rng = np.random.RandomState(0)
    shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)),
              ((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))]
    tuples = itertools.product(self.float_types, shapes)
    for dtype, (a_shape, b_shape) in tuples:
      n = a_shape[-1]
      a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
      b = rng.randn(*b_shape)
      self._VerifyTriangularSolveCombo(
          a.astype(dtype), b.astype(dtype), atol=1e-3)

  def testBatchBroadcast(self):
    rng = np.random.RandomState(0)
    shapes = [((3, 3), (4, 3, 5)), ((1, 2, 2), (3, 2, 1)), ((1, 1), (1, 1, 2)),
              ((1, 3, 4, 4), (2, 1, 4, 1))]
    tuples = itertools.product(self.float_types, shapes)
    for dtype, (a_shape, b_shape) in tuples:
      n = a_shape[-1]
      a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
      b = rng.randn(*b_shape)
      self._VerifyTriangularSolveCombo(
          a.astype(dtype), b.astype(dtype), atol=1e-3)

  def testLarge(self):
    n = 1024
    rng = np.random.RandomState(0)
    a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n)
    b = rng.randn(n, n)
    self._VerifyTriangularSolve(
        a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)

  @test_util.disable_mlir_bridge("Error handling")
  def testNonSquareCoefficientMatrix(self):
    rng = np.random.RandomState(0)
    for dtype in self.float_types:
      a = rng.randn(3, 4).astype(dtype)
      b = rng.randn(4, 4).astype(dtype)
      with self.test_scope():
        with self.assertRaises((ValueError, errors.InvalidArgumentError)):
          linalg_ops.matrix_triangular_solve(a, b)

  @test_util.run_v2_only  # Different error types
  @test_util.disable_mlir_bridge("Error handling")
  def testWrongDimensionsV2(self):
    randn = np.random.RandomState(0).randn
    for dtype in self.float_types:
      lhs = constant_op.constant(randn(3, 3), dtype=dtype)
      rhs = constant_op.constant(randn(4, 3), dtype=dtype)
      with self.assertRaises(errors.InvalidArgumentError):
        linalg_ops.matrix_triangular_solve(lhs, rhs)
      with self.assertRaises(errors.InvalidArgumentError):
        linalg_ops.matrix_triangular_solve(lhs, rhs)

  @test_util.run_v1_only("Different error types")
  @test_util.disable_mlir_bridge("Error handling")
  def testWrongDimensionsV1(self):
    randn = np.random.RandomState(0).randn
    for dtype in self.float_types:
      lhs = constant_op.constant(randn(3, 3), dtype=dtype)
      rhs = constant_op.constant(randn(4, 3), dtype=dtype)
      with self.assertRaises(ValueError):
        linalg_ops.matrix_triangular_solve(lhs, rhs)
      with self.assertRaises(ValueError):
        linalg_ops.matrix_triangular_solve(lhs, rhs)


if __name__ == "__main__":
  test.main()
