@@ -104,6 +104,7 @@ def add_llm_args(parser):
104104 parser .add_argument ("--temperature" , type = float , default = None )
105105 parser .add_argument ("--top_k" , type = int , default = None )
106106 parser .add_argument ("--top_p" , type = float , default = None )
107+ parser .add_argument ("--n" , type = int , default = 1 )
107108 parser .add_argument ('--load_format' , type = str , default = 'auto' )
108109 parser .add_argument ('--max_beam_width' , type = int , default = 1 )
109110
@@ -187,6 +188,13 @@ def setup_llm(args):
187188 else :
188189 spec_config = None
189190
191+ # TorchSampler needs to set mixed_sampler to True for non-greedy decoding.
192+ greedy_decoding = ((args .temperature == 0.0 )
193+ or (args .top_k == 1 and
194+ (args .top_p == 0.0 or args .top_p is None )))
195+ mixed_sampler = not greedy_decoding and not args .enable_trtllm_sampler
196+
197+
190198 cuda_graph_config = CudaGraphConfig (
191199 batch_sizes = args .cuda_graph_batch_sizes ,
192200 enable_padding = args .cuda_graph_padding_enabled ,
@@ -210,6 +218,7 @@ def setup_llm(args):
210218 if args .use_torch_compile else None ,
211219 moe_backend = args .moe_backend ,
212220 enable_trtllm_sampler = args .enable_trtllm_sampler ,
221+ mixed_sampler = mixed_sampler ,
213222 max_seq_len = args .max_seq_len ,
214223 max_batch_size = args .max_batch_size ,
215224 max_num_tokens = args .max_num_tokens ,
@@ -225,6 +234,10 @@ def setup_llm(args):
225234 gather_generation_logits = args .return_generation_logits ,
226235 max_beam_width = args .max_beam_width )
227236
237+ if args .max_beam_width > 1 :
238+ # If beam search is used, set n to the beam width.
239+ args .n = args .max_beam_width
240+
228241 sampling_params = SamplingParams (
229242 max_tokens = args .max_tokens ,
230243 temperature = args .temperature ,
@@ -233,7 +246,7 @@ def setup_llm(args):
233246 return_context_logits = args .return_context_logits ,
234247 return_generation_logits = args .return_generation_logits ,
235248 logprobs = args .logprobs ,
236- n = args .max_beam_width ,
249+ n = args .n ,
237250 use_beam_search = args .max_beam_width > 1 )
238251 return llm , sampling_params
239252
@@ -247,23 +260,23 @@ def main():
247260
248261 for i , output in enumerate (outputs ):
249262 prompt = output .prompt
250- for beam_idx , beam in enumerate (output .outputs ):
251- generated_text = beam . text
252- # Skip printing the beam_idx if no beam search was used
253- beam_id_text = f"[ { beam_idx } ]" if args . max_beam_width > 1 else ""
263+ for seq_idx , seq_output in enumerate (output .outputs ):
264+ # Skip printing the sequnce index if a single sequence is returned.
265+ seq_id_text = f"[ { seq_idx } ]" if args . n > 1 else ""
266+ generated_text = seq_output . text
254267 print (
255- f"[{ i } ]{ beam_id_text } Prompt: { prompt !r} , Generated text: { generated_text !r} "
268+ f"[{ i } ]{ seq_id_text } Prompt: { prompt !r} , Generated text: { generated_text !r} "
256269 )
257270 if args .return_context_logits :
258271 print (
259- f"[{ i } ]{ beam_id_text } Context logits: { output .context_logits } "
272+ f"[{ i } ]{ seq_id_text } Context logits: { output .context_logits } "
260273 )
261274 if args .return_generation_logits :
262275 print (
263- f"[{ i } ]{ beam_id_text } Generation logits: { beam .generation_logits } "
276+ f"[{ i } ]{ seq_id_text } Generation logits: { beam .generation_logits } "
264277 )
265278 if args .logprobs :
266- print (f"[{ i } ]{ beam_id_text } Logprobs: { beam .logprobs } " )
279+ print (f"[{ i } ]{ seq_id_text } Logprobs: { beam .logprobs } " )
267280
268281
269282if __name__ == '__main__' :
0 commit comments