From e7fd21a29835d6885293f807d08df5d313f86f13 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 11 Dec 2024 10:44:48 -0600 Subject: [PATCH] [tuner]: add code comments Signed-off-by: Bangtian Liu --- tuner/tuner/common.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 12b40805e..f82c3f5cc 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -163,15 +163,34 @@ def get_lowering_config( return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) +# Generate a config dictionary in translation info +""" +Example IR +translation_info = #iree_codegen.translation_info< + pipeline = LLVMGPUVectorDistribute + workgroup_size = [512, 1, 1] + subgroup_size = 64, + {gpu_pipeline_options = #iree_gpu.pipeline_options<> + llvm_func_attrs = {"amdgpu-waves-per-eu" = "3"} + } + > +Example Usage: + pipeline_options = iree_gpu.PipelineOptionsAttr.get(...) + waves_per_eu = 3 + + config_dict = get_translation_info_config( + pipeline_options=pipeline_options, + waves_per_eu=waves_per_eu + ) + +this 'config_dict' is subsequently used afterward to generate the 'translation_info' in the above example IR. +""" + + def get_translation_info_config( - pipeline_options: iree_gpu.PipelineOptionsAttr, waves_per_eu: int | str + pipeline_options: iree_gpu.PipelineOptionsAttr, waves_per_eu: int ) -> ir.DictAttr: - if isinstance(waves_per_eu, int): - waves_per_eu = str(waves_per_eu) - elif not isinstance(waves_per_eu, str): - assert ( - False - ), f"waves_per_eu must be an int or str, but got {type(waves_per_eu).__name__}" + waves_per_eu = str(waves_per_eu) # Create the waves_per_eu dictionary attribute. waves_per_eu_dict = ir.DictAttr.get(