@@ -103,6 +103,7 @@ def add_llm_args(parser):
103103 parser .add_argument ("--temperature" , type = float , default = None )
104104 parser .add_argument ("--top_k" , type = int , default = None )
105105 parser .add_argument ("--top_p" , type = float , default = None )
106+ parser .add_argument ("--n" , type = int , default = 1 )
106107 parser .add_argument ('--load_format' , type = str , default = 'auto' )
107108 parser .add_argument ('--max_beam_width' , type = int , default = 1 )
108109
@@ -186,6 +187,13 @@ def setup_llm(args):
186187 else :
187188 spec_config = None
188189
190+ # TorchSampler needs to set mixed_sampler to True for non-greedy decoding.
191+ greedy_decoding = ((args .temperature == 0.0 )
192+ or (args .top_k == 1 and
193+ (args .top_p == 0.0 or args .top_p is None )))
194+ mixed_sampler = not greedy_decoding and not args .enable_trtllm_sampler
195+
196+
189197 cuda_graph_config = CudaGraphConfig (
190198 batch_sizes = args .cuda_graph_batch_sizes ,
191199 padding_enabled = args .cuda_graph_padding_enabled ,
@@ -209,6 +217,7 @@ def setup_llm(args):
209217 if args .use_torch_compile else None ,
210218 moe_backend = args .moe_backend ,
211219 enable_trtllm_sampler = args .enable_trtllm_sampler ,
220+ mixed_sampler = mixed_sampler ,
212221 max_seq_len = args .max_seq_len ,
213222 max_batch_size = args .max_batch_size ,
214223 max_num_tokens = args .max_num_tokens ,
@@ -224,6 +233,10 @@ def setup_llm(args):
224233 gather_generation_logits = args .return_generation_logits ,
225234 max_beam_width = args .max_beam_width )
226235
236+ if args .max_beam_width > 1 :
237+ # If beam search is used, set n to the beam width.
238+ args .n = args .max_beam_width
239+
227240 sampling_params = SamplingParams (
228241 max_tokens = args .max_tokens ,
229242 temperature = args .temperature ,
@@ -232,7 +245,7 @@ def setup_llm(args):
232245 return_context_logits = args .return_context_logits ,
233246 return_generation_logits = args .return_generation_logits ,
234247 logprobs = args .logprobs ,
235- n = args .max_beam_width ,
248+ n = args .n ,
236249 use_beam_search = args .max_beam_width > 1 )
237250 return llm , sampling_params
238251
@@ -246,23 +259,23 @@ def main():
246259
247260 for i , output in enumerate (outputs ):
248261 prompt = output .prompt
249- for beam_idx , beam in enumerate (output .outputs ):
250- generated_text = beam . text
251- # Skip printing the beam_idx if no beam search was used
252- beam_id_text = f"[ { beam_idx } ]" if args . max_beam_width > 1 else ""
262+ for seq_idx , seq_output in enumerate (output .outputs ):
263+ # Skip printing the sequnce index if a single sequence is returned.
264+ seq_id_text = f"[ { seq_idx } ]" if args . n > 1 else ""
265+ generated_text = seq_output . text
253266 print (
254- f"[{ i } ]{ beam_id_text } Prompt: { prompt !r} , Generated text: { generated_text !r} "
267+ f"[{ i } ]{ seq_id_text } Prompt: { prompt !r} , Generated text: { generated_text !r} "
255268 )
256269 if args .return_context_logits :
257270 print (
258- f"[{ i } ]{ beam_id_text } Context logits: { output .context_logits } "
271+ f"[{ i } ]{ seq_id_text } Context logits: { output .context_logits } "
259272 )
260273 if args .return_generation_logits :
261274 print (
262- f"[{ i } ]{ beam_id_text } Generation logits: { beam .generation_logits } "
275+ f"[{ i } ]{ seq_id_text } Generation logits: { beam .generation_logits } "
263276 )
264277 if args .logprobs :
265- print (f"[{ i } ]{ beam_id_text } Logprobs: { beam .logprobs } " )
278+ print (f"[{ i } ]{ seq_id_text } Logprobs: { beam .logprobs } " )
266279
267280
268281if __name__ == '__main__' :
0 commit comments