Skip to content

Commit

Permalink
[tuner]: add code comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Dec 11, 2024
1 parent bc88d23 commit e7fd21a
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,34 @@ def get_lowering_config(
return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs)


# Generate a config dictionary in translation info
"""
Example IR
translation_info = #iree_codegen.translation_info<
pipeline = LLVMGPUVectorDistribute
workgroup_size = [512, 1, 1]
subgroup_size = 64,
{gpu_pipeline_options = #iree_gpu.pipeline_options<>
llvm_func_attrs = {"amdgpu-waves-per-eu" = "3"}
}
>
Example Usage:
pipeline_options = iree_gpu.PipelineOptionsAttr.get(...)
waves_per_eu = 3
config_dict = get_translation_info_config(
pipeline_options=pipeline_options,
waves_per_eu=waves_per_eu
)
this 'config_dict' is subsequently used afterward to generate the 'translation_info' in the above example IR.
"""


def get_translation_info_config(
pipeline_options: iree_gpu.PipelineOptionsAttr, waves_per_eu: int | str
pipeline_options: iree_gpu.PipelineOptionsAttr, waves_per_eu: int
) -> ir.DictAttr:
if isinstance(waves_per_eu, int):
waves_per_eu = str(waves_per_eu)
elif not isinstance(waves_per_eu, str):
assert (
False
), f"waves_per_eu must be an int or str, but got {type(waves_per_eu).__name__}"
waves_per_eu = str(waves_per_eu)

# Create the waves_per_eu dictionary attribute.
waves_per_eu_dict = ir.DictAttr.get(
Expand Down

0 comments on commit e7fd21a

Please sign in to comment.