From ea11e4fe0504c9127f0926b665e7c3ddb47c9f05 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Sat, 14 Dec 2024 19:14:56 -0800 Subject: [PATCH] Remove thread shape analysis Signed-off-by: Harsh Menon --- iree/turbine/kernel/wave/codegen.py | 2 +- .../kernel/wave/index_sequence_analysis.py | 100 +++++++++++++++++- iree/turbine/kernel/wave/utils.py | 17 +++ iree/turbine/kernel/wave/wave.py | 4 - lit_tests/kernel/wave/expansion.py | 8 +- 5 files changed, 120 insertions(+), 11 deletions(-) diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index d0236ebb..d6507a41 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -555,7 +555,7 @@ def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]): there are multipled max_vals we pick the fastest/most minor one. """ - index_sizes = [i.size for i in indices.values()] + index_sizes = [subs_idxc(i.size) for i in indices.values()] # Find the maximum value max_size = max(index_sizes) # Find the fastest/most minor index of the maximum value. diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 24c08fcf..4d4dd662 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -6,11 +6,14 @@ from ..ops.wave_ops import ( Allocate, + BinaryPyOp, + Broadcast, Read, Write, ExtractSlice, get_custom, Reduction, + ReduceOp, MMA, Placeholder, IterArg, @@ -30,6 +33,8 @@ subs_idxc, get_inputs, get_users, + get_largest_index_and_size, + capture_backward_slice, ) import torch.fx as fx import numpy as np @@ -217,9 +222,11 @@ def has_gpr_offsets(node: fx.Node) -> bool: {GPR_NUM: cur_gpr_start_id} ), gpr_size, - 1 - if output_mapping[-1] == gpr_offset_dim - else simplified_index[gpr_offset_dim].stride, + ( + 1 + if output_mapping[-1] == gpr_offset_dim + else simplified_index[gpr_offset_dim].stride + ), ) updated_index_with_gpr_offset[ gpr_offset_dim @@ -307,6 +314,7 @@ def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): trace.walk(partial(set_thread_independent_index, constraints)) set_thread_dependent_index(constraints, mma_index, trace) set_derived_index(trace) + resolve_thread_shapes(trace, constraints) def compute_stride( @@ -668,3 +676,89 @@ def apply_offset(node: fx.Node): return False trace.walk(apply_offset) + + +def create_broadcast( + binary_op: BinaryPyOp, + to_broadcast: CustomOp, + broadcast_dim: IndexSymbol, + broadcast_size: int, + target_node: CustomOp, +): + """ + Create a broadcast node for the given binary operator. + """ + with binary_op.graph.inserting_before(binary_op.fx_node): + broadcasted = Broadcast(to_broadcast.fx_node, target_node.type).add_to_graph( + binary_op.graph + ) + custom = get_custom(broadcasted) + custom.vector_shapes = to_broadcast.vector_shapes + custom.index = deepcopy(target_node.index) + custom.index[broadcast_dim].size = broadcast_size + broadcast_idx = list(binary_op.node_args.values()).index(to_broadcast) + binary_op.update_arg(broadcast_idx, custom.fx_node) + + +def resolve_thread_shapes(trace: CapturedTrace, constraints: list[Constraint]): + """ + This function walks through all the binary operators in the graph and + if there is a discrepancy between the thread shapes of the operators + along the same dimension it resolves the discrepancy. + + Currently, the only mismatches that can be resolved are when one of + the shapes is 1 and the other is > 1. + """ + binary_ops = trace.walk(lambda node: isinstance(get_custom(node), BinaryPyOp)) + for binary_op in binary_ops: + custom = get_custom(binary_op) + # Get the largest dim and shape from the lhs and rhs. + lhs = get_custom(custom.lhs) + rhs = get_custom(custom.rhs) + lhs_dim, lhs_size = get_largest_index_and_size(lhs.index) + rhs_dim, rhs_size = get_largest_index_and_size(rhs.index) + if lhs_size > rhs_size: + to_broadcast = rhs + broadcast_dim = lhs_dim + broadcast_size = lhs_size + target = lhs + else: + to_broadcast = lhs + broadcast_dim = rhs_dim + broadcast_size = rhs_size + target = rhs + # If they are equal we are done. + if lhs_dim == rhs_dim and lhs_size == rhs_size: + continue + # If all are unit dims, there is nothing to do. + if lhs_size == 1 and rhs_size == 1: + continue + if lhs_dim != rhs_dim: + # If the dimensions don't agree, we can still do this broadcast only if + # this has a reduce op in its backward slice along the broadcasting dimension, + # or is read from placeholder that doesnt have the broadcasting dimension. + bwd_slice = capture_backward_slice(to_broadcast.fx_node) + reduce_ops = [ + get_custom(x) for x in bwd_slice if isinstance(get_custom(x), ReduceOp) + ] + reduce_criteria = reduce_ops and any( + reduce_op.dim == broadcast_dim for reduce_op in reduce_ops + ) + read_ops = [ + get_custom(x) for x in bwd_slice if isinstance(get_custom(x), Read) + ] + read_criteria = read_ops and all( + broadcast_dim not in read_op.indexing_dims for read_op in read_ops + ) + if not reduce_criteria and not read_criteria: + raise NotImplementedError( + "Currently only support resolving discrepancies along the same dimension." + ) + + # Cannot handle discrepancies when both shapes are > 1. + if lhs_size > 1 and rhs_size > 1: + raise NotImplementedError( + "Currently only support resolving discrepancies when one of the shapes is 1." + ) + # Broadcast + create_broadcast(custom, to_broadcast, broadcast_dim, broadcast_size, target) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index c46b4490..77ffea45 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -1336,3 +1336,20 @@ def check_is_mapping_contiguous( expected_diff[-1] = 1 return diff == expected_diff + + +def get_largest_index_and_size(indices: dict[IndexExpr, IndexSequence]): + """ + This function takes in indices of a Node, extract their sizes + into a list, and then returns the dimension with the largest size. + In case of ties, it picks the fastest changing dimension. + """ + + sorted_values = sorted( + [ + (i, dim, subs_idxc(index.size)) + for i, (dim, index) in enumerate(indices.items()) + ], + key=lambda x: (-x[2], -x[0]), + ) + return sorted_values[0][1:] diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 15ffe7f2..c0895976 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -51,7 +51,6 @@ apply_shared_memory_indexing_corrections, align_index_sizes, ) -from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr from .type_inference import infer_types @@ -260,9 +259,6 @@ def _trace_and_get_kernel_signature( # Set indices. set_node_indices(graph, self.constraints) - # Analyze Thread Shapes per Op. - determine_thread_shapes(graph) - # Expansion expand_graph(graph, self.constraints) diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 129085df..7d022eb1 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -561,6 +561,7 @@ def test_batched_gemm(): def gemm_non_direct_acc( a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + bias: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], ): c_reg = tkl.Register[M, N, tkl.f32](0.0) @@ -569,7 +570,8 @@ def gemm_non_direct_acc( def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: a_reg = tkw.read(a, elements_per_thread=4) b_reg = tkw.read(b, elements_per_thread=4) - new_acc = tkw.exp2(a_reg) + acc + bias_reg = tkw.read(bias, elements_per_thread=4) + new_acc = tkw.exp2(bias_reg) + acc acc = tkw.mma(a_reg, b_reg, new_acc) return acc @@ -603,11 +605,11 @@ def test_gemm_non_direct_acc(): # CHECK: %add_0_0_0 # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_0_0), kwargs = {}) # CHECK: %add_1_1_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_1_0), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_1_0, %acc_1_1_0), kwargs = {}) # CHECK: %add_1_0_0 # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_0_0), kwargs = {}) # CHECK: %add_0_1_0 - # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_1_0), kwargs = {}) + # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_1_0, %acc_0_1_0), kwargs = {}) # CHECK: %mma_0_0_0 # CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_0_0, %add_0_0_0, None), kwargs = {}) # CHECK: %mma_0_0_1