Skip to content

Commit

Permalink
Remove no_workgroup_reorder in tune.py and test_tune.py (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing authored Jun 18, 2024
1 parent 2d9ec68 commit 57aaf1d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 38 deletions.
37 changes: 12 additions & 25 deletions tuning/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_get_mmt_tile_sizes():
tile_sizes=[128, 320, 32],
subgroup_m_count=0,
subgroup_n_count=0,
waves_per_eu=0,
no_workgroup_reorder=0,
waves_per_eu=0
)
assert tune.get_mmt_tile_sizes(config) == [128, 320, 32]

Expand All @@ -28,8 +27,7 @@ def test_get_conv_tile_sizes():
tile_sizes=[464, 320, 16],
subgroup_m_count=1,
subgroup_n_count=4,
waves_per_eu=1,
no_workgroup_reorder=0,
waves_per_eu=1
)
assert tune.get_conv_tile_sizes(config) == (1, 1, 464, 320, 1, 1, 16)

Expand All @@ -42,8 +40,7 @@ def test_get_contract_tile_sizes():
tile_sizes=[4, 8, 16],
subgroup_m_count=1,
subgroup_n_count=1,
waves_per_eu=2,
no_workgroup_reorder=0,
waves_per_eu=2
)
assert tune.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16]
assert tune.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16]
Expand All @@ -59,8 +56,7 @@ def test_get_pipeline_config():
tile_sizes=[4, 8, 16],
subgroup_m_count=1,
subgroup_n_count=1,
waves_per_eu=2,
no_workgroup_reorder=0,
waves_per_eu=2
)
config2 = tune.Configuration(
subgroup_size=32,
Expand All @@ -69,13 +65,12 @@ def test_get_pipeline_config():
tile_sizes=[4, 8, 16],
subgroup_m_count=1,
subgroup_n_count=1,
waves_per_eu=4,
no_workgroup_reorder=1,
waves_per_eu=4
)
assert tune.get_pipeline_config(config1) == ""
assert (
tune.get_pipeline_config(config2)
== ', no_reorder_workgroups, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}'
== ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}'
)


Expand Down Expand Up @@ -148,7 +143,6 @@ def test_generate_constraints_valid_input():
sg_m_cnt = tune.z3.Int("sg_m_cnt")
sg_n_cnt = tune.z3.Int("sg_n_cnt")
waves_per_eu = tune.z3.Int("waves_per_eu")
no_workgroup_reorder = tune.z3.Int("no_workgroup_reorder")

constraints = tune.generate_constraints(
[M, N, K],
Expand All @@ -158,8 +152,7 @@ def test_generate_constraints_valid_input():
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
no_workgroup_reorder,
waves_per_eu
)

solver = tune.z3.Solver()
Expand All @@ -180,7 +173,6 @@ def test_generate_constraints_invalid_input():
sg_m_cnt = tune.z3.Int("sg_m_cnt")
sg_n_cnt = tune.z3.Int("sg_n_cnt")
waves_per_eu = tune.z3.Int("waves_per_eu")
no_workgroup_reorder = tune.z3.Int("no_workgroup_reorder")

constraints = tune.generate_constraints(
[M, N, K],
Expand All @@ -190,8 +182,7 @@ def test_generate_constraints_invalid_input():
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
no_workgroup_reorder,
waves_per_eu
)
constraints.append(m > 1000) # Adding an additional unsatisfiable constraint

Expand Down Expand Up @@ -219,8 +210,7 @@ def test_apply_params_mmt():
tile_sizes=[8, 8, 8],
subgroup_m_count=16,
subgroup_n_count=16,
waves_per_eu=8,
no_workgroup_reorder=0,
waves_per_eu=8
)

modified, embeddable = tune.apply_params_mmt(M, N, K, mlir_template, config)
Expand Down Expand Up @@ -256,8 +246,7 @@ def test_apply_params_conv():
tile_sizes=[464, 320, 16],
subgroup_m_count=1,
subgroup_n_count=4,
waves_per_eu=1,
no_workgroup_reorder=0,
waves_per_eu=1
)

modified, embeddable = tune.apply_params_conv(
Expand Down Expand Up @@ -297,8 +286,7 @@ def test_apply_params_contract():
tile_sizes=[480, 384, 32],
subgroup_m_count=1,
subgroup_n_count=4,
waves_per_eu=2,
no_workgroup_reorder=1,
waves_per_eu=2
)

new_mlir = tune.apply_params_contract(
Expand Down Expand Up @@ -337,8 +325,7 @@ def test_apply_params_batch_matmul():
tile_sizes=[416, 320, 128],
subgroup_m_count=2,
subgroup_n_count=2,
waves_per_eu=1,
no_workgroup_reorder=1,
waves_per_eu=1
)

modified, embeddable = tune.apply_params_batch_matmul(
Expand Down
17 changes: 4 additions & 13 deletions tuning/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class Configuration:
subgroup_m_count: int
subgroup_n_count: int
waves_per_eu: int
no_workgroup_reorder: int


def read_input_mlir(filename):
Expand Down Expand Up @@ -86,8 +85,6 @@ def get_contract_tile_sizes(configuration: Configuration, tile_dims):

def get_pipeline_config(configuration: Configuration) -> str:
extra_config = ""
if configuration.no_workgroup_reorder == 1:
extra_config += ", no_reorder_workgroups"
if configuration.waves_per_eu != 2:
extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}'
return extra_config
Expand Down Expand Up @@ -520,8 +517,7 @@ def generate_constraints(
workgroup_size,
subgroup_m_count,
subgroup_n_count,
waves_per_eu,
no_workgroup_reorder,
waves_per_eu
):
M, N, K = problem_size
m, n, k = tile_sizes
Expand Down Expand Up @@ -564,7 +560,6 @@ def generate_constraints(
constraints += [subgroup_m_count * subgroup_n_count == 4]

constraints += [z3.Or(waves_per_eu == 1, waves_per_eu == 2, waves_per_eu == 4)]
constraints += [no_workgroup_reorder >= 0, no_workgroup_reorder <= 1]

return constraints

Expand All @@ -579,7 +574,6 @@ def generate_solutions(M, N, K):
sg_m_cnt = z3.Int("sg_m_cnt")
sg_n_cnt = z3.Int("sg_n_cnt")
waves_per_eu = z3.Int("waves_per_eu")
no_workgroup_reorder = z3.Int("no_workgroup_reorder")
all_vars = [
m,
n,
Expand All @@ -592,8 +586,7 @@ def generate_solutions(M, N, K):
wg_z,
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
no_workgroup_reorder,
waves_per_eu
]

solver = z3.Solver()
Expand All @@ -605,8 +598,7 @@ def generate_solutions(M, N, K):
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
no_workgroup_reorder,
waves_per_eu
)
solver.add(z3.simplify(z3.And(constraints)))
tune_logger.debug(f"Initial constraints: {solver}")
Expand All @@ -622,8 +614,7 @@ def generate_solutions(M, N, K):
[lookup(m), lookup(n), lookup(k)],
lookup(sg_m_cnt),
lookup(sg_n_cnt),
lookup(waves_per_eu),
lookup(no_workgroup_reorder),
lookup(waves_per_eu)
)
solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))))
i += 1
Expand Down

0 comments on commit 57aaf1d

Please sign in to comment.