Skip to content

Commit

Permalink
Add tensor.tolist() method (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
parthchadha authored Aug 19, 2024
1 parent c0b6b49 commit 7c5bcbf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tripy/tests/frontend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,16 @@ def test_multiple_elements_boolean_fails(self):

with pytest.raises(tp.TripyException):
bool(tensor)

@pytest.mark.parametrize(
"tensor, expected",
[
(tp.Tensor([0]), [0]),
(tp.zeros((1, 1, 1)), [[[0]]]),
(tp.Tensor([[[0.1]]]), [[[0.1]]]),
(tp.Tensor([True]), [True]),
(tp.ones((1, 2), dtype=tp.float16), [[1.0, 1.0]]),
],
)
def test_tolist(self, tensor, expected):
assert np.allclose(tensor.tolist(), expected)
3 changes: 3 additions & 0 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def data(self) -> Array:
arr = cast(Tensor(arr), tripy.common.datatype.float32).eval()
return arr

def tolist(self) -> List:
return self.data().data()

def __iter__(self):
raise TypeError("Iterating over tensors is not supported")

Expand Down

0 comments on commit 7c5bcbf

Please sign in to comment.