Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Propagate dim index in thread shape analysis #288

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions iree/turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,6 @@ def __repr__(self):
if isinstance(self.size, int) and self.size <= 1:
return f"{self.start}"
return f"{self.start} : {self.size} : {self.stride}"

def __hash__(self):
return hash((self.start, self.size, self.stride))
4 changes: 4 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,10 @@ def num_reduction_dims(self) -> int:
def reduction_dim(self) -> IndexSymbol:
return self.dim

@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining why this is necessary?

def elements_per_thread(self) -> int:
return 1


# TODO: Add support for more shuffle types.
@define_op("shuffle")
Expand Down
95 changes: 51 additions & 44 deletions iree/turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,32 @@


@dataclass(order=True)
class DimSize:
class DimIndex:
dim: IndexSymbol
size: int
seq: IndexSequence

@property
def size(self) -> IndexExpr:
return self.seq.size

def __hash__(self):
return hash((self.dim, self.size))
return hash((self.dim, self.seq))


def get_dim_sizes(indices: list[IndexSequence]):
dims = frozenset(
[DimSize(dim, subs_idxc(seq.size)) for dim, seq in indices.items()]
)
def process_seq(seq):
return subs_idxc(seq)


def get_dim_indices(indices: list[IndexSequence]):
dims = frozenset([DimIndex(dim, process_seq(seq)) for dim, seq in indices.items()])
return dims


def get_custom_dim_sizes(custom: CustomOp):
return get_dim_sizes(custom.index)
def get_custom_dim_indices(custom: CustomOp):
return get_dim_indices(custom.index)


def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
def set_custom_index(custom: CustomOp, target_dim_sizes: list[DimIndex]):
for target in target_dim_sizes:
if target.dim not in custom.index:
raise NotImplementedError(
Expand All @@ -51,7 +57,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
# Anchor Indicies and Conflict resolution helpers
#################################################################

anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape)
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the forward propagation of permute, just as safety measure to ensure we won't be generating "valid" but incorrect IRs. Speaking from experience, would be much better for program to crash than to debug why MLIR is wrong and where the wrong is coming from. 😄

noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a comment explaining why Permute is added as an anchor op?

legalSubtypes = (IterArg,)
nonPropagatableTypes = anchorOpTypes + noHandleTypes
Expand Down Expand Up @@ -145,7 +151,7 @@ def determine_thread_shapes(trace: CapturedTrace):
thread_shape in it's indexSequence.

`thread_shapes` is used to store thread_size at every dimension that the op
cares about. We use a frozenset[DimSize] to represent it, where DimSize
cares about. We use a frozenset[DimIndex] to represent it, where DimIndex
is essentially a pair<dimension: IndexSymbol, thread_size: int>. we are using
frozen_set since we do not care about the order of dims for the shape/size
propagation.
Expand All @@ -171,15 +177,17 @@ def determine_thread_shapes(trace: CapturedTrace):

"""
anchor_ops = trace.walk(is_anchor_op)
thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {}
thread_size_to_ops: dict[frozenset[DimIndex], set[CustomOp]] = {}

def update_dims(index: frozenset[DimIndex], ops: set[CustomOp]):
thread_size_to_ops[index] = thread_size_to_ops.get(index, set([])).union(ops)

for anchor_op in anchor_ops:
custom = get_custom(anchor_op)
index_sizes = get_custom_dim_sizes(custom)
index_sizes = get_custom_dim_indices(custom)
if isinstance(custom, Read):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
index_sizes, set([])
).union(fwd_slice)
update_dims(index_sizes, fwd_slice)
elif isinstance(custom, ReduceOp):
fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op)
bwd_slice = set()
Expand All @@ -188,18 +196,18 @@ def determine_thread_shapes(trace: CapturedTrace):
):
bwd_slice = capture_backward_slice(custom.init, propagatable_op)
reduce_dims = frozenset(
[DimSize(dim, 1) for dim in custom.index.keys() if dim != custom.dim]
)
thread_size_to_ops[reduce_dims] = (
thread_size_to_ops.get(reduce_dims, set([]))
.union(fwd_slice)
.union(bwd_slice)
[
DimIndex(dim, process_seq(IndexSequence(seq.start, 1, 1)))
for dim, seq in custom.index.items()
if dim != custom.dim
]
)

update_dims(reduce_dims, fwd_slice)
update_dims(reduce_dims, bwd_slice)
elif isinstance(custom, Write):
bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op)
thread_size_to_ops[index_sizes] = thread_size_to_ops.get(
index_sizes, set([])
).union(bwd_slice)
update_dims(index_sizes, bwd_slice)
elif isinstance(custom, MMA):
lhs_bwd_slice = set([custom.lhs])
if propagatable_op(custom.lhs):
Expand All @@ -212,27 +220,19 @@ def determine_thread_shapes(trace: CapturedTrace):
acc_slice = acc_slice.union(
capture_backward_slice(custom.acc, propagatable_op)
)
acc_index = get_dim_sizes(custom.acc_index)
lhs_index = get_dim_sizes(custom.lhs_index)
rhs_index = get_dim_sizes(custom.rhs_index)
thread_size_to_ops[acc_index] = thread_size_to_ops.get(
acc_index, set([])
).union(acc_slice)
thread_size_to_ops[lhs_index] = thread_size_to_ops.get(
lhs_index, set([])
).union(lhs_bwd_slice)
thread_size_to_ops[rhs_index] = thread_size_to_ops.get(
rhs_index, set([])
).union(rhs_bwd_slice)
acc_index = get_dim_indices(custom.acc_index)
lhs_index = get_dim_indices(custom.lhs_index)
rhs_index = get_dim_indices(custom.rhs_index)
update_dims(acc_index, acc_slice)
update_dims(lhs_index, lhs_bwd_slice)
update_dims(rhs_index, rhs_bwd_slice)
elif isinstance(custom, Reshape):
# The reshape op acts like a barrier for the MMA preventing
# the mma from propagating the thread shapes of its reshaped
# operands backwards.
bwd_size = get_dim_sizes(custom.args.index)
bwd_size = get_dim_indices(custom.args.index)
bwd_slice = capture_backward_slice(custom.args, propagatable_op)
thread_size_to_ops[bwd_size] = thread_size_to_ops.get(
bwd_size, set([])
).union(bwd_slice)
update_dims(bwd_size, bwd_slice)

# Go through each index-size buckets, and apply the index-size to ops in the bucket.
cummulative_set = set()
Expand All @@ -241,10 +241,17 @@ def determine_thread_shapes(trace: CapturedTrace):
if not cummulative_set.isdisjoint(target_ops):
conflicted_ops = cummulative_set.intersection(target_ops)
if handle_conflicts(conflicted_ops) == False:
raise NotImplementedError("Failed to handle conflicting thread shape.")
offenders = tuple(
(ops, dim)
for dim, ops in thread_size_to_ops.items()
if not conflicted_ops.isdisjoint(ops)
)
raise NotImplementedError(
f"Failed to handle conflicting thread shape: {conflicted_ops}, {offenders}"
)
target_ops = target_ops.difference(conflicted_ops)
cummulative_set = cummulative_set.union(target_ops)
# Set target ops's indexSize to be the determined from analysis.
for user in target_ops:
custom_user = get_custom(user)
set_index_size(custom_user, target_index_size)
set_custom_index(custom_user, target_index_size)