Skip to content

Commit 5c0d5ba

Browse files
add n to pytorch example
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent e8ce36b commit 5c0d5ba

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

examples/pytorch/quickstart_advanced.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def add_llm_args(parser):
104104
parser.add_argument("--temperature", type=float, default=None)
105105
parser.add_argument("--top_k", type=int, default=None)
106106
parser.add_argument("--top_p", type=float, default=None)
107+
parser.add_argument("--n", type=int, default=1)
107108
parser.add_argument('--load_format', type=str, default='auto')
108109
parser.add_argument('--max_beam_width', type=int, default=1)
109110

@@ -187,6 +188,13 @@ def setup_llm(args):
187188
else:
188189
spec_config = None
189190

191+
# TorchSampler needs to set mixed_sampler to True for non-greedy decoding.
192+
greedy_decoding = ((args.temperature == 0.0)
193+
or (args.top_k == 1 and
194+
(args.top_p == 0.0 or args.top_p is None)))
195+
mixed_sampler = not greedy_decoding and not args.enable_trtllm_sampler
196+
197+
190198
cuda_graph_config = CudaGraphConfig(
191199
batch_sizes=args.cuda_graph_batch_sizes,
192200
enable_padding=args.cuda_graph_padding_enabled,
@@ -210,6 +218,7 @@ def setup_llm(args):
210218
if args.use_torch_compile else None,
211219
moe_backend=args.moe_backend,
212220
enable_trtllm_sampler=args.enable_trtllm_sampler,
221+
mixed_sampler=mixed_sampler,
213222
max_seq_len=args.max_seq_len,
214223
max_batch_size=args.max_batch_size,
215224
max_num_tokens=args.max_num_tokens,
@@ -225,6 +234,10 @@ def setup_llm(args):
225234
gather_generation_logits=args.return_generation_logits,
226235
max_beam_width=args.max_beam_width)
227236

237+
if args.max_beam_width > 1:
238+
# If beam search is used, set n to the beam width.
239+
args.n = args.max_beam_width
240+
228241
sampling_params = SamplingParams(
229242
max_tokens=args.max_tokens,
230243
temperature=args.temperature,
@@ -233,7 +246,7 @@ def setup_llm(args):
233246
return_context_logits=args.return_context_logits,
234247
return_generation_logits=args.return_generation_logits,
235248
logprobs=args.logprobs,
236-
n=args.max_beam_width,
249+
n=args.n,
237250
use_beam_search=args.max_beam_width > 1)
238251
return llm, sampling_params
239252

@@ -247,23 +260,23 @@ def main():
247260

248261
for i, output in enumerate(outputs):
249262
prompt = output.prompt
250-
for beam_idx, beam in enumerate(output.outputs):
251-
generated_text = beam.text
252-
# Skip printing the beam_idx if no beam search was used
253-
beam_id_text = f"[{beam_idx}]" if args.max_beam_width > 1 else ""
263+
for seq_idx, seq_output in enumerate(output.outputs):
264+
# Skip printing the sequnce index if a single sequence is returned.
265+
seq_id_text = f"[{seq_idx}]" if args.n > 1 else ""
266+
generated_text = seq_output.text
254267
print(
255-
f"[{i}]{beam_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
268+
f"[{i}]{seq_id_text} Prompt: {prompt!r}, Generated text: {generated_text!r}"
256269
)
257270
if args.return_context_logits:
258271
print(
259-
f"[{i}]{beam_id_text} Context logits: {output.context_logits}"
272+
f"[{i}]{seq_id_text} Context logits: {output.context_logits}"
260273
)
261274
if args.return_generation_logits:
262275
print(
263-
f"[{i}]{beam_id_text} Generation logits: {beam.generation_logits}"
276+
f"[{i}]{seq_id_text} Generation logits: {beam.generation_logits}"
264277
)
265278
if args.logprobs:
266-
print(f"[{i}]{beam_id_text} Logprobs: {beam.logprobs}")
279+
print(f"[{i}]{seq_id_text} Logprobs: {beam.logprobs}")
267280

268281

269282
if __name__ == '__main__':

0 commit comments

Comments
 (0)