Skip to content

Commit

Permalink
Fix iree-tf/benchmark-model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Che-Yu Wu committed May 26, 2023
1 parent 064f2aa commit 10ff9ed
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions iree-tf/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Optional

# Add library dir to the search path.
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent / "library"))
sys.path.insert(0, str(pathlib.Path(__file__).parents[1] / "library"))
from models import resnet50, bert_large, t5_large

# Add benchmark definitions to the search path.
Expand Down Expand Up @@ -46,6 +46,15 @@ def benchmark_lookup(unique_id: str):
raise ValueError(f"Model definition not supported")


def benchmark_lookup(benchmark_id: str):
benchmark = tf_inference_benchmarks.ID_TO_BENCHMARK_MAP.get(benchmark_id)
if benchmark is None:
raise ValueError(f"Id {benchmark_id} does not exist in benchmark suite.")

model_name, model_class = model_lookup(benchmark.model.id)
return (model_name, model_class, benchmark)


def dump_result(file_path: str, result: dict) -> None:
with open(file_path, "r") as f:
dictObj = json.load(f)
Expand All @@ -66,7 +75,8 @@ def bytes_to_mb(bytes: Optional[int]) -> Optional[float]:
def run_framework_benchmark(model_name: str, model_class: type[tf.Module],
batch_size: int, warmup_iterations: int,
benchmark_iterations: int, tf_device: str,
hlo_dump_dir: str, dump_hlo: bool, shared_dict) -> None:
hlo_dump_dir: str, dump_hlo: bool,
shared_dict) -> None:
try:
with tf.device(tf_device):
if dump_hlo:
Expand Down Expand Up @@ -216,17 +226,16 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,

args = argParser.parse_args()

model_name, model_class, model_definition = benchmark_lookup(
args.benchmark_id)
model_name, model_class, benchmark = benchmark_lookup(args.benchmark_id)
print(
f"\n\n--- {model_name} {args.benchmark_id} -------------------------------------"
f"\n\n--- {benchmark.name} {args.benchmark_id} -------------------------------------"
)

if os.path.exists(_HLO_DUMP_DIR):
shutil.rmtree(_HLO_DUMP_DIR)
os.mkdir(_HLO_DUMP_DIR)

batch_size = model_definition.input_batch_size
batch_size = benchmark.input_batch_size
benchmark_definition = {
"benchmark_id": args.benchmark_id,
"benchmark_name": model_definition.name,
Expand All @@ -248,9 +257,9 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,
shared_dict = manager.dict()

if args.run_in_process:
run_framework_benchmark(model_name, model_class, batch_size, args.warmup_iterations,
args.iterations, tf_device, _HLO_DUMP_DIR, dump_hlo,
shared_dict)
run_framework_benchmark(model_name, model_class, batch_size,
args.warmup_iterations, args.iterations,
tf_device, _HLO_DUMP_DIR, dump_hlo, shared_dict)
else:
p = multiprocessing.Process(target=run_framework_benchmark,
args=(model_name, model_class, batch_size,
Expand All @@ -269,8 +278,10 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,
shared_dict = manager.dict()

if args.run_in_process:
run_compiler_benchmark(args.hlo_benchmark_path, _HLO_DUMP_DIR, args.hlo_iterations,
"cuda" if args.device == "gpu" else "cpu", shared_dict)
run_compiler_benchmark(args.hlo_benchmark_path, _HLO_DUMP_DIR,
args.hlo_iterations,
"cuda" if args.device == "gpu" else "cpu",
shared_dict)
else:
p = multiprocessing.Process(
target=run_compiler_benchmark,
Expand Down

0 comments on commit 10ff9ed

Please sign in to comment.