diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 2a544ef55..ca5d035e9 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -38,7 +38,7 @@ tune_logger = logging.getLogger("tune") - +# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'. def apply_configuration( template: list[str], configuration: Configuration, diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 915e84711..ea5c66f76 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -222,39 +222,32 @@ def generate_solutions( int_type = ir.IntegerType.get_signless(64) + tuner_ctx = TunerContext(ir.Context(), logger) + i = 0 while solver.check() == z3.sat: model = solver.model() lookup = lambda var: model[var].as_long() - lowering_config_dict = { - "mma_kind": getMMAAttr( - problem_size.res_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.lhs_type.element_type, - problem_size.rhs_type.element_type, - ), - "workgroup": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(int_type, lookup(m)), - ir.IntegerAttr.get(int_type, lookup(n)), - ir.IntegerAttr.get(int_type, 0), - ] - ), - "reduction": ir.ArrayAttr.get( - [ - ir.IntegerAttr.get(int_type, 0), - ir.IntegerAttr.get(int_type, 0), - ir.IntegerAttr.get(int_type, lookup(k)), - ] - ), # Placeholder now to be consistent with iree. - "subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)), - "subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)), - } - - lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) - lowering_config = iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) + mma_attr = getMMAAttr( + problem_size.res_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.lhs_type.element_type, + problem_size.rhs_type.element_type, + ) + lowering_config = get_lowering_config( + tuner_ctx=tuner_ctx, + mma_kind=mma_attr, + workgroup=[lookup(m), lookup(n), 0], + reduction=[ + 0, + 0, + lookup(k), + ], + subgroup_m_count=lookup(sg_m_cnt), + subgroup_n_count=lookup(sg_n_cnt), + ) config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)],