Skip to content

Commit

Permalink
Fix split in torch and tensorflow (#18988)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Dec 24, 2023
1 parent b1a1107 commit 8d02be3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 9 deletions.
16 changes: 15 additions & 1 deletion keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,21 @@ def sort(x, axis=-1):


def split(x, indices_or_sections, axis=0):
return tfnp.split(x, indices_or_sections, axis=axis)
if not isinstance(indices_or_sections, int):
# `tf.split` requires `num_or_size_splits`, so we need to convert
# `indices_or_sections` to the appropriate format.
# The following implementation offers better compatibility for the
# tensor argument `indices_or_sections` than original `tfnp.split`.
total_size = x.shape[axis]
indices_or_sections = convert_to_tensor(indices_or_sections)
start_size = indices_or_sections[0:1]
end_size = total_size - indices_or_sections[-1:]
num_or_size_splits = tf.concat(
[start_size, tfnp.diff(indices_or_sections), end_size], axis=0
)
else:
num_or_size_splits = indices_or_sections
return tf.split(x, num_or_size_splits, axis=axis)


def stack(x, axis=0):
Expand Down
16 changes: 8 additions & 8 deletions keras/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,15 +1168,15 @@ def sort(x, axis=-1):
def split(x, indices_or_sections, axis=0):
x = convert_to_tensor(x)
dim = x.shape[axis]
if isinstance(indices_or_sections, (list, tuple)):
idxs = convert_to_tensor(indices_or_sections)
start_size = indices_or_sections[0]
end_size = dim - indices_or_sections[-1]
chunk_sizes = (
[start_size]
+ torch.diff(idxs).type(torch.int).tolist()
+ [end_size]
if not isinstance(indices_or_sections, int):
indices_or_sections = convert_to_tensor(indices_or_sections)
start_size = indices_or_sections[0:1]
end_size = dim - indices_or_sections[-1:]
chunk_sizes = torch.concat(
[start_size, torch.diff(indices_or_sections), end_size], dim=0
)
# torch.split doesn't support tensor input for `split_size_or_sections`
chunk_sizes = chunk_sizes.tolist()
else:
if dim % indices_or_sections != 0:
raise ValueError(
Expand Down
31 changes: 31 additions & 0 deletions keras/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3740,6 +3740,37 @@ def test_split(self):
self.assertEqual(len(knp.split(x, 2)), 2)
self.assertEqual(len(knp.Split(2)(x)), 2)

# test indices_or_sections as tensor
x = knp.array([[1, 2, 3], [3, 2, 1]])
indices_or_sections = knp.array([1, 2])
x_np = np.array([[1, 2, 3], [3, 2, 1]])
indices_or_sections_np = np.array([1, 2])
self.assertAllClose(
knp.split(x, indices_or_sections, axis=1),
np.split(x_np, indices_or_sections_np, axis=1),
)

@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Only test tensorflow backend",
)
def test_split_with_jit_in_tf(self):
import tensorflow as tf

x = knp.array([[1, 2, 3], [3, 2, 1]])
indices = knp.array([1, 2])
x_np = np.array([[1, 2, 3], [3, 2, 1]])
indices_np = np.array([1, 2])

@tf.function(jit_compile=True)
def fn(x, indices, axis):
return knp.split(x, indices, axis=axis)

self.assertAllClose(
fn(x, indices, axis=1),
np.split(x_np, indices_np, axis=1),
)

def test_sqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32")
ref_y = np.sqrt(x)
Expand Down

0 comments on commit 8d02be3

Please sign in to comment.