Skip to content

Commit

Permalink
Print more info
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing authored and nithinsubbiah committed Jun 21, 2024
1 parent fcceac8 commit ac6f800
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 3 additions & 3 deletions tuning/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def compile_unet_candidates(
f"Hash value '{hash_val}' collided at candidate {indices}."
)
unique_unet_candidates.append(
candidate_trackers[indices[0]].unet_candidate_path
candidate_trackers[indices[0]].unet_candidate_path # If collision occurs, use the first candidate index in the list
)

return unique_unet_candidates if collision_detected else unet_candidates
Expand Down Expand Up @@ -626,15 +626,15 @@ def main():
best_log = benchmark_top_candidates(
args, base_dir, candidates_dir, compiled_files, candidate_trackers
)
print(f"Top candidates results are stored in {best_log}\n")
print(f"Top20 candidates results are stored in {best_log}\n")

print("Compiling unet candidates...")
unet_candidates = compile_unet_candidates(
args, base_dir, best_log, candidate_trackers
)
print(f"Unet candidates compiled in {base_dir}\n")

print("Bnechmarking unet candidates...")
print(f"Bnechmarking [{len(unet_candidates)}] unet candidates...")
unet_result_log = benchmark_unet(
args, base_dir, unet_candidates, candidate_trackers
)
Expand Down
4 changes: 3 additions & 1 deletion tuning/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,9 @@ def generate_constraints(
constraints += [k == intrinsic_mn * z3.FreshInt()]
constraints += [k * n % (wg_x * wg_y * wg_z) == 0]
constraints += [k * m % (wg_x * wg_y * wg_z) == 0]
constraints += [subgroup_m_count * subgroup_n_count == 4]
# constraints += [subgroup_m_count * subgroup_n_count == 4] # splat
# constraints += [subgroup_m_count * subgroup_n_count == 2] # real_weights
constraints += [subgroup_m_count * subgroup_n_count == 5] # conv

constraints += [z3.Or(waves_per_eu == 1, waves_per_eu == 2, waves_per_eu == 4)]

Expand Down

0 comments on commit ac6f800

Please sign in to comment.