Skip to content

Commit

Permalink
[Bugfix] Fix load config when using bools (vllm-project#9533)
Browse files Browse the repository at this point in the history
Signed-off-by: Shanshan Wang <[email protected]>
  • Loading branch information
madt2709 authored and cooleel committed Oct 28, 2024
1 parent ee44915 commit 8b91237
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 22 deletions.
2 changes: 2 additions & 0 deletions tests/data/test_config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
port: 12312
served_model_name: mymodel
tensor_parallel_size: 2
trust_remote_code: true
multi_step_stream_outputs: false
6 changes: 5 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, merge_async_iterators, supports_kw)

from .utils import error_on_warning
Expand Down Expand Up @@ -141,6 +141,8 @@ def parser_with_config():
parser.add_argument('--config', type=str)
parser.add_argument('--port', type=int)
parser.add_argument('--tensor-parallel-size', type=int)
parser.add_argument('--trust-remote-code', action='store_true')
parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean)
return parser


Expand Down Expand Up @@ -214,6 +216,8 @@ def test_config_args(parser_with_config):
args = parser_with_config.parse_args(
['serve', 'mymodel', '--config', './data/test_config.yaml'])
assert args.tensor_parallel_size == 2
assert args.trust_remote_code
assert not args.multi_step_stream_outputs


def test_config_file(parser_with_config):
Expand Down
14 changes: 1 addition & 13 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, StoreBoolean

if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
Expand Down Expand Up @@ -1144,18 +1144,6 @@ def add_cli_args(parser: FlexibleArgumentParser,
return parser


class StoreBoolean(argparse.Action):

def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
setattr(namespace, self.dest, True)
elif values.lower() == "false":
setattr(namespace, self.dest, False)
else:
raise ValueError(f"Invalid boolean value: {values}. "
"Expected 'true' or 'false'.")


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(FlexibleArgumentParser())
Expand Down
35 changes: 27 additions & 8 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,18 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
return wrapper


class StoreBoolean(argparse.Action):

def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
setattr(namespace, self.dest, True)
elif values.lower() == "false":
setattr(namespace, self.dest, False)
else:
raise ValueError(f"Invalid boolean value: {values}. "
"Expected 'true' or 'false'.")


class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""

Expand All @@ -1163,7 +1175,7 @@ def parse_args(self, args=None, namespace=None):
args = sys.argv[1:]

if '--config' in args:
args = FlexibleArgumentParser._pull_args_from_config(args)
args = self._pull_args_from_config(args)

# Convert underscores to dashes and vice versa in argument names
processed_args = []
Expand All @@ -1181,8 +1193,7 @@ def parse_args(self, args=None, namespace=None):

return super().parse_args(processed_args, namespace)

@staticmethod
def _pull_args_from_config(args: List[str]) -> List[str]:
def _pull_args_from_config(self, args: List[str]) -> List[str]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
Expand Down Expand Up @@ -1226,7 +1237,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]:

file_path = args[index + 1]

config_args = FlexibleArgumentParser._load_config_file(file_path)
config_args = self._load_config_file(file_path)

# 0th index is for {serve,chat,complete}
# followed by model_tag (only for serve)
Expand All @@ -1247,8 +1258,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]:

return args

@staticmethod
def _load_config_file(file_path: str) -> List[str]:
def _load_config_file(self, file_path: str) -> List[str]:
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
Expand Down Expand Up @@ -1282,9 +1292,18 @@ def _load_config_file(file_path: str) -> List[str]:
Make sure path is correct", file_path)
raise ex

store_boolean_arguments = [
action.dest for action in self._actions
if isinstance(action, StoreBoolean)
]

for key, value in config.items():
processed_args.append('--' + key)
processed_args.append(str(value))
if isinstance(value, bool) and key not in store_boolean_arguments:
if value:
processed_args.append('--' + key)
else:
processed_args.append('--' + key)
processed_args.append(str(value))

return processed_args

Expand Down

0 comments on commit 8b91237

Please sign in to comment.