Skip to content

Commit

Permalink
[Tripy][Bugfix] Use correct types in __str__ method for `ShapeScala…
Browse files Browse the repository at this point in the history
…r` (#212)

Noticed a small error: The `__str__` method for `ShapeScalar` assumed
that evaluating the scalar would give a list output, resulting in a type
error, since the actual result is a scalar. This PR fixes that and adds
a check to the unit tests.
  • Loading branch information
slyubomirsky authored Sep 19, 2024
1 parent d3e9d2c commit f72a7af
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
17 changes: 15 additions & 2 deletions tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,26 @@ def other_values(request):


class TestShapeScalar:
@pytest.mark.parametrize("value", [1, tp.Tensor(1), np.array(2)])
@pytest.mark.parametrize(
"value",
[
1,
tp.Tensor(1),
# Note: if we don't specify the dtype, the tensor constructor will insert a cast
# and the assert below about the trace_tensor's producer will fail.
np.array(2, dtype=np.int32),
],
)
def test_scalar_shape(self, value):
s = tp.ShapeScalar(values)
s = tp.ShapeScalar(value)

assert isinstance(s, tp.ShapeScalar)
assert s.trace_tensor.producer.inputs == []

def test_scalar_shape_str_method(self):
s = tp.ShapeScalar(12)
assert s.__str__() == f"shape_scalar(12)"

def test_scalar_slice(self):
a = tp.iota((3, 3))
assert isinstance(a.shape[0], tp.ShapeScalar)
Expand Down
4 changes: 3 additions & 1 deletion tripy/tripy/frontend/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __repr__(self) -> str:
return "shape_scalar" + tensor_repr[6:]

def __str__(self) -> str:
return "shape_scalar" + "(" + ", ".join(map(str, self.tolist())) + ")"
val = self.tolist()
assert isinstance(val, int)
return f"shape_scalar({val})"


@export.public_api()
Expand Down

0 comments on commit f72a7af

Please sign in to comment.