Skip to content

Commit

Permalink
[tuner]: use tuner_context
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Dec 3, 2024
1 parent 98b4add commit 07919e4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def tune(
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups, mma_list)
generate_solutions(tuner_context, problem_size, num_subgroups, mma_list)
):
if i >= limit:
break
Expand Down
10 changes: 3 additions & 7 deletions tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ def getMMAAttr(


def generate_solutions(
logger: logging.Logger,
tuner_ctx: TunerContext,
problem_size: ProblemSize,
num_subgrups: int,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> Iterator[Configuration]:
M, N, K = problem_size.MNK
logger.info(f"{M},{N},{K}")
tuner_ctx.logger.info(f"{M},{N},{K}")
m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k")
subgroup_size = z3.Int("subgroup_size")
intrinsic_mn = z3.Int("intrinsic_mn")
Expand Down Expand Up @@ -218,11 +218,7 @@ def generate_solutions(
mma_intrinsics,
)
solver.add(z3.simplify(z3.And(constraints)))
logger.debug(f"Initial constraints: {solver}")

int_type = ir.IntegerType.get_signless(64)

tuner_ctx = TunerContext(ir.Context(), logger)
tuner_ctx.logger.debug(f"Initial constraints: {solver}")

i = 0
while solver.check() == z3.sat:
Expand Down
2 changes: 1 addition & 1 deletion tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
configs = dispatch_constraints.generate_solutions(
tuner_ctx.logger,
tuner_ctx,
problem_size,
4,
[
Expand Down

0 comments on commit 07919e4

Please sign in to comment.