Skip to content

Commit

Permalink
Remove thread shape analysis
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Dec 15, 2024
1 parent 3972153 commit ea11e4f
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 11 deletions.
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]):
there are multipled max_vals we pick the fastest/most minor one.
"""

index_sizes = [i.size for i in indices.values()]
index_sizes = [subs_idxc(i.size) for i in indices.values()]
# Find the maximum value
max_size = max(index_sizes)
# Find the fastest/most minor index of the maximum value.
Expand Down
100 changes: 97 additions & 3 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from ..ops.wave_ops import (
Allocate,
BinaryPyOp,
Broadcast,
Read,
Write,
ExtractSlice,
get_custom,
Reduction,
ReduceOp,
MMA,
Placeholder,
IterArg,
Expand All @@ -30,6 +33,8 @@
subs_idxc,
get_inputs,
get_users,
get_largest_index_and_size,
capture_backward_slice,
)
import torch.fx as fx
import numpy as np
Expand Down Expand Up @@ -217,9 +222,11 @@ def has_gpr_offsets(node: fx.Node) -> bool:
{GPR_NUM: cur_gpr_start_id}
),
gpr_size,
1
if output_mapping[-1] == gpr_offset_dim
else simplified_index[gpr_offset_dim].stride,
(
1
if output_mapping[-1] == gpr_offset_dim
else simplified_index[gpr_offset_dim].stride
),
)
updated_index_with_gpr_offset[
gpr_offset_dim
Expand Down Expand Up @@ -307,6 +314,7 @@ def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]):
trace.walk(partial(set_thread_independent_index, constraints))
set_thread_dependent_index(constraints, mma_index, trace)
set_derived_index(trace)
resolve_thread_shapes(trace, constraints)


def compute_stride(
Expand Down Expand Up @@ -668,3 +676,89 @@ def apply_offset(node: fx.Node):
return False

trace.walk(apply_offset)


def create_broadcast(
binary_op: BinaryPyOp,
to_broadcast: CustomOp,
broadcast_dim: IndexSymbol,
broadcast_size: int,
target_node: CustomOp,
):
"""
Create a broadcast node for the given binary operator.
"""
with binary_op.graph.inserting_before(binary_op.fx_node):
broadcasted = Broadcast(to_broadcast.fx_node, target_node.type).add_to_graph(
binary_op.graph
)
custom = get_custom(broadcasted)
custom.vector_shapes = to_broadcast.vector_shapes
custom.index = deepcopy(target_node.index)
custom.index[broadcast_dim].size = broadcast_size
broadcast_idx = list(binary_op.node_args.values()).index(to_broadcast)
binary_op.update_arg(broadcast_idx, custom.fx_node)


def resolve_thread_shapes(trace: CapturedTrace, constraints: list[Constraint]):
"""
This function walks through all the binary operators in the graph and
if there is a discrepancy between the thread shapes of the operators
along the same dimension it resolves the discrepancy.
Currently, the only mismatches that can be resolved are when one of
the shapes is 1 and the other is > 1.
"""
binary_ops = trace.walk(lambda node: isinstance(get_custom(node), BinaryPyOp))
for binary_op in binary_ops:
custom = get_custom(binary_op)
# Get the largest dim and shape from the lhs and rhs.
lhs = get_custom(custom.lhs)
rhs = get_custom(custom.rhs)
lhs_dim, lhs_size = get_largest_index_and_size(lhs.index)
rhs_dim, rhs_size = get_largest_index_and_size(rhs.index)
if lhs_size > rhs_size:
to_broadcast = rhs
broadcast_dim = lhs_dim
broadcast_size = lhs_size
target = lhs
else:
to_broadcast = lhs
broadcast_dim = rhs_dim
broadcast_size = rhs_size
target = rhs
# If they are equal we are done.
if lhs_dim == rhs_dim and lhs_size == rhs_size:
continue
# If all are unit dims, there is nothing to do.
if lhs_size == 1 and rhs_size == 1:
continue
if lhs_dim != rhs_dim:
# If the dimensions don't agree, we can still do this broadcast only if
# this has a reduce op in its backward slice along the broadcasting dimension,
# or is read from placeholder that doesnt have the broadcasting dimension.
bwd_slice = capture_backward_slice(to_broadcast.fx_node)
reduce_ops = [
get_custom(x) for x in bwd_slice if isinstance(get_custom(x), ReduceOp)
]
reduce_criteria = reduce_ops and any(
reduce_op.dim == broadcast_dim for reduce_op in reduce_ops
)
read_ops = [
get_custom(x) for x in bwd_slice if isinstance(get_custom(x), Read)
]
read_criteria = read_ops and all(
broadcast_dim not in read_op.indexing_dims for read_op in read_ops
)
if not reduce_criteria and not read_criteria:
raise NotImplementedError(
"Currently only support resolving discrepancies along the same dimension."
)

# Cannot handle discrepancies when both shapes are > 1.
if lhs_size > 1 and rhs_size > 1:
raise NotImplementedError(
"Currently only support resolving discrepancies when one of the shapes is 1."
)
# Broadcast
create_broadcast(custom, to_broadcast, broadcast_dim, broadcast_size, target)
17 changes: 17 additions & 0 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,3 +1336,20 @@ def check_is_mapping_contiguous(
expected_diff[-1] = 1

return diff == expected_diff


def get_largest_index_and_size(indices: dict[IndexExpr, IndexSequence]):
"""
This function takes in indices of a Node, extract their sizes
into a list, and then returns the dimension with the largest size.
In case of ties, it picks the fastest changing dimension.
"""

sorted_values = sorted(
[
(i, dim, subs_idxc(index.size))
for i, (dim, index) in enumerate(indices.items())
],
key=lambda x: (-x[2], -x[0]),
)
return sorted_values[0][1:]
4 changes: 0 additions & 4 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
apply_shared_memory_indexing_corrections,
align_index_sizes,
)
from .thread_shape_analysis import determine_thread_shapes
from .scheduling.schedule import schedule_graph
from .._support.indexing import IndexingContext, IndexExpr
from .type_inference import infer_types
Expand Down Expand Up @@ -260,9 +259,6 @@ def _trace_and_get_kernel_signature(
# Set indices.
set_node_indices(graph, self.constraints)

# Analyze Thread Shapes per Op.
determine_thread_shapes(graph)

# Expansion
expand_graph(graph, self.constraints)

Expand Down
8 changes: 5 additions & 3 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def test_batched_gemm():
def gemm_non_direct_acc(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
bias: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)
Expand All @@ -569,7 +570,8 @@ def gemm_non_direct_acc(
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
new_acc = tkw.exp2(a_reg) + acc
bias_reg = tkw.read(bias, elements_per_thread=4)
new_acc = tkw.exp2(bias_reg) + acc
acc = tkw.mma(a_reg, b_reg, new_acc)
return acc

Expand Down Expand Up @@ -603,11 +605,11 @@ def test_gemm_non_direct_acc():
# CHECK: %add_0_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_0_0), kwargs = {})
# CHECK: %add_1_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_1_0), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_1_0, %acc_1_1_0), kwargs = {})
# CHECK: %add_1_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_0_0), kwargs = {})
# CHECK: %add_0_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_1_0), kwargs = {})
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_1_0, %acc_0_1_0), kwargs = {})
# CHECK: %mma_0_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_0_0, %add_0_0_0, None), kwargs = {})
# CHECK: %mma_0_0_1
Expand Down

0 comments on commit ea11e4f

Please sign in to comment.