Skip to content

Commit

Permalink
Add scalar shape class to Tripy (#202)
Browse files Browse the repository at this point in the history
This PR adds `ScalarShape` to Tripy which is encodes a value that is
sliced out of a `Shape` tensor.
  • Loading branch information
parthchadha authored Sep 16, 2024
1 parent 6da16dc commit 251102e
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 31 deletions.
24 changes: 24 additions & 0 deletions tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ def other_values(request):
return request.param


class TestShapeScalar:
@pytest.mark.parametrize("value", [1, tp.Tensor(1), np.array(2)])
def test_scalar_shape(self, value):
s = tp.ShapeScalar(values)

assert isinstance(s, tp.ShapeScalar)
assert s.trace_tensor.producer.inputs == []

def test_scalar_slice(self):
a = tp.iota((3, 3))
assert isinstance(a.shape[0], tp.ShapeScalar)

s = a.shape[0] * a.shape[1]
b = tp.reshape(a, tp.reshape(s, (1,)))
assert tp.allclose(tp.flatten(a), b)

def test_scalar_scalar_op(self):
a = tp.iota((3, 4))
s1 = a.shape[0]
s2 = a.shape[1]
s = s1 + s2
assert isinstance(s, tp.ShapeScalar)


class TestShape:
def test_shape(self, values):
s = tp.Shape(values)
Expand Down
57 changes: 56 additions & 1 deletion tripy/tripy/frontend/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,62 @@
import tripy.frontend.utils as frontend_utils


@export.public_api()
class ShapeScalar(Tensor):
"""
Scalar shape is a tensor used to represent a scalar value extracted from a shape tensor.
ShapeScalars are scalars (rank 0) of non-negative integer (using int32 as the datatype).
"""

def __init__(
self,
data: Union[Sequence, Tensor, "np.ndarray", "cp.ndarray", "torch.Tensor", "jnp.ndarray"],
name: Optional[str] = None,
) -> None:
r"""
Args:
data: The value of the ShapeScalar, which should be a scalar integer.
name: An optional name
"""

from tripy.common.exception import raise_error

if isinstance(data, Tensor):
# these fields can be None in the case of an uninitialized tensor (like Tensor(None))
if data.trace_tensor.rank is not None and data.trace_tensor.rank != 0:
raise_error(
f"Scalar shape tensors must be of rank 0, but input tensor is rank {data.rank}", details=[data]
)
if data.dtype is not None and data.dtype != int32:
raise_error(
f"Scalar shape tensor must have int32 member, but input tensor has data type {data.dtype}",
details=[data],
)

# the shape of data should correspond to the given rank
super().__init__(data=None, dtype=int32, name=name, device=data.device)
# share the underlying data
self.trace_tensor = data.trace_tensor
self.stack_info = data.stack_info
else:
shape = data.shape if hasattr(data, "shape") else utils.get_shape(data)
device = data.device if hasattr(data, "device") else None
if len(shape) != 0:
raise_error(
f"Tensors used to represent scalar shapes must be of rank 0, but given shape {shape} has rank {len(shape)}."
)
super().__init__(data=data, dtype=int32, name=name, device=device)

def __repr__(self) -> str:
# denote the representation as a shape rather than a tensor
tensor_repr = super().__repr__()
assert tensor_repr[:6] == "tensor"
return "shape_scalar" + tensor_repr[6:]

def __str__(self) -> str:
return "shape_scalar" + "(" + ", ".join(map(str, self.tolist())) + ")"


@export.public_api()
class Shape(Tensor):
"""
Expand All @@ -47,7 +103,6 @@ def __init__(
r"""
Args:
data: The value of the shape, which should be a 1D array of integers (the dimensions).
num_dims: The number of dimensions in the shape (its rank), which should correspond to the number of elements in data
name: An optional name
"""

Expand Down
33 changes: 21 additions & 12 deletions tripy/tripy/frontend/trace/ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def build(cls, inputs: List["Tensor"], *args, num_outputs=1, **kwargs) -> Union[
"""

from tripy.common.exception import raise_error
from tripy.frontend.shape import Shape
from tripy.frontend.shape import Shape, ShapeScalar
from tripy.frontend.tensor import Tensor

# NOTE: If you change the stack depth where the tensors are constructed, update STACK_DEPTH_OF_BUILD in
Expand All @@ -87,22 +87,31 @@ def build(cls, inputs: List["Tensor"], *args, num_outputs=1, **kwargs) -> Union[
raise_error(
f"Error processing shape inputs in operator {cls.__name__}{custom_err}\n(Shape input indices: {shape_arg_msg}.)"
)
# for shape outputs, we infer the length
if len(res.value) != 0:
inferred_lengths = op.infer_len()
for idx in res.value:
outputs[idx] = Shape(outputs[idx])
if inferred_lengths[idx] is not None:
out_trace_tensors[idx].shape = [inferred_lengths[idx]]

shape = res.value.get("shape")
if shape is not None:
# for shape outputs, we infer the length
if len(shape) != 0:
inferred_lengths = op.infer_len()

for idx in shape:
outputs[idx] = Shape(outputs[idx])
if inferred_lengths[idx] is not None:
out_trace_tensors[idx].shape = [inferred_lengths[idx]]

scalar_shape = res.value.get("scalar")
if scalar_shape is not None:
for idx in scalar_shape:
outputs[idx] = ShapeScalar(outputs[idx])

if num_outputs == 1:
return outputs[0]
return outputs

def infer_shape_output_idxs(self, inputs: List["Tensor"]) -> Result:
"""
Given the operator's inputs, this method returns a `Result` containing a list of the operator's output indices
that should be wrapped in `tp.Shape`.
Given the operator's inputs, this method returns a `Result` containing a dict of the operator's output indices
that should be wrapped in `tp.Shape` or `tp.ShapeScalar`.
By default, this will wrap all the outputs in `tp.Shape` if all the inputs are `tp.Shape`s and not wrap any otherwise,
treating it as an error if the inputs are inconsistent.
Expand All @@ -126,9 +135,9 @@ def infer_shape_output_idxs(self, inputs: List["Tensor"]) -> Result:

if any(map(is_shape, inputs)):
if all(map(is_shape, inputs)):
return Result.ok(list(range(len(self.outputs))))
return Result.ok({"shape": list(range(len(self.outputs))), "scalar": []})
return Result.err(["Either all inputs must be tp.Shape or all must be tp.Tensor."])
return Result.ok([])
return Result.ok({})

def infer_len(self) -> List[Optional[int]]:
"""
Expand Down
9 changes: 6 additions & 3 deletions tripy/tripy/frontend/trace/ops/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __str__(self):

def infer_shape_output_idxs(self, inputs):
# permit one input to be a shape but require the output to be a shape
from tripy.frontend.shape import Shape
from tripy.frontend.shape import Shape, ShapeScalar
from tripy.utils import Result

if any(map(lambda t: isinstance(t, Shape), inputs)):
Expand All @@ -66,9 +66,12 @@ def infer_shape_output_idxs(self, inputs):
f"The following inputs have invalid ranks: {invalid_indices_message}",
]
)
return Result.ok([0])
return Result.ok({"shape": [0]})
elif all(map(lambda t: isinstance(t, ShapeScalar), inputs)):
# Binary operation on ShapeScalar should yield another ShapeScalar.
return Result.ok({"scalar": [0]})
else:
return Result.ok([])
return Result.ok({})

def infer_len(self):
# For the shape case, the result will be broadcast to the max of the input shapes
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def infer_shape_output_idxs(self, inputs):
if isinstance(inputs[0], Shape):
# Only still a valid shape if it remains int32
if self.dtype == int32:
return Result.ok([0])
return Result.ok([])
return Result.ok({"shape": [0]})
return Result.ok({})

infer_len = InferLenPolicies.infer_same_as_first_input

Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def infer_shape_output_idxs(self, inputs) -> Result:

# wrap if the first input is a shape and the output is rank-1
if isinstance(inputs[0], Shape) and self.output_rank == 1:
return Result.ok([0])
return Result.ok([])
return Result.ok({"shape": [0]})
return Result.ok({})

def infer_len(self):
if self.output_len is not None:
Expand Down
2 changes: 1 addition & 1 deletion tripy/tripy/frontend/trace/ops/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def infer_shape_output_idxs(self, inputs):
if (isinstance(inputs[0], Shape) and isinstance(inputs[1], Shape)) or (
not isinstance(inputs[0], Shape) and not isinstance(inputs[1], Shape)
):
return Result.ok([])
return Result.ok({})
return Result.err(None)

def infer_rank(self):
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def infer_shape_output_idxs(self, inputs):

# Only wrap the reshaped output if the result is rank 1, otherwise don't wrap
if isinstance(inputs[0], Shape) and self.output_rank == 1:
return Result.ok([0])
return Result.ok([])
return Result.ok({"shape": [0]})
return Result.ok({})

def infer_rank(self):
if self.output_rank is None:
Expand Down
2 changes: 1 addition & 1 deletion tripy/tripy/frontend/trace/ops/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Shape(BaseTraceOp):

# always return a shape
def infer_shape_output_idxs(self, inputs) -> Result:
return Result.ok([0])
return Result.ok({"shape": [0]})

def infer_len(self):
return [self.inputs[0].rank]
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __getitem__(self: "tripy.Tensor", index: Union[slice, int, Tuple[int], "trip
assert np.array_equal(cp.from_dlpack(output).get(), np.arange(10)[8:2:-1])
"""
from tripy.frontend.shape import Shape
from tripy.frontend.shape import ShapeScalar, Shape
from tripy.frontend.tensor import Tensor
from tripy.frontend.trace.ops.flip import flip
from tripy.frontend.trace.ops.reshape import reshape, squeeze
Expand Down Expand Up @@ -297,7 +297,7 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]:
if squeeze_dims:
out = squeeze(out, make_tuple(squeeze_dims))

return out
return ShapeScalar(out) if isinstance(self, Shape) and out.rank == 0 else out


# Conveniently converts the inputs to tensors. The decorator also fills in column info for the converted tensors.
Expand Down
6 changes: 3 additions & 3 deletions tripy/tripy/frontend/trace/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def infer_from_first_input_only(self, inputs):
from tripy.frontend.shape import Shape

if isinstance(inputs[0], Shape):
return Result.ok(list(range(len(self.outputs))))
return Result.ok([])
return Result.ok({"shape": list(range(len(self.outputs)))})
return Result.ok({})

def never_return_shape(self, inputs):
"""
Accepts shapes but the result is always no shape indices
"""
return Result.ok([])
return Result.ok({})


##
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def infer_shape_output_idxs(self, inputs):
f" the Boolean input must be rank 1, but given rank {inputs[0].rank}",
]
)
return Result.ok([0])
return Result.ok({"shape": [0]})
elif not isinstance(inputs[1], Shape) and not isinstance(inputs[2], Shape):
return Result.ok([])
return Result.ok({})
else:
return Result.err(
[
Expand Down

0 comments on commit 251102e

Please sign in to comment.