Skip to content

Commit

Permalink
Add --skip for test_bench
Browse files Browse the repository at this point in the history
  • Loading branch information
shink committed Dec 2, 2024
1 parent c57be06 commit 77ee930
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
1 change: 1 addition & 0 deletions torchbenchmark/util/experiment/instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TorchBenchModelConfig:
extra_args: List[str]
extra_env: Optional[Dict[str, str]] = None
output_dir: Optional[pathlib.Path] = None
skip: bool = False


def _set_extra_env(extra_env):
Expand Down
23 changes: 22 additions & 1 deletion userbenchmark/test_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def generate_model_configs(
tests: List[str],
batch_sizes: List[str],
model_names: List[str],
skip_models: List[str],
extra_args: List[str],
) -> List[TorchBenchModelConfig]:
"""Use the default batch size and default mode."""
Expand All @@ -94,6 +95,7 @@ def generate_model_configs(
batch_size=None if not batch_size else int(batch_size),
extra_args=extra_args,
extra_env=None,
skip=model_name in skip_models
)
for device, test, batch_size, model_name in cfgs
]
Expand Down Expand Up @@ -165,6 +167,9 @@ def run_config(
if dryrun:
print(" [skip_by_dryrun]", flush=True)
return dict.fromkeys(metrics, "skip_by_dryrun")
if config.skip:
print(" [skip]", flush=True)
return dict.fromkeys(metrics, "skip")
# We do not allow RuntimeError in this test
try:
# load the model instance in subprocess
Expand Down Expand Up @@ -214,6 +219,9 @@ def assertEqual(x, y):
if dryrun:
print(" [skip_by_dryrun] ", flush=True)
return {"memleak": "skip_by_dryrun"}
if config.skip:
print(" [skip]", flush=True)
return {"memleak": "skip"}
try:
with task.watch_cuda_memory(
skip=False,
Expand Down Expand Up @@ -250,6 +258,9 @@ def run_config_accuracy(
if dryrun:
print(" [skip_by_dryrun]", flush=True)
return {"accuracy": "skip_by_dryrun"}
if config.skip:
print(" [skip]", flush=True)
return {"accuracy": "skip"}
try:
accuracy = get_model_accuracy(config)
print(" [done]", flush=True)
Expand Down Expand Up @@ -277,6 +288,11 @@ def parse_known_args(args):
default=None,
help="Name of models to run, split by comma.",
)
parser.add_argument(
"--skip",
default=None,
help="Name of models to skip running, split by comma."
)
parser.add_argument(
"--device",
"-d",
Expand Down Expand Up @@ -333,7 +349,10 @@ def run(args: List[str]):
modelset = modelset.union(timm_set).union(huggingface_set)
if not args.models:
args.models = []
if not args.skip:
args.skip = []
args.models = parse_str_to_list(args.models)
args.skip = parse_str_to_list(args.skip)
if args.timm:
args.models.extend(timm_set)
if args.huggingface:
Expand All @@ -348,8 +367,10 @@ def run(args: List[str]):
tests = validate(parse_str_to_list(args.test), list_tests())
batch_sizes = parse_str_to_list(args.bs)
models = validate(args.models, modelset)
skips = validate(args.skip, modelset)
configs = generate_model_configs(
devices, tests, batch_sizes, model_names=models, extra_args=extra_args
devices, tests, batch_sizes,
model_names=models, skip_models=skips, extra_args=extra_args
)
debug_output_dir = get_default_debug_output_dir(args.output) if args.debug else None
configs = (
Expand Down

0 comments on commit 77ee930

Please sign in to comment.