Skip to content

Commit 6692e1d

Browse files
add n to pytorch example
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 4323474 commit 6692e1d

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def add_llm_args(parser):
103103
parser.add_argument("--temperature", type=float, default=None)
104104
parser.add_argument("--top_k", type=int, default=None)
105105
parser.add_argument("--top_p", type=float, default=None)
106+
parser.add_argument("--n", type=int, default=1)
106107
parser.add_argument('--load_format', type=str, default='auto')
107108
parser.add_argument('--max_beam_width', type=int, default=1)
108109

@@ -186,6 +187,13 @@ def setup_llm(args):
186187
else:
187188
spec_config = None
188189

190+
# TorchSampler needs to set mixed_sampler to True for non-greedy decoding.
191+
greedy_decoding = ((args.temperature == 0.0)
192+
or (args.top_k == 1 and
193+
(args.top_p == 0.0 or args.top_p is None)))
194+
mixed_sampler = not greedy_decoding and not args.enable_trtllm_sampler
195+
196+
189197
cuda_graph_config = CudaGraphConfig(
190198
batch_sizes=args.cuda_graph_batch_sizes,
191199
padding_enabled=args.cuda_graph_padding_enabled,
@@ -209,6 +217,7 @@ def setup_llm(args):
209217
if args.use_torch_compile else None,
210218
moe_backend=args.moe_backend,
211219
enable_trtllm_sampler=args.enable_trtllm_sampler,
220+
mixed_sampler=mixed_sampler,
212221
max_seq_len=args.max_seq_len,
213222
max_batch_size=args.max_batch_size,
214223
max_num_tokens=args.max_num_tokens,
@@ -224,6 +233,10 @@ def setup_llm(args):
224233
gather_generation_logits=args.return_generation_logits,
225234
max_beam_width=args.max_beam_width)
226235

236+
if args.max_beam_width > 1:
237+
# If beam search is used, set n to the beam width.
238+
args.n = args.max_beam_width
239+
227240
sampling_params = SamplingParams(
228241
max_tokens=args.max_tokens,
229242
temperature=args.temperature,
@@ -232,7 +245,7 @@ def setup_llm(args):
232245
return_context_logits=args.return_context_logits,
233246
return_generation_logits=args.return_generation_logits,
234247
logprobs=args.logprobs,
235-
n=args.max_beam_width,
248+
n=args.n,
236249
use_beam_search=args.max_beam_width > 1)
237250
return llm, sampling_params
238251

@@ -246,23 +259,23 @@ def main():
246259

247260
for i, output in enumerate(outputs):
248261
prompt = output.prompt
249-
for beam_idx, beam in enumerate(output.outputs):
250-
generated_text = beam.text
251-
# Skip printing the beam_idx if no beam search was used
252-
beam_id_text = f"[{beam_idx}]" if args.max_beam_width > 1 else ""
262+
for seq_idx, seq_output in enumerate(output.outputs):
263+
# Skip printing the sequnce index if a single sequence is returned.
264+
seq_id_text = f"[{seq_idx}]" if args.n > 1 else ""
265+
generated_text = seq_output.text
253266
print(
254-
f"[{i}]{beam_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
267+
f"[{i}]{seq_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
255268
)
256269
if args.return_context_logits:
257270
print(
258-
f"[{i}]{beam_id_text} Context logits: {output.context_logits}"
271+
f"[{i}]{seq_id_text} Context logits: {output.context_logits}"
259272
)
260273
if args.return_generation_logits:
261274
print(
262-
f"[{i}]{beam_id_text} Generation logits: {beam.generation_logits}"
275+
f"[{i}]{seq_id_text} Generation logits: {beam.generation_logits}"
263276
)
264277
if args.logprobs:
265-
print(f"[{i}]{beam_id_text} Logprobs: {beam.logprobs}")
278+
print(f"[{i}]{seq_id_text} Logprobs: {beam.logprobs}")
266279

267280

268281
if __name__ == '__main__':

0 commit comments

Comments
 (0)