Skip to content

Commit

Permalink
[tuner]: use helper function
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
bangtianliu committed Dec 3, 2024
1 parent 06c1b51 commit 98b4add
Showing 2 changed files with 23 additions and 30 deletions.
2 changes: 1 addition & 1 deletion tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@

tune_logger = logging.getLogger("tune")


# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.
def apply_configuration(
template: list[str],
configuration: Configuration,
51 changes: 22 additions & 29 deletions tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
@@ -222,39 +222,32 @@ def generate_solutions(

int_type = ir.IntegerType.get_signless(64)

tuner_ctx = TunerContext(ir.Context(), logger)

i = 0
while solver.check() == z3.sat:
model = solver.model()
lookup = lambda var: model[var].as_long()
lowering_config_dict = {
"mma_kind": getMMAAttr(
problem_size.res_type.element_type,
lookup(intrinsic_mn),
lookup(intrinsic_mn),
lookup(intrinsic_k),
problem_size.lhs_type.element_type,
problem_size.rhs_type.element_type,
),
"workgroup": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(int_type, lookup(m)),
ir.IntegerAttr.get(int_type, lookup(n)),
ir.IntegerAttr.get(int_type, 0),
]
),
"reduction": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(int_type, 0),
ir.IntegerAttr.get(int_type, 0),
ir.IntegerAttr.get(int_type, lookup(k)),
]
), # Placeholder now to be consistent with iree.
"subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)),
"subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)),
}

lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
lowering_config = iree_gpu.LoweringConfigAttr.get(lowering_config_attrs)
mma_attr = getMMAAttr(
problem_size.res_type.element_type,
lookup(intrinsic_mn),
lookup(intrinsic_mn),
lookup(intrinsic_k),
problem_size.lhs_type.element_type,
problem_size.rhs_type.element_type,
)
lowering_config = get_lowering_config(
tuner_ctx=tuner_ctx,
mma_kind=mma_attr,
workgroup=[lookup(m), lookup(n), 0],
reduction=[
0,
0,
lookup(k),
],
subgroup_m_count=lookup(sg_m_cnt),
subgroup_n_count=lookup(sg_n_cnt),
)
config = Configuration(
lookup(subgroup_size),
[lookup(wg_x), lookup(wg_y), lookup(wg_z)],

0 comments on commit 98b4add

Please sign in to comment.