Skip to content

Commit

Permalink
address comments, fix bugs with benchmarking bindings, remove root_op…
Browse files Browse the repository at this point in the history
… attr from spec generation

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Dec 17, 2024
1 parent 3ae2d3e commit ad77baf
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 39 deletions.
29 changes: 20 additions & 9 deletions tuner/examples/test/tuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@

class TestTuner(libtuner.TuningClient):
def __init__(self):
self.compile_flags = [
"--iree-hip-target=gfx942",
"--compile-from=executable-sources",
]
self.compile_flags = ["--compile-from=executable-sources"]
self.benchmark_flags = ["--benchmark_repetitions=3", "--input=1"]

def get_iree_compile_flags(self) -> list[str]:
Expand Down Expand Up @@ -96,10 +93,10 @@ def main():
# TODO(Max191): Some bug seems to be causing OOM errors in benchmarking
# when device validation happens, so this is commented for now. Uncomment
# when the bug is fixed.
# if not args.dry_run:
# print("Validating devices")
# libtuner.validate_devices(args.devices)
# print("Validation successful!\n")
if not args.dry_run:
print("Validating devices")
libtuner.validate_devices(args.devices)
print("Validation successful!\n")

print("Generating candidates...")
candidates = libtuner.generate_candidate_specs(
Expand All @@ -126,11 +123,25 @@ def main():
)

print("Compiling models with top candidates...")
test_tuner.compile_flags = [
"--iree-hal-target-backends=rocm",
"--iree-hip-target=gfx942",
]
compiled_model_candidates = libtuner.compile(
args, path_config, top_candidates, candidate_trackers, test_tuner
args,
path_config,
top_candidates,
candidate_trackers,
test_tuner,
args.test_model_file,
)

print("Benchmarking compiled model candidates...")
test_tuner.benchmark_flags = [
"--benchmark_repetitions=3",
"--input=2048x2048xf16",
"--input=2048x2048xf16",
]
top_model_candidates = libtuner.benchmark(
args,
path_config,
Expand Down
89 changes: 61 additions & 28 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from iree.compiler import ir # type: ignore
from . import candidate_gen
from . import dispatch_parser
from .op_matchers import *
from .common import *


Expand Down Expand Up @@ -641,6 +642,22 @@ def run_command(run_pack: RunPack) -> RunResult:
return RunResult(result, is_timeout)


# The `strip_root_op_attr` and `strip_compilation_info` functions are used for
# getting consistent inputs to the compilation step in tuning. Inputs may come
# in with lowering configs, translation info, and root_op attrs when the input
# is a benchmark, but not when the input is a source MLIR file. Stripping the
# info makes the inputs to compilation consistent, and allows for overwriting
# the compilation info with generated TD specs during codegen.
def strip_root_op_attr(module: ir.Module):
root_ops: list[ir.Operation] = get_ops_from_module(module, is_root_op)
for root_op in root_ops:
assert (
ROOT_OP_ATTR_NAME in root_op.opview.attributes
), f"expected root op to have '{ROOT_OP_ATTR_NAME}' attr"
del root_op.opview.attributes[ROOT_OP_ATTR_NAME]


# See the above comment for `strip_root_op_attr`.
def strip_compilation_info(input_path: Path) -> str:
# Strip compilation info from the source and save the stripped IR
strip_command = [
Expand All @@ -661,39 +678,33 @@ def strip_compilation_info(input_path: Path) -> str:
logging.error(f"Command '{cmd}' returned non-zero exit status {e.returncode}.")
logging.error(f"Command '{cmd}' failed with error: {e.stderr}")
raise
except KeyboardInterrupt:
print("Ctrl+C detected, terminating child processes...")
return ""
return stripped_mlir


def run_iree_compile_command(compile_pack: CompilePack) -> Optional[int]:
candidate_tracker = compile_pack.candidate_tracker
tuning_client = compile_pack.tuning_client

# Strip compilation info from the source and save the stripped IR
input_path = candidate_tracker.mlir_path
assert input_path is not None, "expected input mlir_path"
stripped_mlir = strip_compilation_info(input_path)

# Compile to vmfb.
assert candidate_tracker.spec_path is not None, "expected candidate spec path"
assert candidate_tracker.spec_path, "expected candidate spec path"
td_spec_path = candidate_tracker.spec_path.as_posix()
logging.debug(
f"Compiling candidate {candidate_tracker.candidate_id} with spec: td_spec_path"
f"Compiling candidate {candidate_tracker.candidate_id} with spec: {td_spec_path}"
)
extra_flags = [
"--iree-codegen-tuning-spec-path=" + td_spec_path,
f"--iree-codegen-tuning-spec-path={td_spec_path}",
]
extra_flags.extend(tuning_client.get_iree_compile_flags())
assert candidate_tracker.compiled_vmfb_path is not None, "expected output vmfb path"
extra_flags += tuning_client.get_iree_compile_flags()
assert candidate_tracker.compiled_vmfb_path, "expected output vmfb path"
output_path = candidate_tracker.compiled_vmfb_path.as_posix()
crash_dump_path = output_path + ".crash_report.mlir"
crash_dump_path = f"{output_path}.crash_report.mlir"
assert candidate_tracker.mlir_path, "expected input mlir file path"
input_file = candidate_tracker.mlir_path.as_posix()
# TODO(Max191): Make the device in `traget_backends` a command line option
# instead of hardcoding in ireec.compile_str.
try:
ireec.compile_str(
input_str=stripped_mlir,
ireec.compile_file(
input_file=input_file,
target_backends=["rocm"],
output_file=output_path,
extra_args=extra_flags,
Expand Down Expand Up @@ -721,15 +732,14 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack):
with open(vmfb_path, "rb") as f:
vmfb_buffer = f.read()

rt_config = ireert.Config(device_id)
device = rt_config.device
vm_instance = rt_config.vm_instance
vm_instance = ireert.VmInstance()
vm_module = ireert.VmModule.copy_buffer(vm_instance, vmfb_buffer)

# Parse the flags passed from the tuning client and create a kwargs dict
# for the benchmark_module function.
extra_flags = {}
func_name = None
inputs = []
for flag in tuning_client.get_iree_benchmark_module_flags():
assert flag[:2] == "--", "iree_benchmark_module_flags should begin with '--'"
split_key_value = flag[2:].split("=")
Expand All @@ -742,6 +752,10 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack):
if key == "function":
func_name = value
continue
# Special handling for `--input`, since it can be passed many times.
if key == "input":
inputs.append(value)
continue
# Other flags become normal kwargs.
extra_flags[key] = value

Expand All @@ -751,7 +765,8 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack):
benchmark_results = ireert.benchmark.benchmark_module(
vm_module,
entry_function=func_name,
device=device,
inputs=inputs,
device=device_id,
timeout=timeout,
**extra_flags,
)
Expand Down Expand Up @@ -1034,14 +1049,15 @@ def generate_candidate_specs(
logging.debug("generate_candidate_specs()")

path_config.specs_dir.mkdir(parents=True, exist_ok=True)
shutil.copy(args.input_file, path_config.template_mlir)
tune_logger = logging.getLogger("tune")

# Generate transform dialect specs.
try:
# Strip compilation info before generating td_specs, since the generated
# td_specs can end up matching against the compilation info from the
# source mlir.
mlir_text = strip_compilation_info(args.input_file)
mlir_text = strip_compilation_info(path_config.template_mlir)
with ir.Context() as ctx:
tuner_context = TunerContext(ctx, tune_logger)
mlir_module = dispatch_parser.parse_mlir(mlir_text, tuner_context)
Expand All @@ -1068,7 +1084,7 @@ def generate_candidate_specs(
with open(spec_path, "w") as f:
f.write(str(spec))
new_candidate = CandidateTracker(
mlir_path=args.input_file,
mlir_path=path_config.template_mlir,
candidate_id=candidate_num,
spec_path=spec_path,
)
Expand Down Expand Up @@ -1119,20 +1135,31 @@ def compile(
logging.warning("No model candidates to compile.")
return []

# Set the source and output file paths for compilation of each candidate.
# If `input_file` is not None, then replace the currently tracked mlir_path
# If `input_file` is not None, then replace the currently tracked template
# with the passed input mlir file.
if input_file is not None:
shutil.copy(input_file, path_config.template_mlir)

# Strip compilation info and root_op attribute from the source and save
# the stripped IR, since the TD specs do not expect these attributes.
stripped_mlir = strip_compilation_info(path_config.template_mlir)
with ir.Context():
stripped_module = ir.Module.parse(stripped_mlir)
strip_root_op_attr(stripped_module)
stripped_mlir = str(stripped_module)
with open(path_config.template_mlir, "w") as f:
f.write(stripped_mlir)

# Set the source and output file paths for compilation of each candidate.
path_config.compiled_dir.mkdir(parents=True, exist_ok=True)
for i in candidates:
vmfb_file_name = path_config.get_candidate_vmfb_filename(
candidate_trackers[i].candidate_id
)
vmfb_path = path_config.compiled_dir / vmfb_file_name
candidate_trackers[i].compiled_vmfb_path = vmfb_path
if input_file is not None:
candidate_trackers[i].mlir_path = input_file
if input_file is not None:
candidate_trackers[0].mlir_path = input_file
candidate_trackers[i].mlir_path = path_config.template_mlir
candidate_trackers[0].mlir_path = path_config.template_mlir

# Run compilation for all candidates.
task_list = [
Expand All @@ -1141,6 +1168,12 @@ def compile(
)
for i in candidates
]
if 0 not in candidates:
task_list.append(
CompilePack(
tuning_client=tuning_client, candidate_tracker=candidate_trackers[0]
)
)
num_worker = min(args.max_cpu_workers, len(task_list))
compiled_candidates = multiprocess_progress_wrapper(
num_worker=num_worker, task_list=task_list, function=run_iree_compile_command
Expand Down
5 changes: 4 additions & 1 deletion tuner/tuner/op_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ def get_ops_from_module(module: ir.Module, fn):
return ops


ROOT_OP_ATTR_NAME = "root_op"


def is_root_op(op: ir.Operation) -> bool:
for attr in op.opview.attributes:
if attr.name == "root_op":
if attr.name == ROOT_OP_ATTR_NAME:
return True
return False

Expand Down
34 changes: 33 additions & 1 deletion tuner/tuner/spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .common import *
from .dispatch_constraints import *
from .dispatch_parser import *
from .op_matchers import ROOT_OP_ATTR_NAME


def get_placeholder_spec(context: ir.Context) -> ir.Module:
Expand All @@ -37,12 +38,43 @@ def build_td_spec(
func_name: str,
) -> ir.Module:
bbargs = []
# The `root_op` attribute will prevent matching of ops without the attr in
# the resulting TD spec matcher if it is not removed, so we remove it here.
# After removing, we must add it back, since the op is connected to the
# input module, which gets used for all candidates.
# TODO(Max191): Find a cleaner way to do this without removing and adding
# back the attribute.
has_root_attr = ROOT_OP_ATTR_NAME in op.opview.attributes
if has_root_attr:
assert isinstance(
op.opview.attributes[ROOT_OP_ATTR_NAME], ir.UnitAttr
), f"expected '{ROOT_OP_ATTR_NAME}' attr to be a unit attr"
if has_root_attr:
del op.opview.attributes[ROOT_OP_ATTR_NAME]
# Get the root op string for formatting the final spec.
root_operation = str(op)
if has_root_attr:
op.opview.attributes[ROOT_OP_ATTR_NAME] = ir.UnitAttr.get(op.context)

# Get the names ssa names of operands to make sure they match in the
# template after string formatting.
captured_values: set[ir.Value] = set()
for operand in op.operands:
if operand in captured_values:
# TODO(Max191): Remove this warning when the transform for the
# `cast_compatible_dag_from_root` op fixes a bug in the matching
# logic that causes failure to match when the same operand is
# repeated. For now, still avoid adding duplicate SSA values to
# prevent parsing failure.
logging.warning(
f"Root op has repeated operand. This can cause failure to match in the resulting TD spec at compile time."
)
continue
ssa_name = operand.get_name()
operand_type = operand.type
bbargs.append(f"{ssa_name}: {operand_type}")
captured_values.add(operand)
bbargs_str = ", ".join(bbargs)
root_operation = str(op)
spec_text = f"""
module attributes {{ transform.with_named_sequence }} {{
// Annotation Transform
Expand Down

0 comments on commit ad77baf

Please sign in to comment.