Skip to content

Commit

Permalink
Rework Reshape->IndexLambda lowering
Browse files Browse the repository at this point in the history
- Better handle 1-long axes in old and new shapes
- Avoid generating modulo expressions for direct pass-through
inducer committed Jan 30, 2025
1 parent 7b7c9ad commit b73a556
Showing 1 changed file with 56 additions and 23 deletions.
79 changes: 56 additions & 23 deletions pytato/transform/lower_to_index_lambda.py
Original file line number Diff line number Diff line change
@@ -68,9 +68,9 @@


@dataclass(frozen=True)
class _ReshapeIndexGroup:
old_ax_indices: tuple[ShapeComponent, ...]
new_ax_indices: tuple[ShapeComponent, ...]
class _ReshapeShapeGroup:
old_ax_shape_group: tuple[ShapeComponent, ...]
new_ax_shape_group: tuple[ShapeComponent, ...]


def _generate_index_expressions(
@@ -84,6 +84,15 @@ def _generate_index_expressions(
old_strides = old_strides[:len(old_shape)]
new_strides = new_strides[:len(new_shape)]

if not old_shape:
assert new_shape == (1,)
return (0,)

if old_shape == new_shape:
# Avoid generating modulo expressions for direct pass-through
assert len(old_shape) == 1
return (index_vars[0],)

old_size_tills = [old_shape[-1] if order == "C" else old_shape[0]]

old_stride_axs = (old_shape[::-1][:-1] if order == "C" else
@@ -148,7 +157,7 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]:

# {{{ generate subsets of old axes mapped to subsets of new axes

axis_mapping: list[_ReshapeIndexGroup] = []
axis_mapping: list[_ReshapeShapeGroup] = []

old_index = 0
new_index = 0
@@ -157,6 +166,21 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]:
old_ax_len_product = old_shape[old_index]
new_ax_len_product = new_shape[new_index]

# Specially handle (i.e. skip) axes of length 1 at the start of an index group
if old_ax_len_product != new_ax_len_product:
if old_ax_len_product == 1:
axis_mapping.append(_ReshapeShapeGroup(
old_ax_shape_group=(old_ax_len_product,),
new_ax_shape_group=()))
old_index += 1
continue
if new_ax_len_product == 1:
axis_mapping.append(_ReshapeShapeGroup(
old_ax_shape_group=(),
new_ax_shape_group=(new_ax_len_product,)))
new_index += 1
continue

old_product_end = old_index + 1
new_product_end = new_index + 1

@@ -173,48 +197,57 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]:
old_ax_len_product *= old_shape[old_product_end]
old_product_end += 1

old_ax_indices = old_shape[old_index:old_product_end]
new_ax_indices = new_shape[new_index:new_product_end]

axis_mapping.append(_ReshapeIndexGroup(
old_ax_indices=old_ax_indices,
new_ax_indices=new_ax_indices))
axis_mapping.append(_ReshapeShapeGroup(
old_ax_shape_group=old_shape[old_index:old_product_end],
new_ax_shape_group=new_shape[new_index:new_product_end]))

old_index = old_product_end
new_index = new_product_end

# handle trailing 1s
final_reshaped_indices = axis_mapping.pop(-1)
old_ax_indices = final_reshaped_indices.old_ax_indices
new_ax_indices = final_reshaped_indices.new_ax_indices

# At most one of the while loops below should execute.
assert not (
old_index < len(old_shape)
and
new_index < len(new_shape)
)

while old_index < len(old_shape):
old_ax_indices += tuple([old_shape[old_index]]) # noqa: C409
assert old_shape[old_index] == 1
axis_mapping.append(_ReshapeShapeGroup(
old_ax_shape_group=(old_shape[old_index],),
new_ax_shape_group=()))
old_index += 1

while new_index < len(new_shape):
new_ax_indices += tuple([new_shape[new_index]]) # noqa: C409
assert new_shape[new_index] == 1
axis_mapping.append(_ReshapeShapeGroup(
old_ax_shape_group=(),
new_ax_shape_group=(new_shape[new_index],),
))
new_index += 1

axis_mapping.append(_ReshapeIndexGroup(old_ax_indices=old_ax_indices,
new_ax_indices=new_ax_indices))

# }}}

# {{{ compute index expressions for sub shapes

index_vars_begin = 0
index_expressions = []
for reshaped_indices in axis_mapping:
sub_old_shape = reshaped_indices.old_ax_indices
sub_new_shape = reshaped_indices.new_ax_indices
for shape_group in axis_mapping:
sub_old_shape = shape_group.old_ax_shape_group
sub_new_shape = shape_group.new_ax_shape_group

index_vars_end = index_vars_begin + len(sub_new_shape)
sub_index_vars = index_vars[index_vars_begin:index_vars_end]
index_vars_begin = index_vars_end

index_expressions.append(_generate_index_expressions(
sub_old_shape, sub_new_shape, order, sub_index_vars))
if not sub_old_shape:
# No need to generate index into old array
assert sub_new_shape == (1,)
else:
index_expressions.append(_generate_index_expressions(
sub_old_shape, sub_new_shape, order, sub_index_vars))

# }}}

0 comments on commit b73a556

Please sign in to comment.