diff --git a/iree/turbine/kernel/lang/kernel_buffer.py b/iree/turbine/kernel/lang/kernel_buffer.py index 4ab00ea7..980aac49 100644 --- a/iree/turbine/kernel/lang/kernel_buffer.py +++ b/iree/turbine/kernel/lang/kernel_buffer.py @@ -64,6 +64,7 @@ def new_subtype( address_space: AddressSpace | NotSetType = NotSet, symbolic_shape: tuple[IndexExpr, ...] | NotSetType = NotSet, dtype: DataType | NotSetType = NotSet, + stride: tuple[IndexExpr, ...] | NotSetType = NotSet, usage: KernelBufferUsage | NotSetType = NotSet, ) -> Type[SubtypeT]: init_address_space = ( @@ -71,6 +72,7 @@ def new_subtype( ) init_symbolic_shape = symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape # type: ignore init_dtype = dtype if dtype is not NotSet else cls.dtype # type: ignore + init_stride = stride if stride else None init_usage = usage if usage is not NotSet else cls.usage # type: ignore class SubType(cls): @@ -78,6 +80,7 @@ class SubType(cls): symbolic_shape = init_symbolic_shape rank = len(init_symbolic_shape) # type: ignore dtype = init_dtype + stride = init_stride usage = init_usage if name is not NotSet: diff --git a/iree/turbine/kernel/lang/wave_types.py b/iree/turbine/kernel/lang/wave_types.py index f87a9570..e310c303 100644 --- a/iree/turbine/kernel/lang/wave_types.py +++ b/iree/turbine/kernel/lang/wave_types.py @@ -41,6 +41,7 @@ class Memory(metaclass=KernelBufferMeta): symbolic_shape: ClassVar[tuple[IndexExpr, ...]] rank: ClassVar[int] dtype: ClassVar[DataType] + stride: ClassVar[Optional[tuple[IndexExpr, ...]]] usage: ClassVar[Optional[KernelBufferUsage]] def __init__(self) -> None: @@ -55,9 +56,15 @@ def __class_getitem__( shift = 0 usage = KernelBufferUsage.NONE - if isinstance(shape_and_dtype[-1], KernelBufferUsage): - shift = 1 - usage = shape_and_dtype[-1] + last_dim = -1 + if isinstance(shape_and_dtype[last_dim], KernelBufferUsage): + shift += 1 + usage = shape_and_dtype[last_dim] + last_dim -= 1 + stride = None + if isinstance(shape_and_dtype[last_dim], Sequence): + shift += 1 + stride = shape_and_dtype[last_dim] shape = shape_and_dtype[: -2 - shift] addressSpace = shape_and_dtype[-2 - shift] dtype = shape_and_dtype[-1 - shift] @@ -85,6 +92,7 @@ def __class_getitem__( address_space=addressSpace, symbolic_shape=shape, dtype=dtype, + stride=stride, usage=usage, ) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index ffa73618..b4b6826a 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -461,7 +461,7 @@ def node_args(self) -> dict[int, Any]: for i, arg in enumerate(self.fx_node.args): if isinstance(arg, fx.Node): custom_args[i] = get_custom(arg) - if isinstance(arg, list) and all(isinstance(x, fx.Node) for x in arg): + if isinstance(arg, Sequence) and all(isinstance(x, fx.Node) for x in arg): custom_args[i] = [get_custom(x) for x in arg] return custom_args @@ -1013,7 +1013,11 @@ def transform_index_backwards( iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] subs = {v: k for k, v in zip(iters, mapping.keys())} - return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()} + return { + k: v.apply_expr(subs[k], mapping[k]) + for k, v in index.items() + if k in mapping + } return index @@ -1253,7 +1257,11 @@ def transform_index_backwards( iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] subs = {v: k for k, v in zip(iters, mapping.keys())} - return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()} + return { + k: v.apply_expr(subs[k], mapping[k]) + for k, v in index.items() + if k in mapping + } return index diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 346d7529..11344e54 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -283,7 +283,7 @@ def _expand_node( for i, arg in node.node_args.items(): arg_list = arg unpack = lambda x: x - if isinstance(arg, list): + if isinstance(arg, Sequence): if not all(is_expandable(a) for a in arg): continue else: diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 24c08fcf..ecbe67ff 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -217,9 +217,11 @@ def has_gpr_offsets(node: fx.Node) -> bool: {GPR_NUM: cur_gpr_start_id} ), gpr_size, - 1 - if output_mapping[-1] == gpr_offset_dim - else simplified_index[gpr_offset_dim].stride, + ( + 1 + if output_mapping[-1] == gpr_offset_dim + else simplified_index[gpr_offset_dim].stride + ), ) updated_index_with_gpr_offset[ gpr_offset_dim diff --git a/iree/turbine/kernel/wave/minimize_global_loads.py b/iree/turbine/kernel/wave/minimize_global_loads.py index 62b623a2..82793cee 100644 --- a/iree/turbine/kernel/wave/minimize_global_loads.py +++ b/iree/turbine/kernel/wave/minimize_global_loads.py @@ -125,9 +125,12 @@ def add_optimized_nodes( access_pattern: dict[IndexSymbol, IndexSequence] = custom.index for i in range(expected_number_of_loads): with custom.graph.inserting_before(custom.fx_node): - read = Read(memory, load_elems_per_thread, custom.mapping).add_to_graph( - custom.graph - ) + read = Read( + memory, + load_elems_per_thread, + custom.mapping, + custom.mapping_dynamic_vals, + ).add_to_graph(custom.graph) global_offset = ( hardware_constraint.linearized_thread_id * load_elems_per_thread + i * max_elements_per_load diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index 320d9014..bfd5b9d5 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -13,6 +13,7 @@ ) import torch from enum import Enum +import sympy # Input sizes B = tkl.sym.B @@ -36,442 +37,622 @@ STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD -@run_test -def test_evoformer(): - # B, BN, K2, H, K1, M, N - shape = (1, 256, 256, 4, 32, 256, 32) - # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.WorkgroupConstraint(BN, BLOCK_BN, 3)] - constraints += [tkw.WorkgroupConstraint(H, BLOCK_H, 4)] - constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] - - mfma_variant = tkw.MMAType.F32_16x16x16_F16 - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(2, 2, 1), - mma_type=mfma_variant, - vector_shapes={B: 0, BN: 0, H: 0, M: 16, N: 16}, - ) - ] - - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - l = tkw.IndexMapping.iterator(3) - m = tkw.IndexMapping.iterator(4) - # [B, BN, M, H, K1] -> [B, BN, H, M, K1] - q_mapping = tkw.IndexMapping( - num_iterators=5, - inputs={B: i, BN: j, H: k, M: l, K1: m}, - outputs={B: i, BN: j, H: k, M: l, K1: m}, - ) - # [B, BN, K2, H, K1] -> [B, BN, H, K2, K1] - k_mapping = tkw.IndexMapping( - num_iterators=5, - inputs={B: i, BN: j, H: k, K2: l, K1: m}, - outputs={B: i, BN: j, H: k, K2: l, K1: m}, - ) - # [B, BN, K2, H, N] -> [B, BN, H, N, K2] - v_mapping = tkw.IndexMapping( - num_iterators=5, - inputs={B: i, BN: j, H: k, N: l, K2: m}, - outputs={B: i, BN: j, H: k, N: l, K2: m}, - ) - # [B, BN, H, N, M] -> [B, BN, M, H, N] - o_mapping = tkw.IndexMapping( - num_iterators=5, - inputs={B: i, BN: j, H: k, N: l, M: m}, - outputs={B: i, BN: j, H: k, N: l, M: m}, - ) - - @tkw.wave(constraints) - def evoformer( - q: tkl.Memory[B, BN, M, H, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, BN, K2, H, K1, ADDRESS_SPACE, tkl.f16], - v: tkl.Memory[B, BN, K2, H, N, ADDRESS_SPACE, tkl.f16], - mask: tkl.Memory[B, BN, K2, GLOBAL_ADDRESS_SPACE, tkl.f16], - bias: tkl.Memory[B, H, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, BN, M, H, N, GLOBAL_ADDRESS_SPACE, tkl.f16], - ): - c_reg = tkl.Register[B, BN, H, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, BN, H, M, tkl.f32](0.0) - init_max = tkl.Register[B, BN, H, M, tkl.f32](-1e6) - - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, BN, H, M, tkl.f32], - partial_sum: tkl.Register[B, BN, H, M, tkl.f32], - acc: tkl.Register[B, BN, H, N, M, tkl.f32], - ) -> ( - tkl.Register[B, BN, H, M, tkl.f32], - tkl.Register[B, BN, H, M, tkl.f32], - tkl.Register[B, BN, H, N, M, tkl.f32], - ): - imm_reg = tkl.Register[B, BN, H, K2, M, tkl.f32](0.0) - q_reg = tkw.read( - q, mapping=q_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD - ) - k_reg = tkw.read( - k, mapping=k_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD - ) - inner_acc = tkw.mma(k_reg, q_reg, imm_reg) - x_j = tkw.permute(inner_acc, target_shape=[B, BN, H, M, K2]) - mask_reg = tkw.read(mask, elements_per_thread=STORE_ELEMS_PER_THREAD) - casted_mask_reg = tkw.cast(mask_reg, tkl.f32) - y_j = x_j + casted_mask_reg - bias_reg = tkw.read(bias, elements_per_thread=STORE_ELEMS_PER_THREAD) - casted_bias_reg = tkw.cast(bias_reg, tkl.f32) - z_j = y_j + casted_bias_reg - m_j = tkw.max(z_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(z_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read( - v, mapping=v_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD - ) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc - - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - casted = tkw.cast(res, tkl.f16) - tkw.write( - casted, c, mapping=o_mapping, elements_per_thread=STORE_ELEMS_PER_THREAD - ) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), - STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), - B: shape[0], - BN: shape[1], - K2: shape[2], - H: shape[3], - K1: shape[4], - M: shape[5], - N: shape[6], - BLOCK_B: 1, - BLOCK_BN: 1, - BLOCK_H: 1, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K2: 32, - READ_SHARED_DELAY: 1, - WRITE_SHARED_DELAY: 1, - READ_GLOBAL_DELAY: 2, - WRITE_GLOBAL_DELAY: 2, - MMA_DELAY: 1, - VALU_DELAY: 1, - SHUFFLE_DELAY: 1, - SHARED_MEMORY_UNITS: 4, - GLOBAL_MEMORY_UNITS: 4, - MMA_UNITS: 4, - VALU_UNITS: 2, - SHUFFLE_UNITS: 2, - } - - with tk.gen.TestLaunchContext( - hyperparams, - canonicalize=True, - run=False, - run_bench=False, - schedule=False, - use_scheduling_barriers=False, - ): - torch.manual_seed(0) - q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) - k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) - v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) - output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) - print(evoformer(q, k, v, output).module_op) - - # CHECK: func.func @evoformer - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-5: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store {{.*}} - # CHECK-COUNT-4: {{.*}} = vector.load - # CHECK-COUNT-8: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = vector.load - # CHECK-COUNT-2: {{.*}} = arith.extf - # CHECK-COUNT-4: {{.*}} = arith.addf - # CHECK-COUNT-4: {{.*}} = vector.load - # CHECK-COUNT-4: {{.*}} = arith.extf - # CHECK-COUNT-4: {{.*}} = arith.addf +# @run_test +# def test_evoformer(): +# # B, BN, K2, H, K1, M, N +# shape = (1, 256, 256, 4, 32, 256, 32) +# # Expose user-constraints +# constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] +# constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] +# constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] +# constraints += [tkw.WorkgroupConstraint(BN, BLOCK_BN, 3)] +# constraints += [tkw.WorkgroupConstraint(H, BLOCK_H, 4)] +# constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] +# constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] +# constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +# +# mfma_variant = tkw.MMAType.F32_16x16x16_F16 +# constraints += [ +# tkw.HardwareConstraint( +# threads_per_wave=64, +# waves_per_block=(2, 2, 1), +# mma_type=mfma_variant, +# vector_shapes={B: 0, BN: 0, H: 0, M: 16, N: 16}, +# ) +# ] +# +# i = tkw.IndexMapping.iterator(0) +# j = tkw.IndexMapping.iterator(1) +# k = tkw.IndexMapping.iterator(2) +# l = tkw.IndexMapping.iterator(3) +# m = tkw.IndexMapping.iterator(4) +# # [B, BN, M, H, K1] -> [B, BN, H, M, K1] +# q_mapping = tkw.IndexMapping( +# num_iterators=5, +# inputs={B: i, BN: j, H: k, M: l, K1: m}, +# outputs={B: i, BN: j, H: k, M: l, K1: m}, +# ) +# # [B, BN, K2, H, K1] -> [B, BN, H, K2, K1] +# k_mapping = tkw.IndexMapping( +# num_iterators=5, +# inputs={B: i, BN: j, H: k, K2: l, K1: m}, +# outputs={B: i, BN: j, H: k, K2: l, K1: m}, +# ) +# # [B, BN, K2, H, N] -> [B, BN, H, N, K2] +# v_mapping = tkw.IndexMapping( +# num_iterators=5, +# inputs={B: i, BN: j, H: k, N: l, K2: m}, +# outputs={B: i, BN: j, H: k, N: l, K2: m}, +# ) +# # [B, BN, H, N, M] -> [B, BN, M, H, N] +# o_mapping = tkw.IndexMapping( +# num_iterators=5, +# inputs={B: i, BN: j, H: k, N: l, M: m}, +# outputs={B: i, BN: j, H: k, N: l, M: m}, +# ) +# +# @tkw.wave(constraints) +# def evoformer( +# q: tkl.Memory[B, BN, M, H, K1, ADDRESS_SPACE, tkl.f16], +# k: tkl.Memory[B, BN, K2, H, K1, ADDRESS_SPACE, tkl.f16], +# v: tkl.Memory[B, BN, K2, H, N, ADDRESS_SPACE, tkl.f16], +# mask: tkl.Memory[B, BN, K2, GLOBAL_ADDRESS_SPACE, tkl.f16], +# bias: tkl.Memory[B, H, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, BN, M, H, N, GLOBAL_ADDRESS_SPACE, tkl.f16], +# ): +# c_reg = tkl.Register[B, BN, H, N, M, tkl.f32](0.0) +# init_sum = tkl.Register[B, BN, H, M, tkl.f32](0.0) +# init_max = tkl.Register[B, BN, H, M, tkl.f32](-1e6) +# +# @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) +# def repeat( +# partial_max: tkl.Register[B, BN, H, M, tkl.f32], +# partial_sum: tkl.Register[B, BN, H, M, tkl.f32], +# acc: tkl.Register[B, BN, H, N, M, tkl.f32], +# ) -> ( +# tkl.Register[B, BN, H, M, tkl.f32], +# tkl.Register[B, BN, H, M, tkl.f32], +# tkl.Register[B, BN, H, N, M, tkl.f32], +# ): +# imm_reg = tkl.Register[B, BN, H, K2, M, tkl.f32](0.0) +# q_reg = tkw.read( +# q, mapping=q_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD +# ) +# k_reg = tkw.read( +# k, mapping=k_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD +# ) +# inner_acc = tkw.mma(k_reg, q_reg, imm_reg) +# x_j = tkw.permute(inner_acc, target_shape=[B, BN, H, M, K2]) +# mask_reg = tkw.read(mask, elements_per_thread=STORE_ELEMS_PER_THREAD) +# casted_mask_reg = tkw.cast(mask_reg, tkl.f32) +# y_j = x_j + casted_mask_reg +# bias_reg = tkw.read(bias, elements_per_thread=STORE_ELEMS_PER_THREAD) +# casted_bias_reg = tkw.cast(bias_reg, tkl.f32) +# z_j = y_j + casted_bias_reg +# m_j = tkw.max(z_j, partial_max, dim=K2) +# e_delta_max = tkw.exp2(partial_max - m_j) +# e_delta = tkw.exp2(z_j - m_j) +# e_init = partial_sum * e_delta_max +# d_j = tkw.sum(e_delta, e_init, dim=K2) +# imm_f16 = tkw.cast(e_delta, tkl.f16) +# v_reg = tkw.read( +# v, mapping=v_mapping, elements_per_thread=LOAD_ELEMS_PER_THREAD +# ) +# new_acc = acc * e_delta_max +# acc = tkw.mma(v_reg, imm_f16, new_acc) +# return m_j, d_j, acc +# +# # repeat represents the results of the loop +# res_max, res_sum, res_mm = repeat +# res = res_mm / res_sum +# casted = tkw.cast(res, tkl.f16) +# tkw.write( +# casted, c, mapping=o_mapping, elements_per_thread=STORE_ELEMS_PER_THREAD +# ) +# +# hyperparams = { +# ADDRESS_SPACE: SHARED_ADDRESS_SPACE, +# LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), +# STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), +# B: shape[0], +# BN: shape[1], +# K2: shape[2], +# H: shape[3], +# K1: shape[4], +# M: shape[5], +# N: shape[6], +# BLOCK_B: 1, +# BLOCK_BN: 1, +# BLOCK_H: 1, +# BLOCK_M: 64, +# BLOCK_N: 64, +# BLOCK_K2: 32, +# READ_SHARED_DELAY: 1, +# WRITE_SHARED_DELAY: 1, +# READ_GLOBAL_DELAY: 2, +# WRITE_GLOBAL_DELAY: 2, +# MMA_DELAY: 1, +# VALU_DELAY: 1, +# SHUFFLE_DELAY: 1, +# SHARED_MEMORY_UNITS: 4, +# GLOBAL_MEMORY_UNITS: 4, +# MMA_UNITS: 4, +# VALU_UNITS: 2, +# SHUFFLE_UNITS: 2, +# } +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=False, +# use_scheduling_barriers=False, +# ): +# torch.manual_seed(0) +# q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) +# k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) +# v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) +# output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) +# print(evoformer(q, k, v, output).module_op) +# +# # CHECK: func.func @evoformer +# # CHECK: {{.*}} = scf.for +# # CHECK-COUNT-5: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store {{.*}} +# # CHECK-COUNT-4: {{.*}} = vector.load +# # CHECK-COUNT-8: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-2: {{.*}} = vector.load +# # CHECK-COUNT-2: {{.*}} = arith.extf +# # CHECK-COUNT-4: {{.*}} = arith.addf +# # CHECK-COUNT-4: {{.*}} = vector.load +# # CHECK-COUNT-4: {{.*}} = arith.extf +# # CHECK-COUNT-4: {{.*}} = arith.addf +# +# +## This test sets all the dimensions except K1 to be dynamic. +## The reason why we can't set K1 to be dynamic is because K1 is the +## tile size we use for expanding the K1 MMA. We could set K1 to be +## dynamic if we tiled the K1 dimension with a tile size of BLOCK_K1. +# @run_test +# def test_dynamic_attention_pipelined(): +# shape = (8, 128, 128, 64, 256) +# # Expose user-constraints +# constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] +# constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] +# constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] +# constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] +# constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] +# constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +# +# mfma_variant = tkw.MMAType.F32_16x16x16_F16 +# constraints += [ +# tkw.HardwareConstraint( +# threads_per_wave=64, +# waves_per_block=(2, 2, 1), +# mma_type=mfma_variant, +# vector_shapes={B: 0, M: 16, N: 16}, +# ) +# ] +# +# constraints += [tkw.Assumption(K2 > 4 * BLOCK_K2)] +# +# i = tkw.IndexMapping.iterator(0) +# j = tkw.IndexMapping.iterator(1) +# k = tkw.IndexMapping.iterator(2) +# mapping = tkw.IndexMapping( +# num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} +# ) +# +# @tkw.wave(constraints) +# def dynamic_attention_pipelined( +# q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], +# k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], +# v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], +# ): +# c_reg = tkl.Register[B, N, M, tkl.f32](0.0) +# init_sum = tkl.Register[B, M, tkl.f32](0.0) +# init_max = tkl.Register[B, M, tkl.f32](-1e6) +# +# # This microkernel encodes the fact that if the reduction +# # dimension were tiled, then we would need to materialize a loop. +# @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) +# def repeat( +# partial_max: tkl.Register[B, M, tkl.f32], +# partial_sum: tkl.Register[B, M, tkl.f32], +# acc: tkl.Register[B, N, M, tkl.f32], +# ) -> ( +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, N, M, tkl.f32], +# ): +# imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) +# q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# inner_acc = tkw.mma(k_reg, q_reg, imm_reg) +# x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) +# m_j = tkw.max(x_j, partial_max, dim=K2) +# e_delta_max = tkw.exp2(partial_max - m_j) +# e_delta = tkw.exp2(x_j - m_j) +# e_init = partial_sum * e_delta_max +# d_j = tkw.sum(e_delta, e_init, dim=K2) +# imm_f16 = tkw.cast(e_delta, tkl.f16) +# v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# new_acc = acc * e_delta_max +# acc = tkw.mma(v_reg, imm_f16, new_acc) +# return m_j, d_j, acc +# +# # repeat represents the results of the loop +# res_max, res_sum, res_mm = repeat +# res = res_mm / res_sum +# tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) +# +# hyperparams = { +# ADDRESS_SPACE: SHARED_ADDRESS_SPACE, +# LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), +# STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), +# K1: shape[3], +# BLOCK_B: 1, +# BLOCK_M: 64, +# BLOCK_N: 64, +# BLOCK_K2: 32, +# READ_SHARED_DELAY: 1, +# WRITE_SHARED_DELAY: 1, +# READ_GLOBAL_DELAY: 2, +# WRITE_GLOBAL_DELAY: 2, +# MMA_DELAY: 1, +# VALU_DELAY: 1, +# SHUFFLE_DELAY: 1, +# SHARED_MEMORY_UNITS: 4, +# GLOBAL_MEMORY_UNITS: 4, +# MMA_UNITS: 4, +# VALU_UNITS: 2, +# SHUFFLE_UNITS: 2, +# } +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=True, +# use_scheduling_barriers=False, +# dynamic_symbols=(B, M, N, K2), +# dynamic_symbol_map={ +# B: shape[0], +# M: shape[1], +# N: shape[2], +# K2: shape[4], +# }, +# ): +# torch.manual_seed(0) +# q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) +# k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) +# v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) +# output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) +# print(dynamic_attention_pipelined(q, k, v, output).module_op) +# +# # CHECK-LABEL: func.func @dynamic_attention_pipelined +# # CHECK-COUNT-6: {{.*}} = vector.maskedload {{.*}} +# # CHECK: {{.*}} = scf.for +# # CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}} +# # CHECK-COUNT-14: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-7: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-2: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-16: vector.maskedstore {{.*}} +# +# +# @run_test +# def test_attention_pipelined(): +# shape = (8, 128, 128, 64, 256) +# # Expose user-constraints +# constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] +# constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] +# constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] +# constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] +# constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] +# constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +# +# mfma_variant = tkw.MMAType.F32_16x16x16_F16 +# constraints += [ +# tkw.HardwareConstraint( +# threads_per_wave=64, +# waves_per_block=(2, 2, 1), +# mma_type=mfma_variant, +# vector_shapes={B: 0, M: 16, N: 16}, +# ) +# ] +# +# i = tkw.IndexMapping.iterator(0) +# j = tkw.IndexMapping.iterator(1) +# k = tkw.IndexMapping.iterator(2) +# mapping = tkw.IndexMapping( +# num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} +# ) +# +# @tkw.wave(constraints) +# def base_attention_pipelined( +# q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], +# k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], +# v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], +# ): +# c_reg = tkl.Register[B, N, M, tkl.f32](0.0) +# init_sum = tkl.Register[B, M, tkl.f32](0.0) +# init_max = tkl.Register[B, M, tkl.f32](-1e6) +# +# # This microkernel encodes the fact that if the reduction +# # dimension were tiled, then we would need to materialize a loop. +# @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) +# def repeat( +# partial_max: tkl.Register[B, M, tkl.f32], +# partial_sum: tkl.Register[B, M, tkl.f32], +# acc: tkl.Register[B, N, M, tkl.f32], +# ) -> ( +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, N, M, tkl.f32], +# ): +# imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) +# q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# inner_acc = tkw.mma(k_reg, q_reg, imm_reg) +# x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) +# m_j = tkw.max(x_j, partial_max, dim=K2) +# e_delta_max = tkw.exp2(partial_max - m_j) +# e_delta = tkw.exp2(x_j - m_j) +# e_init = partial_sum * e_delta_max +# d_j = tkw.sum(e_delta, e_init, dim=K2) +# imm_f16 = tkw.cast(e_delta, tkl.f16) +# v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# new_acc = acc * e_delta_max +# acc = tkw.mma(v_reg, imm_f16, new_acc) +# return m_j, d_j, acc +# +# # repeat represents the results of the loop +# res_max, res_sum, res_mm = repeat +# res = res_mm / res_sum +# tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) +# +# hyperparams = { +# ADDRESS_SPACE: SHARED_ADDRESS_SPACE, +# LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), +# STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), +# BLOCK_B: 1, +# BLOCK_M: 64, +# BLOCK_N: 64, +# BLOCK_K2: 32, +# B: shape[0], +# M: shape[1], +# N: shape[2], +# K1: shape[3], +# K2: shape[4], +# READ_SHARED_DELAY: 1, +# WRITE_SHARED_DELAY: 1, +# READ_GLOBAL_DELAY: 2, +# WRITE_GLOBAL_DELAY: 2, +# MMA_DELAY: 1, +# VALU_DELAY: 1, +# SHUFFLE_DELAY: 1, +# SHARED_MEMORY_UNITS: 4, +# GLOBAL_MEMORY_UNITS: 4, +# MMA_UNITS: 4, +# VALU_UNITS: 2, +# SHUFFLE_UNITS: 2, +# } +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=True, +# use_scheduling_barriers=False, +# ): +# torch.manual_seed(0) +# q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) +# k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) +# v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) +# output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) +# print(base_attention_pipelined(q, k, v, output).module_op) +# +# # CHECK-LABEL: func.func @base_attention_pipelined +# # CHECK: {{.*}} = scf.for +# # CHECK-COUNT-14: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-7: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} +# # CHECK-COUNT-2: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} +# +# +# @run_test +# def test_flash_decoding(): +# shape = (8, 128, 128, 64, 256) +# mfma_variant = tkw.MMAType.F32_16x16x16_F16 +# +# class Phase(Enum): +# QK = (0,) +# SOFTMAX_V = (1,) +# +# def get_constraints(phase: Phase) -> list[tkw.Constraint]: +# constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] +# constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] +# constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] +# if phase == Phase.QK: +# constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 1)] +# constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / 2)] +# vector_shapes = {B: 0, M: 16, K2: 16} +# else: +# constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] +# constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] +# constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +# vector_shapes = {B: 0, M: 16, N: 16} +# constraints += [ +# tkw.HardwareConstraint( +# threads_per_wave=64, +# waves_per_block=(2, 2, 1), +# mma_type=mfma_variant, +# vector_shapes=vector_shapes, +# ) +# ] +# return constraints +# +# # The first kernel computes Q @ K.T. +# @tkw.wave(get_constraints(Phase.QK)) +# def qk_kernel( +# q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], +# k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], +# ): +# c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) +# q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# acc = tkw.mma(k_reg, q_reg, c_reg) +# x_j = tkw.permute(acc, target_shape=[B, M, K2]) +# tkw.write(x_j, c, elements_per_thread=STORE_ELEMS_PER_THREAD) +# +# # The second kernel computes the softmax and V @ softmax(Q @ K.T). +# i = tkw.IndexMapping.iterator(0) +# j = tkw.IndexMapping.iterator(1) +# k = tkw.IndexMapping.iterator(2) +# mapping = tkw.IndexMapping( +# num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} +# ) +# +# @tkw.wave(get_constraints(Phase.SOFTMAX_V)) +# def softmax_v_kernel( +# qk: tkl.Memory[B, M, K2, ADDRESS_SPACE, tkl.f32], +# v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], +# c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], +# ): +# c_reg = tkl.Register[B, N, M, tkl.f32](0.0) +# init_sum = tkl.Register[B, M, tkl.f32](0.0) +# init_max = tkl.Register[B, M, tkl.f32](-1e6) +# +# # This microkernel encodes the fact that if the reduction +# # dimension were tiled, then we would need to materialize a loop. +# @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) +# def repeat( +# partial_max: tkl.Register[B, M, tkl.f32], +# partial_sum: tkl.Register[B, M, tkl.f32], +# acc: tkl.Register[B, N, M, tkl.f32], +# ) -> ( +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, M, tkl.f32], +# tkl.Register[B, N, M, tkl.f32], +# ): +# x_j = tkw.read(qk, elements_per_thread=STORE_ELEMS_PER_THREAD) +# m_j = tkw.max(x_j, partial_max, dim=K2) +# e_delta_max = tkw.exp2(partial_max - m_j) +# e_delta = tkw.exp2(x_j - m_j) +# e_init = partial_sum * e_delta_max +# d_j = tkw.sum(e_delta, e_init, dim=K2) +# imm_f16 = tkw.cast(e_delta, tkl.f16) +# v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) +# new_acc = acc * e_delta_max +# acc = tkw.mma(v_reg, imm_f16, new_acc) +# return m_j, d_j, acc +# +# # repeat represents the results of the loop +# res_max, res_sum, res_mm = repeat +# res = res_mm / res_sum +# tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) +# +# hyperparams = { +# ADDRESS_SPACE: SHARED_ADDRESS_SPACE, +# LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), +# STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), +# BLOCK_B: 1, +# BLOCK_M: 64, +# BLOCK_N: 64, +# BLOCK_K2: 32, +# B: shape[0], +# M: shape[1], +# N: shape[2], +# K1: shape[3], +# K2: shape[4], +# READ_SHARED_DELAY: 1, +# WRITE_SHARED_DELAY: 1, +# READ_GLOBAL_DELAY: 2, +# WRITE_GLOBAL_DELAY: 2, +# MMA_DELAY: 1, +# VALU_DELAY: 1, +# SHUFFLE_DELAY: 1, +# SHARED_MEMORY_UNITS: 4, +# GLOBAL_MEMORY_UNITS: 4, +# MMA_UNITS: 4, +# VALU_UNITS: 2, +# SHUFFLE_UNITS: 2, +# } +# +# torch.manual_seed(0) +# q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) +# k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) +# v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) +# qkt = torch.zeros(shape[0], shape[1], shape[4], dtype=torch.float32) +# output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=False, +# use_scheduling_barriers=False, +# ): +# print(qk_kernel(q, k, qkt).module_op) +# +# # CHECK: func.func @qk_kernel +# # CHECK-NOT: {{.*}} = scf.for +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-8: {{.*}} = vector.load +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-4: {{.*}} = vector.load +# # CHECK-COUNT-8: {{.*}} = amdgpu.mfma +# # CHECK-COUNT-2: vector.store +# +# with tk.gen.TestLaunchContext( +# hyperparams, +# canonicalize=True, +# run=False, +# run_bench=False, +# schedule=False, +# use_scheduling_barriers=False, +# ): +# print(softmax_v_kernel(qkt, v, output).module_op) +# +# # CHECK: func.func @softmax_v_kernel +# # CHECK: {{.*}} = scf.for +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-1: {{.*}} = vector.load +# # CHECK-COUNT-1: vector.store +# # CHECK-COUNT-4: {{.*}} = vector.load +# # CHECK-COUNT-4: {{.*}} = gpu.shuffle +# # CHECK-COUNT-4: {{.*}} = arith.subf +# # CHECK-COUNT-4: {{.*}} = math.exp2 +# # CHECK-COUNT-4: {{.*}} = gpu.shuffle +# # CHECK-COUNT-8: {{.*}} = amdgpu.mfma +# -# This test sets all the dimensions except K1 to be dynamic. -# The reason why we can't set K1 to be dynamic is because K1 is the -# tile size we use for expanding the K1 MMA. We could set K1 to be -# dynamic if we tiled the K1 dimension with a tile size of BLOCK_K1. @run_test -def test_dynamic_attention_pipelined(): - shape = (8, 128, 128, 64, 256) - # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] +def test_flash_paged_decoding(): + shape = (32, 6, 128, 128, 256) - mfma_variant = tkw.MMAType.F32_16x16x16_F16 - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(2, 2, 1), - mma_type=mfma_variant, - vector_shapes={B: 0, M: 16, N: 16}, - ) - ] - - constraints += [tkw.Assumption(K2 > 4 * BLOCK_K2)] - - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} - ) - - @tkw.wave(constraints) - def dynamic_attention_pipelined( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], - v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, M, tkl.f32](0.0) - init_max = tkl.Register[B, M, tkl.f32](-1e6) - - # This microkernel encodes the fact that if the reduction - # dimension were tiled, then we would need to materialize a loop. - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, M, tkl.f32], - partial_sum: tkl.Register[B, M, tkl.f32], - acc: tkl.Register[B, N, M, tkl.f32], - ) -> ( - tkl.Register[B, M, tkl.f32], - tkl.Register[B, M, tkl.f32], - tkl.Register[B, N, M, tkl.f32], - ): - imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) - q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) - inner_acc = tkw.mma(k_reg, q_reg, imm_reg) - x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) - m_j = tkw.max(x_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(x_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc - - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), - STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), - K1: shape[3], - BLOCK_B: 1, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K2: 32, - READ_SHARED_DELAY: 1, - WRITE_SHARED_DELAY: 1, - READ_GLOBAL_DELAY: 2, - WRITE_GLOBAL_DELAY: 2, - MMA_DELAY: 1, - VALU_DELAY: 1, - SHUFFLE_DELAY: 1, - SHARED_MEMORY_UNITS: 4, - GLOBAL_MEMORY_UNITS: 4, - MMA_UNITS: 4, - VALU_UNITS: 2, - SHUFFLE_UNITS: 2, - } - - with tk.gen.TestLaunchContext( - hyperparams, - canonicalize=True, - run=False, - run_bench=False, - schedule=True, - use_scheduling_barriers=False, - dynamic_symbols=(B, M, N, K2), - dynamic_symbol_map={ - B: shape[0], - M: shape[1], - N: shape[2], - K2: shape[4], - }, - ): - torch.manual_seed(0) - q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) - k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) - v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) - output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) - print(dynamic_attention_pipelined(q, k, v, output).module_op) - - # CHECK-LABEL: func.func @dynamic_attention_pipelined - # CHECK-COUNT-6: {{.*}} = vector.maskedload {{.*}} - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}} - # CHECK-COUNT-14: {{.*}} = amdgpu.mfma - # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-7: {{.*}} = amdgpu.mfma - # CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-2: {{.*}} = amdgpu.mfma - # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-16: vector.maskedstore {{.*}} - - -@run_test -def test_attention_pipelined(): - shape = (8, 128, 128, 64, 256) - # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] - - mfma_variant = tkw.MMAType.F32_16x16x16_F16 - constraints += [ - tkw.HardwareConstraint( - threads_per_wave=64, - waves_per_block=(2, 2, 1), - mma_type=mfma_variant, - vector_shapes={B: 0, M: 16, N: 16}, - ) - ] - - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} - ) - - @tkw.wave(constraints) - def base_attention_pipelined( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], - v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, M, tkl.f32](0.0) - init_max = tkl.Register[B, M, tkl.f32](-1e6) - - # This microkernel encodes the fact that if the reduction - # dimension were tiled, then we would need to materialize a loop. - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, M, tkl.f32], - partial_sum: tkl.Register[B, M, tkl.f32], - acc: tkl.Register[B, N, M, tkl.f32], - ) -> ( - tkl.Register[B, M, tkl.f32], - tkl.Register[B, M, tkl.f32], - tkl.Register[B, N, M, tkl.f32], - ): - imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) - q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) - inner_acc = tkw.mma(k_reg, q_reg, imm_reg) - x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) - m_j = tkw.max(x_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(x_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc - - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), - STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), - BLOCK_B: 1, - BLOCK_M: 64, - BLOCK_N: 64, - BLOCK_K2: 32, - B: shape[0], - M: shape[1], - N: shape[2], - K1: shape[3], - K2: shape[4], - READ_SHARED_DELAY: 1, - WRITE_SHARED_DELAY: 1, - READ_GLOBAL_DELAY: 2, - WRITE_GLOBAL_DELAY: 2, - MMA_DELAY: 1, - VALU_DELAY: 1, - SHUFFLE_DELAY: 1, - SHARED_MEMORY_UNITS: 4, - GLOBAL_MEMORY_UNITS: 4, - MMA_UNITS: 4, - VALU_UNITS: 2, - SHUFFLE_UNITS: 2, - } - - with tk.gen.TestLaunchContext( - hyperparams, - canonicalize=True, - run=False, - run_bench=False, - schedule=True, - use_scheduling_barriers=False, - ): - torch.manual_seed(0) - q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) - k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) - v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) - output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) - print(base_attention_pipelined(q, k, v, output).module_op) - - # CHECK-LABEL: func.func @base_attention_pipelined - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-14: {{.*}} = amdgpu.mfma - # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-7: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-2: {{.*}} = amdgpu.mfma - # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} - - -@run_test -def test_flash_decoding(): - shape = (8, 128, 128, 64, 256) + R1 = 8196 + SEQ_LEN = 1024 + NUM_HEADS = 6 + HEAD_TILE_SIZE = 16 mfma_variant = tkw.MMAType.F32_16x16x16_F16 class Phase(Enum): @@ -481,95 +662,171 @@ class Phase(Enum): def get_constraints(phase: Phase) -> list[tkw.Constraint]: constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + m_ratio = 1 + n_ratio = 2 if phase == Phase.QK: + # Distribute the batch, head and sequence length dimensions across the workgroups. constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 1)] - constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / 2)] + constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / n_ratio)] vector_shapes = {B: 0, M: 16, K2: 16} else: constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / n_ratio)] vector_shapes = {B: 0, M: 16, N: 16} constraints += [ tkw.HardwareConstraint( threads_per_wave=64, - waves_per_block=(2, 2, 1), + waves_per_block=(m_ratio, n_ratio, 1), mma_type=mfma_variant, vector_shapes=vector_shapes, ) ] return constraints - # The first kernel computes Q @ K.T. + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + l = tkw.IndexMapping.iterator(3) + d0 = tkw.IndexMapping.dynamic_val(0) + + # Load a specific element from the request_to_tokens matrix. + # The request_to_tokens matrix has shape [R0, R1] and we are loading a single element. + request_to_tokens_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={B: d0, K2: j}, + outputs={B: i, K2: j}, + dynamic_val_mappings={B: i}, + ) + + # Broadcast the offset along the batch dimension. + k_mapping = tkw.IndexMapping( + num_iterators=4, + inputs={B: d0, M: j / sympy.ceiling(NUM_HEADS / HEAD_TILE_SIZE), K1: l}, + outputs={B: i, M: j, K2: k, K1: l}, + dynamic_val_mappings={B: i // LOAD_ELEMS_PER_THREAD, K2: k}, + ) + + # Broadcast the offset along the batch dimension. + output_mapping = tkw.IndexMapping( + num_iterators=3, + inputs={B: i, M: j, K2: k}, + outputs={B: d0, M: j, K2: k}, + dynamic_val_mappings={B: i // STORE_ELEMS_PER_THREAD}, + ) + + # The first kernel computes Q @ K.T, after loading K from the K-cache. + # Say the batch dimension B = 32, num heads M = 6 and head dimension K1 = 128. + # The shape of the Q matrix is [32, 6, 128] / [16, 128] (BLOCK_B: 1, BLOCK_M: 16). + # The K cache has a much larger first dimension that could be on the order of + # ~O(10^7). Since we wil always be accessing this dimension using a dynamic variable, + # and will always be loading [K2, K1] from the K cache, we can set the first dimension + # to be B. # After loading the K matrix, it would be of shape : [32, ?, 128] / [64, 128] + # (with a BLOCK_K2 = 64) where we are using the same batch dimension for all the K vectors. + + # The request to tokens matrix maps from logical to physical indices in the K cache. + # In general it has a shape [R0, R1], but since we are only loadin B offsets we represent it + # with shape [B] and use R0, R1 in the mapping to access the appropriate index. In general, + # we can always represent an n-D matrix as a 1-D matrix. + # The request indices and offsets have shape [B] and are used to index into the request_to_tokens matrix. + + # Finally, the output matrix is of shape [B, M, K2] where K2 is the maximum sequence length. + # (So it would be of shape [32, 6, 128] / [16, 64]). + @tkw.wave(get_constraints(Phase.QK)) def qk_kernel( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], - k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16, (SEQ_LEN, K1, 1)], + k_cache: tkl.Memory[B, M, K2, K1, ADDRESS_SPACE, tkl.f16, (K1, K1, K1, 1)], + request_to_tokens: tkl.Memory[B, K2, GLOBAL_ADDRESS_SPACE, tkl.i32, (R1, 1)], + request_indices: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + request_offsets: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + output: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], ): c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + req_idx_reg = tkw.read(request_indices, elements_per_thread=1) q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) - k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_offsets = tkw.read( + request_to_tokens, + elements_per_thread=1, + mapping=request_to_tokens_mapping, + mapping_dynamic_vals=(req_idx_reg,), + ) + k_reg = tkw.read( + k_cache, + elements_per_thread=LOAD_ELEMS_PER_THREAD, + mapping=k_mapping, + mapping_dynamic_vals=(k_offsets,), + ) acc = tkw.mma(k_reg, q_reg, c_reg) x_j = tkw.permute(acc, target_shape=[B, M, K2]) - tkw.write(x_j, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - - # The second kernel computes the softmax and V @ softmax(Q @ K.T). - i = tkw.IndexMapping.iterator(0) - j = tkw.IndexMapping.iterator(1) - k = tkw.IndexMapping.iterator(2) - mapping = tkw.IndexMapping( - num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} - ) - - @tkw.wave(get_constraints(Phase.SOFTMAX_V)) - def softmax_v_kernel( - qk: tkl.Memory[B, M, K2, ADDRESS_SPACE, tkl.f32], - v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[B, N, M, tkl.f32](0.0) - init_sum = tkl.Register[B, M, tkl.f32](0.0) - init_max = tkl.Register[B, M, tkl.f32](-1e6) - - # This microkernel encodes the fact that if the reduction - # dimension were tiled, then we would need to materialize a loop. - @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) - def repeat( - partial_max: tkl.Register[B, M, tkl.f32], - partial_sum: tkl.Register[B, M, tkl.f32], - acc: tkl.Register[B, N, M, tkl.f32], - ) -> ( - tkl.Register[B, M, tkl.f32], - tkl.Register[B, M, tkl.f32], - tkl.Register[B, N, M, tkl.f32], - ): - x_j = tkw.read(qk, elements_per_thread=STORE_ELEMS_PER_THREAD) - m_j = tkw.max(x_j, partial_max, dim=K2) - e_delta_max = tkw.exp2(partial_max - m_j) - e_delta = tkw.exp2(x_j - m_j) - e_init = partial_sum * e_delta_max - d_j = tkw.sum(e_delta, e_init, dim=K2) - imm_f16 = tkw.cast(e_delta, tkl.f16) - v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) - new_acc = acc * e_delta_max - acc = tkw.mma(v_reg, imm_f16, new_acc) - return m_j, d_j, acc + req_off_reg = tkw.read(request_offsets, elements_per_thread=1) + tkw.write( + x_j, + output, + elements_per_thread=STORE_ELEMS_PER_THREAD, + mapping=output_mapping, + mapping_dynamic_vals=(req_off_reg,), + ) - # repeat represents the results of the loop - res_max, res_sum, res_mm = repeat - res = res_mm / res_sum - tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) + ## The second kernel computes the softmax and V @ softmax(Q @ K.T). + # i = tkw.IndexMapping.iterator(0) + # j = tkw.IndexMapping.iterator(1) + # k = tkw.IndexMapping.iterator(2) + # mapping = tkw.IndexMapping( + # num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + # ) + + # @tkw.wave(get_constraints(Phase.SOFTMAX_V)) + # def softmax_v_kernel( + # qk: tkl.Memory[B, H, M, K2, ADDRESS_SPACE, tkl.f32], + # v: tkl.Memory[BV, K2, ADDRESS_SPACE, tkl.f16], + # request_to_tokens: tkl.Memory[R0, R1, ADDRESS_SPACE, tkl.f16], + # request_indices: tkl.Memory[B, ADDRESS_SPACE, tkl.i32], + # request_offsets: tkl.Memory[B, ADDRESS_SPACE, tkl.i32], + # output: tkl.Memory[B, H, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + # ): + # c_reg = tkl.Register[B, H, N, M, tkl.f32](0.0) + # init_sum = tkl.Register[B, H, M, tkl.f32](0.0) + # init_max = tkl.Register[B, H, M, tkl.f32](-1e6) + + # @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + # def repeat( + # partial_max: tkl.Register[B, H, M, tkl.f32], + # partial_sum: tkl.Register[B, H, M, tkl.f32], + # acc: tkl.Register[B, H, N, M, tkl.f32], + # ) -> ( + # tkl.Register[B, H, M, tkl.f32], + # tkl.Register[B, H, M, tkl.f32], + # tkl.Register[B, H, N, M, tkl.f32], + # ): + # x_j = tkw.read(qk, elements_per_thread=STORE_ELEMS_PER_THREAD) + # m_j = tkw.max(x_j, partial_max, dim=K2) + # e_delta_max = tkw.exp2(partial_max - m_j) + # e_delta = tkw.exp2(x_j - m_j) + # e_init = partial_sum * e_delta_max + # d_j = tkw.sum(e_delta, e_init, dim=K2) + # imm_f16 = tkw.cast(e_delta, tkl.f16) + # v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # new_acc = acc * e_delta_max + # acc = tkw.mma(v_reg, imm_f16, new_acc) + # return m_j, d_j, acc + + # # repeat represents the results of the loop + # res_max, res_sum, res_mm = repeat + # res = res_mm / res_sum + # tkw.write( + # res, output, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD + # ) hyperparams = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), BLOCK_B: 1, - BLOCK_M: 64, + BLOCK_M: 16, # Read heads in blocks of 16 BLOCK_N: 64, - BLOCK_K2: 32, + BLOCK_K2: 64, # Sequence length in blocks of 64 B: shape[0], M: shape[1], N: shape[2], @@ -619,28 +876,28 @@ def repeat( # CHECK-COUNT-8: {{.*}} = amdgpu.mfma # CHECK-COUNT-2: vector.store - with tk.gen.TestLaunchContext( - hyperparams, - canonicalize=True, - run=False, - run_bench=False, - schedule=False, - use_scheduling_barriers=False, - ): - print(softmax_v_kernel(qkt, v, output).module_op) - - # CHECK: func.func @softmax_v_kernel - # CHECK: {{.*}} = scf.for - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-1: {{.*}} = vector.load - # CHECK-COUNT-1: vector.store - # CHECK-COUNT-4: {{.*}} = vector.load - # CHECK-COUNT-4: {{.*}} = gpu.shuffle - # CHECK-COUNT-4: {{.*}} = arith.subf - # CHECK-COUNT-4: {{.*}} = math.exp2 - # CHECK-COUNT-4: {{.*}} = gpu.shuffle - # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + # with tk.gen.TestLaunchContext( + # hyperparams, + # canonicalize=True, + # run=False, + # run_bench=False, + # schedule=False, + # use_scheduling_barriers=False, + # ): + # print(softmax_v_kernel(qkt, v, output).module_op) + + # CHECK: func.func @softmax_v_kernel + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-1: {{.*}} = vector.load + # CHECK-COUNT-1: vector.store + # CHECK-COUNT-1: {{.*}} = vector.load + # CHECK-COUNT-1: vector.store + # CHECK-COUNT-4: {{.*}} = vector.load + # CHECK-COUNT-4: {{.*}} = gpu.shuffle + # CHECK-COUNT-4: {{.*}} = arith.subf + # CHECK-COUNT-4: {{.*}} = math.exp2 + # CHECK-COUNT-4: {{.*}} = gpu.shuffle + # CHECK-COUNT-8: {{.*}} = amdgpu.mfma @run_test diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 3c8a57df..e3b74188 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -23,6 +23,7 @@ get_mfma_store_elems_per_thread, device_randn, device_zeros, + device_randint, ) from iree.turbine.kernel.wave.constraints import MMAType import os @@ -1136,3 +1137,269 @@ def repeat( f.write(mb_sv.module_op.get_asm()) assert_allclose(output, torch_ref) + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_attention")) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize("dynamic_dims", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.F32_16x16x16_F16, + ], +) +def testPagedFlashDecoding( + shape: tuple[int], + enable_scheduling: bool, + dynamic_dims: bool, + mfma_variant: MMAType, + request, +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + class Phase(Enum): + QK = (0,) + SOFTMAX_V = (1,) + + def get_constraints(phase: Phase) -> list[tkw.Constraint]: + if mfma_variant == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + ratio_m = 1 + ratio_n = 2 + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)] + if phase == Phase.QK: + constraints += [tkw.WorkgroupConstraint(K2, BLOCK_K2, 1)] + constraints += [tkw.WaveConstraint(K2, BLOCK_K2 / ratio_n)] + vector_shapes = {B: 0, M: Mvec, K2: Nvec} + else: + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / ratio_n)] + vector_shapes = {B: 0, M: Mvec, N: Nvec} + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(ratio_m, ratio_n, 1), + mma_type=mfma_variant, + vector_shapes=vector_shapes, + ) + ] + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + return constraints + + # Shape of logical to physical mapping table. + R0 = 4097 + R1 = 8196 + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + d0 = tkw.IndexMapping.dynamic_val(0) + + # Load a specific element from the request_to_tokens matrix. + # The request_to_tokens matrix has shape [R0, R1] and we are loading a single element. + request_to_tokens_mapping = tkw.IndexMapping( + num_iterators=1, + inputs={B: d0 * R1}, + outputs={B: i}, + dynamic_val_mappings={B: i}, + ) + + # Broadcast the offset along the batch dimension. + k_mapping = tkw.IndexMapping( + num_iterators=3, + inputs={B: d0, K2: j, K1: k}, + outputs={B: i, K2: j, K1: k}, + dynamic_val_mappings={B: i // LOAD_ELEMS_PER_THREAD}, + ) + + # Broadcast the offset along the batch dimension. + output_mapping = tkw.IndexMapping( + num_iterators=3, + inputs={B: i, M: j, K2: k}, + outputs={B: d0, M: j, K2: k}, + dynamic_val_mappings={B: i // STORE_ELEMS_PER_THREAD}, + ) + + # The first kernel computes K @ Q.T. + @tkw.wave(get_constraints(Phase.QK)) + def qk_kernel( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k_cache: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + request_to_tokens: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + request_indices: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + request_offsets: tkl.Memory[B, GLOBAL_ADDRESS_SPACE, tkl.i32], + output: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + req_idx_reg = tkw.read(request_indices, elements_per_thread=1) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_offsets = tkw.read( + request_to_tokens, + elements_per_thread=1, + mapping=request_to_tokens_mapping, + mapping_dynamic_vals=(req_idx_reg,), + ) + k_reg = tkw.read( + k_cache, + elements_per_thread=LOAD_ELEMS_PER_THREAD, + mapping=k_mapping, + mapping_dynamic_vals=(k_offsets,), + ) + acc = tkw.mma(k_reg, q_reg, c_reg) + x_j = tkw.permute(acc, target_shape=[B, M, K2]) + req_off_reg = tkw.read(request_offsets, elements_per_thread=1) + tkw.write( + x_j, + output, + elements_per_thread=STORE_ELEMS_PER_THREAD, + mapping=output_mapping, + mapping_dynamic_vals=(req_off_reg,), + ) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + } + hyperparams.update(get_default_scheduling_params()) + config = get_default_run_config() + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + torch.manual_seed(0) + + q = device_randn(shape[0], shape[1], shape[3], dtype=torch.float16) + + # Construct synthetic page tables for key and value matrices. + max_entries = 32 + max_seq_len = 1025 + total_entries = max_entries * max_seq_len + request_to_tokens = device_randint(0, max_entries, (R0, R1), dtype=torch.int32) + request_indices = device_randint(0, R0, (shape[0],), dtype=torch.int32) + request_offsets = device_randint(0, shape[4], (shape[0],), dtype=torch.int32) + k_cache = device_randn(total_entries, shape[3], dtype=torch.float16) + + def extract_page_table_entries(request_to_tokens, request_indices, request_offsets): + entries = [] + for i in range(request_to_tokens.shape[0]): + idx = request_indices[i] + offset = request_offsets[i] + entries.append(request_to_tokens[i, idx] + offset) + return torch.tensor(entries, dtype=torch.int32) + + k = extract_page_table_entries(request_to_tokens, request_indices, request_offsets) + qk = device_zeros(shape[0], shape[1], shape[4], dtype=torch.float32) + + # v = device_randn(shape[0], shape[4], shape[2], dtype=torch.float16) + # output = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / shape[3]) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + dynamic_symbols=dynamic_symbols, + dynamic_symbols_map=dynamic_symbols_map, + ): + # TODO: Add scaling of QK as part of kernel. + mb_qk = qk_kernel( + q * dk_sqrt * log2e, + k_cache, + request_to_tokens, + request_indices, + request_offsets, + qk, + ) + + torch_ref = torch.matmul(q, k.permute([0, 2, 1])) * dk_sqrt * log2e + assert_allclose(qk, torch_ref.permute([0, 2, 1])) + + # with tk.gen.TestLaunchContext( + # hyperparams, + # canonicalize=True, + # run=True, + # run_bench=run_bench, + # run_config=config, + # schedule=enable_scheduling, + # use_scheduling_barriers=enable_scheduling_barriers, + # dynamic_symbols=dynamic_symbols, + # dynamic_symbols_map=dynamic_symbols_map, + # ): + # # TODO: Add variant of non-transposed V attention kernel. + # mb_sv = softmax_v_kernel(qk, v.permute([0, 2, 1]), output) + + # torch_ref = torch.nn.functional.scaled_dot_product_attention( + # q, k, v, attn_mask=None + # ) + + if test_dump_generated_mlir: + filename = f"wave_qk_kernel_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_qk.module_op.get_asm()) + # filename = f"wave_softmax_v_kernel_{'x'.join(map(str, shape))}.mlir" + # with open(filename, "w") as f: + # f.write(mb_sv.module_op.get_asm()) + + # assert_allclose(output, torch_ref)