Skip to content

Commit

Permalink
Add paged flash decoding kernel
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Dec 11, 2024
1 parent 71eb1c8 commit 9eacda3
Show file tree
Hide file tree
Showing 8 changed files with 1,073 additions and 525 deletions.
3 changes: 3 additions & 0 deletions iree/turbine/kernel/lang/kernel_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,23 @@ def new_subtype(
address_space: AddressSpace | NotSetType = NotSet,
symbolic_shape: tuple[IndexExpr, ...] | NotSetType = NotSet,
dtype: DataType | NotSetType = NotSet,
stride: tuple[IndexExpr, ...] | NotSetType = NotSet,
usage: KernelBufferUsage | NotSetType = NotSet,
) -> Type[SubtypeT]:
init_address_space = (
address_space if address_space else AddressSpace.GLOBAL_MEMORY
)
init_symbolic_shape = symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape # type: ignore
init_dtype = dtype if dtype is not NotSet else cls.dtype # type: ignore
init_stride = stride if stride else None
init_usage = usage if usage is not NotSet else cls.usage # type: ignore

class SubType(cls):
address_space = init_address_space
symbolic_shape = init_symbolic_shape
rank = len(init_symbolic_shape) # type: ignore
dtype = init_dtype
stride = init_stride
usage = init_usage

if name is not NotSet:
Expand Down
14 changes: 11 additions & 3 deletions iree/turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Memory(metaclass=KernelBufferMeta):
symbolic_shape: ClassVar[tuple[IndexExpr, ...]]
rank: ClassVar[int]
dtype: ClassVar[DataType]
stride: ClassVar[Optional[tuple[IndexExpr, ...]]]
usage: ClassVar[Optional[KernelBufferUsage]]

def __init__(self) -> None:
Expand All @@ -55,9 +56,15 @@ def __class_getitem__(

shift = 0
usage = KernelBufferUsage.NONE
if isinstance(shape_and_dtype[-1], KernelBufferUsage):
shift = 1
usage = shape_and_dtype[-1]
last_dim = -1
if isinstance(shape_and_dtype[last_dim], KernelBufferUsage):
shift += 1
usage = shape_and_dtype[last_dim]
last_dim -= 1
stride = None
if isinstance(shape_and_dtype[last_dim], Sequence):
shift += 1
stride = shape_and_dtype[last_dim]
shape = shape_and_dtype[: -2 - shift]
addressSpace = shape_and_dtype[-2 - shift]
dtype = shape_and_dtype[-1 - shift]
Expand Down Expand Up @@ -85,6 +92,7 @@ def __class_getitem__(
address_space=addressSpace,
symbolic_shape=shape,
dtype=dtype,
stride=stride,
usage=usage,
)

Expand Down
14 changes: 11 additions & 3 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def node_args(self) -> dict[int, Any]:
for i, arg in enumerate(self.fx_node.args):
if isinstance(arg, fx.Node):
custom_args[i] = get_custom(arg)
if isinstance(arg, list) and all(isinstance(x, fx.Node) for x in arg):
if isinstance(arg, Sequence) and all(isinstance(x, fx.Node) for x in arg):
custom_args[i] = [get_custom(x) for x in arg]
return custom_args

Expand Down Expand Up @@ -1013,7 +1013,11 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}
return {
k: v.apply_expr(subs[k], mapping[k])
for k, v in index.items()
if k in mapping
}

return index

Expand Down Expand Up @@ -1253,7 +1257,11 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]
subs = {v: k for k, v in zip(iters, mapping.keys())}
return {k: v.apply_expr(subs[k], mapping[k]) for k, v in index.items()}
return {
k: v.apply_expr(subs[k], mapping[k])
for k, v in index.items()
if k in mapping
}

return index

Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def _expand_node(
for i, arg in node.node_args.items():
arg_list = arg
unpack = lambda x: x
if isinstance(arg, list):
if isinstance(arg, Sequence):
if not all(is_expandable(a) for a in arg):
continue
else:
Expand Down
8 changes: 5 additions & 3 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def has_gpr_offsets(node: fx.Node) -> bool:
{GPR_NUM: cur_gpr_start_id}
),
gpr_size,
1
if output_mapping[-1] == gpr_offset_dim
else simplified_index[gpr_offset_dim].stride,
(
1
if output_mapping[-1] == gpr_offset_dim
else simplified_index[gpr_offset_dim].stride
),
)
updated_index_with_gpr_offset[
gpr_offset_dim
Expand Down
9 changes: 6 additions & 3 deletions iree/turbine/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,12 @@ def add_optimized_nodes(
access_pattern: dict[IndexSymbol, IndexSequence] = custom.index
for i in range(expected_number_of_loads):
with custom.graph.inserting_before(custom.fx_node):
read = Read(memory, load_elems_per_thread, custom.mapping).add_to_graph(
custom.graph
)
read = Read(
memory,
load_elems_per_thread,
custom.mapping,
custom.mapping_dynamic_vals,
).add_to_graph(custom.graph)
global_offset = (
hardware_constraint.linearized_thread_id * load_elems_per_thread
+ i * max_elements_per_load
Expand Down
Loading

0 comments on commit 9eacda3

Please sign in to comment.