From 8b91237942957d2d2b4b69224d65bbf339ebb86b Mon Sep 17 00:00:00 2001 From: madt2709 <55849102+madt2709@users.noreply.github.com> Date: Sun, 27 Oct 2024 10:46:41 -0700 Subject: [PATCH] [Bugfix] Fix load config when using bools (#9533) Signed-off-by: Shanshan Wang --- tests/data/test_config.yaml | 2 ++ tests/test_utils.py | 6 +++++- vllm/engine/arg_utils.py | 14 +------------- vllm/utils.py | 35 +++++++++++++++++++++++++++-------- 4 files changed, 35 insertions(+), 22 deletions(-) diff --git a/tests/data/test_config.yaml b/tests/data/test_config.yaml index 42f4f6f7bb992..5090e8f357bb8 100644 --- a/tests/data/test_config.yaml +++ b/tests/data/test_config.yaml @@ -1,3 +1,5 @@ port: 12312 served_model_name: mymodel tensor_parallel_size: 2 +trust_remote_code: true +multi_step_stream_outputs: false diff --git a/tests/test_utils.py b/tests/test_utils.py index 0fed8e678fc76..a731b11eae81c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,7 @@ import pytest -from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs, +from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, get_open_port, merge_async_iterators, supports_kw) from .utils import error_on_warning @@ -141,6 +141,8 @@ def parser_with_config(): parser.add_argument('--config', type=str) parser.add_argument('--port', type=int) parser.add_argument('--tensor-parallel-size', type=int) + parser.add_argument('--trust-remote-code', action='store_true') + parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean) return parser @@ -214,6 +216,8 @@ def test_config_args(parser_with_config): args = parser_with_config.parse_args( ['serve', 'mymodel', '--config', './data/test_config.yaml']) assert args.tensor_parallel_size == 2 + assert args.trust_remote_code + assert not args.multi_step_stream_outputs def test_config_file(parser_with_config): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c49f475b9ee61..38687809a31f6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -19,7 +19,7 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, StoreBoolean if TYPE_CHECKING: from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup @@ -1144,18 +1144,6 @@ def add_cli_args(parser: FlexibleArgumentParser, return parser -class StoreBoolean(argparse.Action): - - def __call__(self, parser, namespace, values, option_string=None): - if values.lower() == "true": - setattr(namespace, self.dest, True) - elif values.lower() == "false": - setattr(namespace, self.dest, False) - else: - raise ValueError(f"Invalid boolean value: {values}. " - "Expected 'true' or 'false'.") - - # These functions are used by sphinx to build the documentation def _engine_args_parser(): return EngineArgs.add_cli_args(FlexibleArgumentParser()) diff --git a/vllm/utils.py b/vllm/utils.py index 1f75de89d0cc2..d4f2c936ca9cc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1155,6 +1155,18 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: return wrapper +class StoreBoolean(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + if values.lower() == "true": + setattr(namespace, self.dest, True) + elif values.lower() == "false": + setattr(namespace, self.dest, False) + else: + raise ValueError(f"Invalid boolean value: {values}. " + "Expected 'true' or 'false'.") + + class FlexibleArgumentParser(argparse.ArgumentParser): """ArgumentParser that allows both underscore and dash in names.""" @@ -1163,7 +1175,7 @@ def parse_args(self, args=None, namespace=None): args = sys.argv[1:] if '--config' in args: - args = FlexibleArgumentParser._pull_args_from_config(args) + args = self._pull_args_from_config(args) # Convert underscores to dashes and vice versa in argument names processed_args = [] @@ -1181,8 +1193,7 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) - @staticmethod - def _pull_args_from_config(args: List[str]) -> List[str]: + def _pull_args_from_config(self, args: List[str]) -> List[str]: """Method to pull arguments specified in the config file into the command-line args variable. @@ -1226,7 +1237,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]: file_path = args[index + 1] - config_args = FlexibleArgumentParser._load_config_file(file_path) + config_args = self._load_config_file(file_path) # 0th index is for {serve,chat,complete} # followed by model_tag (only for serve) @@ -1247,8 +1258,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]: return args - @staticmethod - def _load_config_file(file_path: str) -> List[str]: + def _load_config_file(self, file_path: str) -> List[str]: """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml @@ -1282,9 +1292,18 @@ def _load_config_file(file_path: str) -> List[str]: Make sure path is correct", file_path) raise ex + store_boolean_arguments = [ + action.dest for action in self._actions + if isinstance(action, StoreBoolean) + ] + for key, value in config.items(): - processed_args.append('--' + key) - processed_args.append(str(value)) + if isinstance(value, bool) and key not in store_boolean_arguments: + if value: + processed_args.append('--' + key) + else: + processed_args.append('--' + key) + processed_args.append(str(value)) return processed_args