# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for slicing."""

from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest


class SliceTest(xla_test.XLATestCase):

  def test1D(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[10])
        with self.test_scope():
          o = array_ops.slice(i, [2], [4])
        params = {
            i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([2, 3, 4, 5], result)

  def testZeroSlice(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[2])
        with self.test_scope():
          o = array_ops.slice(i, [0], [0])
        params = {
            i: [0, 1],
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([], result)

  def test3D(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[3, 3, 10])
        with self.test_scope():
          o = array_ops.slice(i, [1, 2, 2], [1, 1, 4])
        params = {
            i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                 [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                 [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]],
                [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
                 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                 [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]],
                [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5],
                 [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
                 [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]]
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([[[6, 5, 4, 3]]], result)

  def test3DWithDynamicBegin(self):
    """Tests a slice where the start offset is not known at compile time."""
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[3, 3, 10])
        begin = array_ops.placeholder(dtypes.int32, shape=[3])
        with self.test_scope():
          o = array_ops.slice(i, begin, [1, 1, 4])
        params = {
            i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                 [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                 [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]],
                [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
                 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                 [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]],
                [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5],
                 [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
                 [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]],
            begin: [1, 2, 2]
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([[[6, 5, 4, 3]]], result)

  def test3DWithDynamicBeginAndNegativeSize(self):
    """Tests a slice where `begin` is fed dynamically and `size` contains -1."""
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[3, 3, 10])
        begin = array_ops.placeholder(dtypes.int32, shape=[3])
        with self.test_scope():
          o = array_ops.slice(i, begin, [1, -1, 4])
        params = {
            i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                 [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                 [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]],
                [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
                 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                 [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]],
                [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5],
                 [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
                 [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]],
            begin: [1, 1, 2]
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result)


class StridedSliceTest(xla_test.XLATestCase):

  def test1D(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[10])
        with self.test_scope():
          o = array_ops.strided_slice(i, [2], [6], [2])
        params = {
            i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([2, 4], result)

  def test1DDynamic(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[10])
        begin = array_ops.placeholder(dtypes.int32, shape=[1])
        with self.test_scope():
          end = math_ops.add(begin, [1])
          o = array_ops.strided_slice(i, begin, end, [1])
        params = {
            i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            begin: [0]
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([0], result)

  def test1DNegativeStride(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[10])
        with self.test_scope():
          o = array_ops.strided_slice(i, [6], [2], [-2])
        params = {
            i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([6, 4], result)

  def test2DDegenerate(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[2, 3])
        with self.test_scope():
          o = array_ops.strided_slice(i, [-1, 0], [0, 3])
        params = {
            i: [[0, 1, 2],
                [3, 4, 5]]
        }
        result = o.eval(feed_dict=params)

        self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape)

  def test2DDegenerateNegativeStride(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[2, 3])
        with self.test_scope():
          o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1])
        params = {
            i: [[0, 1, 2],
                [3, 4, 5]]
        }
        result = o.eval(feed_dict=params)

        self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape)

  def test2DFullSlice(self):
    for dtype in self.numeric_types:
      with self.session():
        with self.test_scope():
          i = array_ops.placeholder(dtype, shape=[2, 4])
          begin = array_ops.placeholder(dtypes.int32, shape=[2])
          end = math_ops.add(begin, [1, 1])
          o = array_ops.strided_slice(i, begin, end, [1, 1])
        params = {
            i: [[0, 1, 2, 3], [4, 5, 6, 7]],
            begin: [1, 1]
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([[5]], result)

  def test3D(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[3, 3, 10])
        with self.test_scope():
          o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2])
        params = {
            i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                 [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                 [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]],
                [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
                 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                 [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]],
                [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5],
                 [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
                 [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]]
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([[[1, 9]], [[6, 4]]], result)

  def test3DNegativeStride(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[3, 4, 10])
        with self.test_scope():
          o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2])
        params = {
            i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                 [9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
                 [5, 3, 1, 7, 9, 2, 4, 6, 8, 0],
                 [4, 5, 2, 4, 3, 7, 6, 8, 9, 4]],
                [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
                 [4, 3, 4, 5, 7, 6, 5, 3, 4, 5],
                 [8, 7, 6, 5, 4, 3, 2, 1, 8, 7],
                 [7, 1, 7, 1, 8, 1, 8, 1, 3, 1]],
                [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5],
                 [1, 2, 1, 2, 1, 2, 1, 2, 1, 2],
                 [9, 8, 7, 9, 8, 7, 9, 8, 7, 9],
                 [9, 9, 5, 5, 6, 6, 3, 3, 6, 6]]]
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([[[9, 8],
                              [1, 1]],
                             [[2, 4],
                              [5, 7]]], result)

  # Test shrink_axis_mask. This `strided_slice` call is equivalent to `i[1,:]`.
  def testShrinkAxisMask(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[2, 3])
        with self.test_scope():
          o = array_ops.strided_slice(i, [1, 0], [10, 3], shrink_axis_mask=1)
        params = {
            i: [[0, 1, 2], [3, 4, 5]],
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([3, 4, 5], result)

  # Test shrink_axis_mask with the range for the second dimension implicit.
  # This `strided_slice` call is equivalent to `i[1]`.
  def testShrinkAxisMaskImplicitRange(self):
    for dtype in self.numeric_types:
      with self.session():
        i = array_ops.placeholder(dtype, shape=[2, 3])
        with self.test_scope():
          o = array_ops.strided_slice(i, [1], [10], shrink_axis_mask=1)
        params = {
            i: [[0, 1, 2], [3, 4, 5]],
        }
        result = o.eval(feed_dict=params)

        self.assertAllEqual([3, 4, 5], result)

if __name__ == "__main__":
  googletest.main()
