Skip to content

Commit

Permalink
Support custom conversation template in multi_model_worker (#2434)
Browse files Browse the repository at this point in the history
  • Loading branch information
hi-jin authored Sep 18, 2023
1 parent 2e0e60b commit c7e3e67
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion fastchat/serve/multi_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ def create_multi_model_worker():
action="append",
help="One or more model names. Values must be aligned with `--model-path` values.",
)
parser.add_argument(
"--conv-template",
type=str,
default=None,
action="append",
help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.",
)
parser.add_argument("--limit-worker-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
Expand All @@ -201,9 +208,16 @@ def create_multi_model_worker():
if args.model_names is None:
args.model_names = [[x.split("/")[-1]] for x in args.model_path]

if args.conv_template is None:
args.conv_template = [None] * len(args.model_path)
elif len(args.conv_template) == 1: # Repeat the same template
args.conv_template = args.conv_template * len(args.model_path)

# Launch all workers
workers = []
for model_path, model_names in zip(args.model_path, args.model_names):
for conv_template, model_path, model_names in zip(
args.conv_template, args.model_path, args.model_names
):
w = ModelWorker(
args.controller_address,
args.worker_address,
Expand All @@ -219,6 +233,7 @@ def create_multi_model_worker():
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
stream_interval=args.stream_interval,
conv_template=conv_template,
)
workers.append(w)
for model_name in model_names:
Expand Down

0 comments on commit c7e3e67

Please sign in to comment.