From a74ba32a7d8c0d881bc7197cc8d7b9fc5eeb3942 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 21 Nov 2024 17:32:02 +0100 Subject: [PATCH 1/7] update determine_thread_shapes Signed-off-by: Ivan Butygin --- iree/turbine/kernel/_support/indexing.py | 3 + .../kernel/wave/thread_shape_analysis.py | 60 ++++++++++++------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/iree/turbine/kernel/_support/indexing.py b/iree/turbine/kernel/_support/indexing.py index b99d7b5b..326927ed 100644 --- a/iree/turbine/kernel/_support/indexing.py +++ b/iree/turbine/kernel/_support/indexing.py @@ -430,3 +430,6 @@ def __repr__(self): if isinstance(self.size, int) and self.size <= 1: return f"{self.start}" return f"{self.start} : {self.size} : {self.stride}" + + def __hash__(self): + return hash((self.start, self.size, self.stride)) diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 129f7551..7051d58a 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -19,26 +19,35 @@ @dataclass(order=True) -class DimSize: +class DimIndex: dim: IndexSymbol - size: int + seq: IndexSequence + + @property + def size(self) -> IndexExpr: + if isinstance(self.seq, int): + return self.seq + return subs_idxc(self.seq).size def __hash__(self): - return hash((self.dim, self.size)) + return hash((self.dim, self.seq)) -def get_dim_sizes(indices: list[IndexSequence]): - dims = frozenset( - [DimSize(dim, subs_idxc(seq.size)) for dim, seq in indices.items()] - ) +def process_seq(seq): + return seq + return IndexSequence(seq.start, seq.size, 1) + + +def get_dim_indices(indices: list[IndexSequence]): + dims = frozenset([DimIndex(dim, process_seq(seq)) for dim, seq in indices.items()]) return dims -def get_custom_dim_sizes(custom: CustomOp): - return get_dim_sizes(custom.index) +def get_custom_dim_indices(custom: CustomOp): + return get_dim_indices(custom.index) -def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): +def set_custom_index(custom: CustomOp, target_dim_sizes: list[DimIndex]): for target in target_dim_sizes: if target.dim not in custom.index: raise NotImplementedError( @@ -145,7 +154,7 @@ def determine_thread_shapes(trace: CapturedTrace): thread_shape in it's indexSequence. `thread_shapes` is used to store thread_size at every dimension that the op - cares about. We use a frozenset[DimSize] to represent it, where DimSize + cares about. We use a frozenset[DimIndex] to represent it, where DimIndex is essentially a pair. we are using frozen_set since we do not care about the order of dims for the shape/size propagation. @@ -171,10 +180,10 @@ def determine_thread_shapes(trace: CapturedTrace): """ anchor_ops = trace.walk(is_anchor_op) - thread_size_to_ops: dict[frozenset[DimSize], set[CustomOp]] = {} + thread_size_to_ops: dict[frozenset[DimIndex], set[CustomOp]] = {} for anchor_op in anchor_ops: custom = get_custom(anchor_op) - index_sizes = get_custom_dim_sizes(custom) + index_sizes = get_custom_dim_indices(custom) if isinstance(custom, Read): fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) thread_size_to_ops[index_sizes] = thread_size_to_ops.get( @@ -188,7 +197,11 @@ def determine_thread_shapes(trace: CapturedTrace): ): bwd_slice = capture_backward_slice(custom.init, propagatable_op) reduce_dims = frozenset( - [DimSize(dim, 1) for dim in custom.index.keys() if dim != custom.dim] + [ + DimIndex(dim, seq) + for dim, seq in custom.index.items() + if dim != custom.dim + ] ) thread_size_to_ops[reduce_dims] = ( thread_size_to_ops.get(reduce_dims, set([])) @@ -212,9 +225,9 @@ def determine_thread_shapes(trace: CapturedTrace): acc_slice = acc_slice.union( capture_backward_slice(custom.acc, propagatable_op) ) - acc_index = get_dim_sizes(custom.acc_index) - lhs_index = get_dim_sizes(custom.lhs_index) - rhs_index = get_dim_sizes(custom.rhs_index) + acc_index = get_dim_indices(custom.acc_index) + lhs_index = get_dim_indices(custom.lhs_index) + rhs_index = get_dim_indices(custom.rhs_index) thread_size_to_ops[acc_index] = thread_size_to_ops.get( acc_index, set([]) ).union(acc_slice) @@ -228,7 +241,7 @@ def determine_thread_shapes(trace: CapturedTrace): # The reshape op acts like a barrier for the MMA preventing # the mma from propagating the thread shapes of its reshaped # operands backwards. - bwd_size = get_dim_sizes(custom.args.index) + bwd_size = get_dim_indices(custom.args.index) bwd_slice = capture_backward_slice(custom.args, propagatable_op) thread_size_to_ops[bwd_size] = thread_size_to_ops.get( bwd_size, set([]) @@ -241,10 +254,17 @@ def determine_thread_shapes(trace: CapturedTrace): if not cummulative_set.isdisjoint(target_ops): conflicted_ops = cummulative_set.intersection(target_ops) if handle_conflicts(conflicted_ops) == False: - raise NotImplementedError("Failed to handle conflicting thread shape.") + offenders = tuple( + (ops, dim) + for dim, ops in thread_size_to_ops.items() + if not conflicted_ops.isdisjoint(ops) + ) + raise NotImplementedError( + f"Failed to handle conflicting thread shape: {conflicted_ops}, {offenders}" + ) target_ops = target_ops.difference(conflicted_ops) cummulative_set = cummulative_set.union(target_ops) # Set target ops's indexSize to be the determined from analysis. for user in target_ops: custom_user = get_custom(user) - set_index_size(custom_user, target_index_size) + set_custom_index(custom_user, target_index_size) From a19acf0becd10ced1b286d04eff47596cbca61cb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 21 Nov 2024 23:06:32 +0100 Subject: [PATCH 2/7] refac Signed-off-by: Ivan Butygin --- .../kernel/wave/thread_shape_analysis.py | 40 +++++++------------ 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 7051d58a..62633c5c 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -60,7 +60,7 @@ def set_custom_index(custom: CustomOp, target_dim_sizes: list[DimIndex]): # Anchor Indicies and Conflict resolution helpers ################################################################# -anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape) +anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute) noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) legalSubtypes = (IterArg,) nonPropagatableTypes = anchorOpTypes + noHandleTypes @@ -181,14 +181,16 @@ def determine_thread_shapes(trace: CapturedTrace): """ anchor_ops = trace.walk(is_anchor_op) thread_size_to_ops: dict[frozenset[DimIndex], set[CustomOp]] = {} + + def update_dims(index: frozenset[DimIndex], ops: set[CustomOp]): + thread_size_to_ops[index] = thread_size_to_ops.get(index, set([])).union(ops) + for anchor_op in anchor_ops: custom = get_custom(anchor_op) index_sizes = get_custom_dim_indices(custom) if isinstance(custom, Read): fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) - thread_size_to_ops[index_sizes] = thread_size_to_ops.get( - index_sizes, set([]) - ).union(fwd_slice) + update_dims(index_sizes, fwd_slice) elif isinstance(custom, ReduceOp): fwd_slice = capture_forward_slice(custom.fx_node, propagatable_op) bwd_slice = set() @@ -198,21 +200,17 @@ def determine_thread_shapes(trace: CapturedTrace): bwd_slice = capture_backward_slice(custom.init, propagatable_op) reduce_dims = frozenset( [ - DimIndex(dim, seq) + DimIndex(dim, IndexSequence(seq.start, 1, 1)) for dim, seq in custom.index.items() if dim != custom.dim ] ) - thread_size_to_ops[reduce_dims] = ( - thread_size_to_ops.get(reduce_dims, set([])) - .union(fwd_slice) - .union(bwd_slice) - ) + + update_dims(reduce_dims, fwd_slice) + update_dims(reduce_dims, bwd_slice) elif isinstance(custom, Write): bwd_slice = capture_backward_slice(custom.fx_node, propagatable_op) - thread_size_to_ops[index_sizes] = thread_size_to_ops.get( - index_sizes, set([]) - ).union(bwd_slice) + update_dims(index_sizes, bwd_slice) elif isinstance(custom, MMA): lhs_bwd_slice = set([custom.lhs]) if propagatable_op(custom.lhs): @@ -228,24 +226,16 @@ def determine_thread_shapes(trace: CapturedTrace): acc_index = get_dim_indices(custom.acc_index) lhs_index = get_dim_indices(custom.lhs_index) rhs_index = get_dim_indices(custom.rhs_index) - thread_size_to_ops[acc_index] = thread_size_to_ops.get( - acc_index, set([]) - ).union(acc_slice) - thread_size_to_ops[lhs_index] = thread_size_to_ops.get( - lhs_index, set([]) - ).union(lhs_bwd_slice) - thread_size_to_ops[rhs_index] = thread_size_to_ops.get( - rhs_index, set([]) - ).union(rhs_bwd_slice) + update_dims(acc_index, acc_slice) + update_dims(lhs_index, lhs_bwd_slice) + update_dims(rhs_index, rhs_bwd_slice) elif isinstance(custom, Reshape): # The reshape op acts like a barrier for the MMA preventing # the mma from propagating the thread shapes of its reshaped # operands backwards. bwd_size = get_dim_indices(custom.args.index) bwd_slice = capture_backward_slice(custom.args, propagatable_op) - thread_size_to_ops[bwd_size] = thread_size_to_ops.get( - bwd_size, set([]) - ).union(bwd_slice) + update_dims(bwd_size, bwd_slice) # Go through each index-size buckets, and apply the index-size to ops in the bucket. cummulative_set = set() From 69b9d6a1162e237bda2625b6ea094baf6a1f10d7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 21 Nov 2024 23:44:38 +0100 Subject: [PATCH 3/7] elem per thread Signed-off-by: Ivan Butygin --- iree/turbine/kernel/ops/wave_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index b1c25440..c13b2892 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1349,6 +1349,10 @@ def num_reduction_dims(self) -> int: def reduction_dim(self) -> IndexSymbol: return self.dim + @property + def elements_per_thread(self) -> int: + return 1 + # TODO: Add support for more shuffle types. @define_op("shuffle") From eabe087af58aaf98ada5fbcac111d60dbe6dc7f4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 21 Nov 2024 23:51:30 +0100 Subject: [PATCH 4/7] subs Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/thread_shape_analysis.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 62633c5c..d1d74d6d 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -25,17 +25,14 @@ class DimIndex: @property def size(self) -> IndexExpr: - if isinstance(self.seq, int): - return self.seq - return subs_idxc(self.seq).size + return self.seq.size def __hash__(self): return hash((self.dim, self.seq)) def process_seq(seq): - return seq - return IndexSequence(seq.start, seq.size, 1) + return subs_idxc(seq) def get_dim_indices(indices: list[IndexSequence]): From 4e1dbcdf4a065525f25ab1535fe1c2d43ef74a0f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 21 Nov 2024 23:56:02 +0100 Subject: [PATCH 5/7] fix Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/thread_shape_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index d1d74d6d..ed063e00 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -197,7 +197,7 @@ def update_dims(index: frozenset[DimIndex], ops: set[CustomOp]): bwd_slice = capture_backward_slice(custom.init, propagatable_op) reduce_dims = frozenset( [ - DimIndex(dim, IndexSequence(seq.start, 1, 1)) + DimIndex(dim, process_seq(IndexSequence(seq.start, 1, 1))) for dim, seq in custom.index.items() if dim != custom.dim ] From 2579cec8618f13159bdf0ef6b5a1db3593a45d79 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 22 Nov 2024 16:54:22 +0100 Subject: [PATCH 6/7] set index Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/thread_shape_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index ed063e00..f1c40aef 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -50,7 +50,7 @@ def set_custom_index(custom: CustomOp, target_dim_sizes: list[DimIndex]): raise NotImplementedError( "NYI: Handle when source target index size is not found in target/user index." ) - custom.index[target.dim].size = target.size + custom.index[target.dim] = target.seq ################################################################# From 2ca85f19a9de1e97b8e2efa96bd1eed3fcf1e6fc Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 22 Nov 2024 17:10:20 +0100 Subject: [PATCH 7/7] comments Signed-off-by: Ivan Butygin --- iree/turbine/kernel/ops/wave_ops.py | 4 ++++ iree/turbine/kernel/wave/thread_shape_analysis.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index c13b2892..8848b691 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1349,6 +1349,10 @@ def num_reduction_dims(self) -> int: def reduction_dim(self) -> IndexSymbol: return self.dim + # In `set_node_indices` there is a logic, which propagates `elements_per_thread` + # from previous ops is if wasn't for the current op, which causes ReduceOp to + # get wrong indices. This function will prevent this propagation. + # TODO: remove after index handling is fully switched to thread_shape_analysis. @property def elements_per_thread(self) -> int: return 1 diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index f1c40aef..49caa640 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -57,6 +57,9 @@ def set_custom_index(custom: CustomOp, target_dim_sizes: list[DimIndex]): # Anchor Indicies and Conflict resolution helpers ################################################################# +# TODO: Permute ops can have different indices on input and output. +# Add it to the anchorOpTypes to stop index propagation during forward/backward +# lookups. anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute) noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) legalSubtypes = (IterArg,)