Skip to content

Commit

Permalink
Makes ops.split in torch consistent with other backends (#914)
Browse files Browse the repository at this point in the history
* Makes split in torch consistent with other backends

* Update error msg
  • Loading branch information
james77777778 authored Sep 19, 2023
1 parent 956e89a commit ffd736e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 37 deletions.
18 changes: 15 additions & 3 deletions keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,22 +800,34 @@ def sort(x, axis=-1):

def split(x, indices_or_sections, axis=0):
x = convert_to_tensor(x)
dim = x.shape[axis]
if isinstance(indices_or_sections, (list, tuple)):
idxs = convert_to_tensor(indices_or_sections)
start_size = indices_or_sections[0]
end_size = x.shape[axis] - indices_or_sections[-1]
end_size = dim - indices_or_sections[-1]
chunk_sizes = (
[start_size]
+ torch.diff(idxs).type(torch.int).tolist()
+ [end_size]
)
else:
chunk_sizes = x.shape[axis] // indices_or_sections
return torch.split(
if dim % indices_or_sections != 0:
raise ValueError(
f"Received indices_or_sections={indices_or_sections} "
f"(interpreted as a number of sections) and axis={axis}, "
f"but input dimension x.shape[{axis}]={x.shape[axis]} "
f"is not divisible by {indices_or_sections}. "
f"Full input shape: x.shape={x.shape}"
)
chunk_sizes = dim // indices_or_sections
out = torch.split(
tensor=x,
split_size_or_sections=chunk_sizes,
dim=axis,
)
if dim == 0 and isinstance(indices_or_sections, int):
out = tuple(out[0].clone() for _ in range(indices_or_sections))
return out


def stack(x, axis=0):
Expand Down
53 changes: 19 additions & 34 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3283,40 +3283,25 @@ def test_sort(self):

def test_split(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
if backend.backend() == "torch":
self.assertAllClose(
[backend.convert_to_numpy(t) for t in knp.split(x, 2)],
np.split(x, 2),
)
self.assertAllClose(
[backend.convert_to_numpy(t) for t in knp.Split(2)(x)],
np.split(x, 2),
)
self.assertAllClose(
[
backend.convert_to_numpy(t)
for t in knp.split(x, [1, 2], axis=1)
],
np.split(x, [1, 2], axis=1),
)
self.assertAllClose(
[
backend.convert_to_numpy(t)
for t in knp.Split([1, 2], axis=1)(x)
],
np.split(x, [1, 2], axis=1),
)
else:
self.assertAllClose(knp.split(x, 2), np.split(x, 2))
self.assertAllClose(knp.Split(2)(x), np.split(x, 2))
self.assertAllClose(
knp.split(x, [1, 2], axis=1),
np.split(x, [1, 2], axis=1),
)
self.assertAllClose(
knp.Split([1, 2], axis=1)(x),
np.split(x, [1, 2], axis=1),
)
self.assertAllClose(knp.split(x, 2), np.split(x, 2))
self.assertAllClose(knp.Split(2)(x), np.split(x, 2))
self.assertAllClose(
knp.split(x, [1, 2], axis=1),
np.split(x, [1, 2], axis=1),
)
self.assertAllClose(
knp.Split([1, 2], axis=1)(x),
np.split(x, [1, 2], axis=1),
)

# test invalid indices_or_sections
with self.assertRaises(Exception):
knp.split(x, 3)

# test zero dimension
x = np.ones(shape=(0,))
self.assertEqual(len(knp.split(x, 2)), 2)
self.assertEqual(len(knp.Split(2)(x)), 2)

def test_sqrt(self):
x = np.array([[1, 4, 9], [16, 25, 36]])
Expand Down

0 comments on commit ffd736e

Please sign in to comment.