From 5100b5da69c8fc5400b13fe327355cfb0682574a Mon Sep 17 00:00:00 2001 From: Andy Wagner Date: Thu, 18 Apr 2024 16:08:15 -0700 Subject: [PATCH] Fix for issue with multiple -- in command line Differential Revision: D56317273 Pull Request resolved: https://github.com/pytorch/torchx/pull/888 --- torchx/cli/cmd_run.py | 6 +++++- torchx/cli/test/cmd_run_test.py | 26 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index 91fd8c6e5..8e88c36d2 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -86,7 +86,11 @@ def _parse_component_name_and_args( component_args = args[1:] # Error if there are repeated command line arguments - all_options = [x for x in component_args if x.startswith("-")] + all_options = [ + x + for x in component_args + if x.startswith("-") and x.strip() != "-" and x.strip() != "--" + ] arg_count = Counter(all_options) duplicates = [arg for arg, count in arg_count.items() if count > 1] if len(duplicates) > 0: diff --git a/torchx/cli/test/cmd_run_test.py b/torchx/cli/test/cmd_run_test.py index 0fc355924..199087c7b 100644 --- a/torchx/cli/test/cmd_run_test.py +++ b/torchx/cli/test/cmd_run_test.py @@ -230,6 +230,27 @@ def test_parse_component_name_and_args_no_default(self) -> None: _parse_component_name_and_args(["utils.echo", "--msg", "hello"], sp), ) + self.assertEqual( + ("utils.echo", ["--msg", "hello", "--", "--"]), + _parse_component_name_and_args( + ["utils.echo", "--msg", "hello", "--", "--"], sp + ), + ) + + self.assertEqual( + ("utils.echo", ["--msg", "hello", "-", "-"]), + _parse_component_name_and_args( + ["utils.echo", "--msg", "hello", "-", "-"], sp + ), + ) + + self.assertEqual( + ("utils.echo", ["--msg", "hello", "- ", "- "]), + _parse_component_name_and_args( + ["utils.echo", "--msg", "hello", "- ", "- "], sp + ), + ) + with self.assertRaises(SystemExit): _parse_component_name_and_args(["--"], sp) @@ -245,6 +266,11 @@ def test_parse_component_name_and_args_no_default(self) -> None: with self.assertRaises(SystemExit): _parse_component_name_and_args(["--msg", "hello", "--msg", "repeate"], sp) + with self.assertRaises(SystemExit): + _parse_component_name_and_args( + ["--msg ", "hello", "--msg ", "repeate"], sp + ) + def test_parse_component_name_and_args_with_default(self) -> None: sp = argparse.ArgumentParser(prog="test") dirs = [str(self.tmpdir)]