Skip to content

Commit

Permalink
Ensure duplicate arguments are only checked within their respective a…
Browse files Browse the repository at this point in the history
…rgument groups

Differential Revision: D57459718

Pull Request resolved: #911
  • Loading branch information
hstonec authored May 16, 2024
1 parent ac3cc78 commit 17a76a6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
24 changes: 14 additions & 10 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import threading
from collections import Counter
from dataclasses import asdict
from itertools import groupby
from pathlib import Path
from pprint import pformat
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -85,16 +86,19 @@ def _parse_component_name_and_args(
component = args[0]
component_args = args[1:]

# Error if there are repeated command line arguments
all_options = [
x
for x in component_args
if x.startswith("-") and x.strip() != "-" and x.strip() != "--"
]
arg_count = Counter(all_options)
duplicates = [arg for arg, count in arg_count.items() if count > 1]
if len(duplicates) > 0:
subparser.error(f"Repeated Command Line Arguments: {duplicates}")
# Error if there are repeated command line arguments each group of arguments,
# where the groups are separated by "--"
arg_groups = [list(g) for _, g in groupby(component_args, key=lambda x: x == "--")]
for arg_group in arg_groups:
all_options = [
x
for x in arg_group
if x.startswith("-") and x.strip() != "-" and x.strip() != "--"
]
arg_count = Counter(all_options)
duplicates = [arg for arg, count in arg_count.items() if count > 1]
if len(duplicates) > 0:
subparser.error(f"Repeated Command Line Arguments: {duplicates}")

if not component:
subparser.error(MISSING_COMPONENT_ERROR_MSG)
Expand Down
31 changes: 31 additions & 0 deletions torchx/cli/test/cmd_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,32 @@ def test_parse_component_name_and_args_no_default(self) -> None:
),
)

self.assertEqual(
(
"fb.python.binary",
[
"--img",
"lex_ig_o3_package",
"-m",
"dper_lib.instagram.pyper_v2.teams.stories.train",
"--",
"-m",
],
),
_parse_component_name_and_args(
[
"fb.python.binary",
"--img",
"lex_ig_o3_package",
"-m",
"dper_lib.instagram.pyper_v2.teams.stories.train",
"--",
"-m",
],
sp,
),
)

with self.assertRaises(SystemExit):
_parse_component_name_and_args(["--"], sp)

Expand All @@ -271,6 +297,11 @@ def test_parse_component_name_and_args_no_default(self) -> None:
["--msg ", "hello", "--msg ", "repeate"], sp
)

with self.assertRaises(SystemExit):
_parse_component_name_and_args(
["--m", "hello", "--", "--msg", "msg", "--msg", "repeate"], sp
)

def test_parse_component_name_and_args_with_default(self) -> None:
sp = argparse.ArgumentParser(prog="test")
dirs = [str(self.tmpdir)]
Expand Down

0 comments on commit 17a76a6

Please sign in to comment.