From 77ee930f86c06424798acd3b2c781e746db8a0f3 Mon Sep 17 00:00:00 2001 From: Yuanhao Ji Date: Mon, 2 Dec 2024 11:30:07 +0800 Subject: [PATCH] Add --skip for test_bench --- .../util/experiment/instantiator.py | 1 + userbenchmark/test_bench/run.py | 23 ++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/torchbenchmark/util/experiment/instantiator.py b/torchbenchmark/util/experiment/instantiator.py index f98e1380a..71614d40a 100644 --- a/torchbenchmark/util/experiment/instantiator.py +++ b/torchbenchmark/util/experiment/instantiator.py @@ -31,6 +31,7 @@ class TorchBenchModelConfig: extra_args: List[str] extra_env: Optional[Dict[str, str]] = None output_dir: Optional[pathlib.Path] = None + skip: bool = False def _set_extra_env(extra_env): diff --git a/userbenchmark/test_bench/run.py b/userbenchmark/test_bench/run.py index 32b6bb358..c2f115d88 100644 --- a/userbenchmark/test_bench/run.py +++ b/userbenchmark/test_bench/run.py @@ -82,6 +82,7 @@ def generate_model_configs( tests: List[str], batch_sizes: List[str], model_names: List[str], + skip_models: List[str], extra_args: List[str], ) -> List[TorchBenchModelConfig]: """Use the default batch size and default mode.""" @@ -94,6 +95,7 @@ def generate_model_configs( batch_size=None if not batch_size else int(batch_size), extra_args=extra_args, extra_env=None, + skip=model_name in skip_models ) for device, test, batch_size, model_name in cfgs ] @@ -165,6 +167,9 @@ def run_config( if dryrun: print(" [skip_by_dryrun]", flush=True) return dict.fromkeys(metrics, "skip_by_dryrun") + if config.skip: + print(" [skip]", flush=True) + return dict.fromkeys(metrics, "skip") # We do not allow RuntimeError in this test try: # load the model instance in subprocess @@ -214,6 +219,9 @@ def assertEqual(x, y): if dryrun: print(" [skip_by_dryrun] ", flush=True) return {"memleak": "skip_by_dryrun"} + if config.skip: + print(" [skip]", flush=True) + return {"memleak": "skip"} try: with task.watch_cuda_memory( skip=False, @@ -250,6 +258,9 @@ def run_config_accuracy( if dryrun: print(" [skip_by_dryrun]", flush=True) return {"accuracy": "skip_by_dryrun"} + if config.skip: + print(" [skip]", flush=True) + return {"accuracy": "skip"} try: accuracy = get_model_accuracy(config) print(" [done]", flush=True) @@ -277,6 +288,11 @@ def parse_known_args(args): default=None, help="Name of models to run, split by comma.", ) + parser.add_argument( + "--skip", + default=None, + help="Name of models to skip running, split by comma." + ) parser.add_argument( "--device", "-d", @@ -333,7 +349,10 @@ def run(args: List[str]): modelset = modelset.union(timm_set).union(huggingface_set) if not args.models: args.models = [] + if not args.skip: + args.skip = [] args.models = parse_str_to_list(args.models) + args.skip = parse_str_to_list(args.skip) if args.timm: args.models.extend(timm_set) if args.huggingface: @@ -348,8 +367,10 @@ def run(args: List[str]): tests = validate(parse_str_to_list(args.test), list_tests()) batch_sizes = parse_str_to_list(args.bs) models = validate(args.models, modelset) + skips = validate(args.skip, modelset) configs = generate_model_configs( - devices, tests, batch_sizes, model_names=models, extra_args=extra_args + devices, tests, batch_sizes, + model_names=models, skip_models=skips, extra_args=extra_args ) debug_output_dir = get_default_debug_output_dir(args.output) if args.debug else None configs = (