diff --git a/python/scalehls/transforms.py b/python/scalehls/transforms.py index efd7c7cb..64afc020 100644 --- a/python/scalehls/transforms.py +++ b/python/scalehls/transforms.py @@ -303,8 +303,11 @@ def annotate(target: Value, annotation: str, param=None): return transform.AnnotateOp(target, annotation, param=param) -def tile(linalg_op_handle: Value, sizes: Sequence[int]): - return linalg_transform.TileUsingForOp(linalg_op_handle, sizes=sizes) +def tile(linalg_op_handle: Value, + sizes: Sequence[int], + interchange: Union[Sequence[int], None] = None): + return linalg_transform.TileUsingForOp( + linalg_op_handle, sizes=sizes, interchange=interchange) def tile_reduction(linalg_op_handle: Value, sizes: Sequence[int]): @@ -400,10 +403,11 @@ def foreach_merge_consecutive_extract_slice_and_convert_to_itensor_read( convert_extract_slice_to_itensor_read(merge_op.result) -def convert_full_tensor_linalg_generic_to_itensor( +def convert_full_tensor_linalg_op_to_itensor( linalg_op_handle: Value, parallel_tile_sizes: Sequence[int], reduction_tile_sizes: Sequence[int], + unroll_sizes: Sequence[int], permutation: Sequence[int], has_input: bool = True, split_combine_reduction=False): @@ -448,35 +452,24 @@ def convert_full_tensor_linalg_generic_to_itensor( foreach_merge_consecutive_extract_slice_and_convert_to_itensor_read( matched_input) - # Interchange the loops of the linalg op with the given permutation. - interchange_op = interchange(linalg_op_handle, permutation) - linalg_op_handle = interchange_op.transformed + # Interchange and "unroll" the linalg op with the given unroll sizes. + unroll_op = tile(linalg_op_handle, unroll_sizes, permutation) + linalg_op_handle = unroll_op.tiled_linalg_op return linalg_op_handle -# ===----------------------------------------------------------------------=== # -# Design Space Exploration Utils -# ===----------------------------------------------------------------------=== # - - # ===----------------------------------------------------------------------=== # # DesignSpaceGraph Class # ===----------------------------------------------------------------------=== # class DesignSpaceGraph(nx.Graph): - def __init__(self, - module: Module, - top_name: str = "forward", - default_tile_size: int = 16, - default_unroll_size: int = 4): + def __init__(self, module: Module, top_name: str = "forward"): super().__init__() self.module = module self.top = find_func(self.module, top_name) if self.top is None: raise ValueError("top function `" + top_name + "` not found") - self.default_tile_size = default_tile_size - self.default_unroll_size = default_unroll_size self.add_node(self.top, name=self.top.OPERATION_NAME, id=-1) for id, op in enumerate(self.top.entry_block): @@ -495,21 +488,8 @@ def is_nontrivial_node(node: Operation): node, (arith.ConstantOp, hls.TensorInitOp, tensor.EmptyOp)) @staticmethod - def get_generic_op_naive_permutation(node: linalg.GenericOp): - loop_properties = extract_loop_properties(node) - numReduction = 0 - interchange_permutation = [] - for index, (_, type) in enumerate(loop_properties): - if type == "parallel": - interchange_permutation.append(index) - elif type == "reduction": - interchange_permutation.insert(numReduction, index) - numReduction += 1 - return interchange_permutation - - @staticmethod - def get_generic_op_naive_tile_sizes(node: linalg.GenericOp, - default_tile_size: int = 16): + def get_linalg_op_naive_tile_sizes(node: linalg.GenericOp, + default_tile_size: int = 16): loop_properties = extract_loop_properties(node) parallel_tile_sizes = [] @@ -524,6 +504,34 @@ def get_generic_op_naive_tile_sizes(node: linalg.GenericOp, reduction_tile_sizes.append(tile_size) return parallel_tile_sizes, reduction_tile_sizes + @staticmethod + def get_linalg_op_naive_unroll_sizes(node: linalg.GenericOp, + default_unroll_size: int = 2): + loop_properties = extract_loop_properties(node) + + unroll_sizes = [] + for range, type in loop_properties: + unroll_size = default_unroll_size if range > default_unroll_size else 0 + if type == "parallel": + unroll_sizes.append(unroll_size) + elif type == "reduction": + unroll_sizes.append(1) + return unroll_sizes + + @staticmethod + def get_linalg_op_naive_permutation(node: linalg.GenericOp): + loop_properties = extract_loop_properties(node) + + num_reduction = 0 + interchange_permutation = [] + for index, (_, type) in enumerate(loop_properties): + if type == "parallel": + interchange_permutation.append(index) + elif type == "reduction": + interchange_permutation.insert(num_reduction, index) + num_reduction += 1 + return interchange_permutation + @staticmethod def get_reshape_op_naive_tile_sizes( node: Union[tensor.ExpandShapeOp, tensor.CollapseShapeOp], @@ -544,17 +552,19 @@ def get_reshape_op_naive_tile_sizes( "Source tile sizes do not match result tile sizes") return source_tile_sizes, result_tile_sizes - def naive_exploration(self): + def naive_exploration(self, default_tile_size: int = 16, default_unroll_size: int = 2): for node, data in self.nodes(data=True): if isinstance(node, linalg.GenericOp): - data["parallel_tile_sizes"], data["reduction_tile_sizes"] = self.get_generic_op_naive_tile_sizes( - node, default_tile_size=self.default_tile_size) - data["permutation"] = self.get_generic_op_naive_permutation( + data["parallel_tile_sizes"], data["reduction_tile_sizes"] = self.get_linalg_op_naive_tile_sizes( + node, default_tile_size=default_tile_size) + data["unroll_sizes"] = self.get_linalg_op_naive_unroll_sizes( + node, default_unroll_size=default_unroll_size) + data["permutation"] = self.get_linalg_op_naive_permutation( node) if isinstance(node, (tensor.ExpandShapeOp, tensor.CollapseShapeOp)): data["source_tile_sizes"], data["result_tile_sizes"] = self.get_reshape_op_naive_tile_sizes( - node, default_tile_size=self.default_tile_size) + node, default_tile_size=default_unroll_size) def print_dot(self, file_name: str): dot = Digraph() @@ -586,13 +596,16 @@ def construct_transform_sequence(target: BlockArgument, raise ValueError("parallel_tile_sizes not found") if "reduction_tile_sizes" not in data: raise ValueError("reduction_tile_sizes not found") + if "unroll_sizes" not in data: + raise ValueError("unroll_sizes not found") if "permutation" not in data: raise ValueError("permutation not found") - linalg_op_handle = convert_full_tensor_linalg_generic_to_itensor( + linalg_op_handle = convert_full_tensor_linalg_op_to_itensor( node_handle, data["parallel_tile_sizes"], data["reduction_tile_sizes"], + data["unroll_sizes"], data["permutation"], len(node.inputs) > 0) annotate(linalg_op_handle, "id", i64_param(data["id"]))