Skip to content

Commit

Permalink
Fix reduce dims to work with multiple negative dims (#136)
Browse files Browse the repository at this point in the history
- Adds test cases for multiple negative reduce dims, e.g.
```python
a = tp.ones((5,5,5))
tp.sum(a, dim=[-2,-1]
```
- Fixes the `_reduce_impl` to ensure negatives are sorted in decreasing
order when performing `unsqueeze`
  • Loading branch information
markkraay authored Aug 21, 2024
1 parent bc19b9b commit 977f653
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
16 changes: 13 additions & 3 deletions tripy/tests/integration/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ class TestReduceOp:
((2, 3, 4), (1, 2), False),
((2, 3, 4), None, False),
((2, 3, 4), None, True),
((2, 3, 4, 5), (-2, -1), True),
],
)
def test_all(self, x_shape, axis, keepdim):
x = np.array([i % 2 == 0 for i in np.arange(np.prod(x_shape))]).reshape(x_shape)
a = tp.Tensor(x)
out = tp.all(a, dim=axis, keepdim=keepdim)
expected = tp.Tensor(np.array(x.all(axis=axis, keepdims=keepdim)))
#np.array is necessary to deal with case where x.all returns a numpy scalar (5th case)
assert tp.allclose(out, tp.Tensor(np.array(x.all(axis=axis, keepdims=keepdim))))
assert out.shape == expected.shape
assert tp.allclose(out, expected)

@pytest.mark.parametrize(
"x_shape, axis, keepdim",
Expand All @@ -51,6 +54,7 @@ def test_all(self, x_shape, axis, keepdim):
((2, 3, 4), (1, 2), False),
((2, 3, 4), None, False),
((2, 3, 4), None, True),
((2, 3, 4, 5), (-2, -1), True),
],
)

Expand All @@ -69,6 +73,7 @@ def test_any(self, x_shape, axis, keepdim):
((2, 3, 4), (1, 2), False),
((2, 3, 4), None, False),
((2, 3, 4), None, True),
((2, 3, 4, 5), (-2, -1), True),
],
)
@pytest.mark.parametrize("dtype", [tp.float32, tp.float16])
Expand All @@ -77,7 +82,9 @@ def test_mean(self, x_shape, axis, keepdim: bool, dtype):
x = np.arange(np.prod(x_shape)).reshape(x_shape).astype(np_dtype)
a = tp.Tensor(x, dtype=dtype)
out = tp.mean(a, dim=axis, keepdim=keepdim)
assert tp.allclose(out, tp.Tensor(cp.array(x.mean(axis=axis, keepdims=keepdim))))
expected = tp.Tensor(cp.array(x.mean(axis=axis, keepdims=keepdim)))
assert out.shape == expected.shape
assert tp.allclose(out, expected, rtol=1e-3, atol=1e-3)

@pytest.mark.parametrize(
"x_shape, axis, keepdim",
Expand All @@ -88,14 +95,17 @@ def test_mean(self, x_shape, axis, keepdim: bool, dtype):
((2, 3, 4), None, True),
((2, 3), 1, False),
((2, 3, 4), (1, 2), False),
((2, 3, 4, 5), (-2, -1), True),
],
)
def test_var(self, x_shape, axis, keepdim: bool):
x = np.arange(np.prod(x_shape)).reshape(x_shape).astype(np.float32)
a = tp.Tensor(x)
out = tp.var(a, dim=axis, keepdim=keepdim)
torch_tensor = torch.Tensor(x)
assert tp.allclose(out, tp.Tensor(torch_tensor.var(dim=axis, keepdim=keepdim)))
expected = tp.Tensor(torch_tensor.var(dim=axis, keepdim=keepdim))
assert out.shape == expected.shape
assert tp.allclose(out, expected)

@pytest.mark.parametrize(
"x_shape, axis, keepdim",
Expand Down
4 changes: 3 additions & 1 deletion tripy/tripy/frontend/trace/ops/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def _reduce_impl(input: "tripy.Tensor", kind: Reduce.Kind, dim: Union[int, Seque
if dim is None:
out = reshape(out, (1,) * input.rank)
else:
for d in sorted(make_list(dim)):
# Custom comparison function ensures negatives are sorted in decreasing order, otherwise increasing.
# e.g, [-2, 0, -1, 2] is sorted as [-1, -2, 0, 2].
for d in sorted(make_list(dim), key=lambda x: (0, -x) if x < 0 else (1, x)):
out = unsqueeze(out, d)

return out
Expand Down

0 comments on commit 977f653

Please sign in to comment.