Skip to content

Commit

Permalink
add eager-mode to cli (#2645)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon authored Oct 24, 2024
1 parent cd3e791 commit 4958071
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 1 deletion.
3 changes: 3 additions & 0 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def parse_args():
ArgumentHelper.backend(parser)
# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.eager_mode(pt_group)

tp_act = ArgumentHelper.tp(pt_group)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
Expand Down Expand Up @@ -422,6 +424,7 @@ def main():
session_len=session_len,
tp=args.tp,
thread_safe=True,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
)
gen_config = GenerationConfig(top_k=args.top_k,
Expand Down
3 changes: 3 additions & 0 deletions benchmark/profile_pipeline_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def parse_args():

# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.eager_mode(pt_group)

tp_act = ArgumentHelper.tp(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group, default=4096)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
Expand Down Expand Up @@ -241,6 +243,7 @@ def main():
max_batch_size=args.concurrency,
tp=args.tp,
thread_safe=False,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
)

Expand Down
3 changes: 3 additions & 0 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def parse_args():

# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.eager_mode(pt_group)

tp_act = ArgumentHelper.tp(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group, default=4096)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
Expand Down Expand Up @@ -328,6 +330,7 @@ def main():
max_batch_size=args.concurrency,
tp=args.tp,
thread_safe=True,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
quant_policy=args.quant_policy,
)
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def add_parser_chat():
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.adapters(pt_group)
ArgumentHelper.device(pt_group)
ArgumentHelper.eager_mode(pt_group)
# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
tp_act = ArgumentHelper.tp(pt_group)
Expand Down Expand Up @@ -265,6 +266,7 @@ def chat(args):
adapters=adapters,
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
eager_mode=args.eager_mode,
quant_policy=args.quant_policy)
run_chat(args.model_path,
engine_config,
Expand All @@ -275,6 +277,7 @@ def chat(args):
kwargs.pop('chat_template')
kwargs.pop('backend')
kwargs.pop('device')
kwargs.pop('eager_mode')
kwargs['chat_template_config'] = chat_template_config
run_chat(**kwargs)

Expand Down
7 changes: 6 additions & 1 deletion lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ def add_parser_gradio():

# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.device(pt_group)
ArgumentHelper.eager_mode(pt_group)

# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
tp_act = ArgumentHelper.tp(pt_group)
ArgumentHelper.device(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group)
max_batch_size_act = ArgumentHelper.max_batch_size(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
Expand Down Expand Up @@ -159,6 +160,8 @@ def add_parser_api_server():

ArgumentHelper.adapters(pt_group)
ArgumentHelper.device(pt_group)
ArgumentHelper.eager_mode(pt_group)

# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
tp_act = ArgumentHelper.tp(pt_group)
Expand Down Expand Up @@ -261,6 +264,7 @@ def gradio(args):
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
quant_policy=args.quant_policy,
eager_mode=args.eager_mode,
max_prefill_token_num=args.max_prefill_token_num)
else:
backend_config = TurbomindEngineConfig(
Expand Down Expand Up @@ -311,6 +315,7 @@ def api_server(args):
enable_prefix_caching=args.enable_prefix_caching,
device_type=args.device,
quant_policy=args.quant_policy,
eager_mode=args.eager_mode,
max_prefill_token_num=args.max_prefill_token_num)
else:
from lmdeploy.messages import TurbomindEngineConfig
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,13 @@ def disable_fastapi_docs(parser):
default=False,
help="Disable FastAPI's OpenAPI schema,"
' Swagger UI, and ReDoc endpoint')

@staticmethod
def eager_mode(parser):
"""Add argument eager_mode to parser."""

return parser.add_argument('--eager-mode',
action='store_true',
default=False,
help='Whether to enable eager mode. '
'If True, cuda graph would be disabled')

0 comments on commit 4958071

Please sign in to comment.