Skip to content

Commit

Permalink
[tuner] Add timeout for compilation in TuningClient
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Jan 9, 2025
1 parent 7849f8e commit 78b43e1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
3 changes: 3 additions & 0 deletions tuner/examples/simple/simple_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(self, tuner_context: libtuner.TunerContext):
def get_iree_compile_flags(self) -> list[str]:
return self.compile_flags

def get_iree_compile_timeout_s(self) -> int:
return 10

def get_iree_benchmark_module_flags(self) -> list[str]:
return self.benchmark_flags

Expand Down
46 changes: 29 additions & 17 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def __init__(self, tuner_context: TunerContext):
def get_iree_compile_flags(self) -> list[str]:
pass

@abstractmethod
def get_iree_compile_timeout_s(self) -> int:
pass

@abstractmethod
def get_iree_benchmark_module_flags(self) -> list[str]:
pass
Expand All @@ -122,6 +126,7 @@ def get_benchmark_timeout_s(self) -> int:
@dataclass
class CompilePack:
iree_compile_flags: list[str]
iree_compile_timeout: int
candidate_tracker: CandidateTracker


Expand Down Expand Up @@ -440,30 +445,35 @@ def run_iree_compile_command(compile_pack: CompilePack) -> Optional[int]:
logging.debug(
f"Compiling candidate {candidate_tracker.candidate_id} with spec: {td_spec_path}"
)
extra_flags = [
f"--iree-codegen-tuning-spec-path={td_spec_path}",
]
extra_flags += compile_pack.iree_compile_flags
assert candidate_tracker.compiled_vmfb_path, "expected output vmfb path"
output_path = candidate_tracker.compiled_vmfb_path.as_posix()
crash_dump_path = f"{output_path}.crash_report.mlir"
assert candidate_tracker.mlir_path, "expected input mlir file path"
input_file = candidate_tracker.mlir_path.as_posix()
# TODO(Max191): Make the device in `traget_backends` a command line option
# instead of hardcoding in ireec.compile_str.
try:
ireec.compile_file(
input_file=input_file,
target_backends=["rocm"],
output_file=output_path,
extra_args=extra_flags,
crash_reproducer_path=crash_dump_path,
# TODO(Max191): Make the device in `traget_backend` a command line option
# instead of hardcoding rocm.
target_backend = "rocm"
iree_compile = ireec.binaries.find_tool("iree-compile")
compile_command = [
iree_compile,
input_file,
f"--iree-input-type={str(ireec.core.InputType.AUTO)}",
f"--iree-vm-bytecode-module-output-format={str(ireec.core.OutputFormat.FLATBUFFER_BINARY)}",
f"--iree-hal-target-backends={target_backend}",
f"-o={output_path}",
f"--mlir-pass-pipeline-crash-reproducer={crash_dump_path}",
f"--iree-codegen-tuning-spec-path={td_spec_path}",
]
compile_command += compile_pack.iree_compile_flags
result = candidate_gen.run_command(
candidate_gen.RunPack(
command=compile_command,
check=False,
timeout_seconds=compile_pack.iree_compile_timeout,
)
except ireec.CompilerToolError as e:
logging.info(f"Compilation returned non-zero exit status.")
logging.debug(e)
)
if result.process_res is None or result.is_timeout:
return None

return candidate_tracker.candidate_id


Expand Down Expand Up @@ -771,6 +781,7 @@ def compile(
task_list = [
CompilePack(
iree_compile_flags=tuning_client.get_iree_compile_flags(),
iree_compile_timeout=tuning_client.get_iree_compile_timeout_s(),
candidate_tracker=candidate_trackers[i],
)
for i in candidates
Expand All @@ -779,6 +790,7 @@ def compile(
task_list.append(
CompilePack(
iree_compile_flags=tuning_client.get_iree_compile_flags(),
iree_compile_timeout=tuning_client.get_iree_compile_timeout_s(),
candidate_tracker=candidate_trackers[0],
)
)
Expand Down

0 comments on commit 78b43e1

Please sign in to comment.