@@ -91,17 +91,14 @@ def _make_tensor(data: list, dtype: torch.dtype, device: torch.device) -> torch.
9191 def _prepare_logits_with_temperature (
9292 logits : torch .Tensor ,
9393 group_logit_indices : Optional [torch .Tensor ],
94- temperature : Optional [ torch .Tensor ] ,
94+ temperature : torch .Tensor ,
9595 ) -> torch .Tensor :
96- if temperature is not None :
97- temperature = temperature .unsqueeze (- 1 )
98- if group_logit_indices is not None :
99- logits = torch .index_select (logits , 0 , group_logit_indices ) # ensures copy
100- logits /= temperature
101- else :
102- logits = logits / temperature # not inplace
103- elif group_logit_indices is not None :
104- logits = logits [group_logit_indices ]
96+ temperature = temperature .unsqueeze (- 1 )
97+ if group_logit_indices is not None :
98+ logits = torch .index_select (logits , 0 , group_logit_indices ) # ensures copy
99+ logits /= temperature
100+ else :
101+ logits = logits / temperature # not inplace
105102 return logits
106103
107104 @staticmethod
@@ -112,12 +109,12 @@ def _prepare_probs_with_temperature(
112109 ) -> torch .Tensor :
113110 if group_logit_indices is not None :
114111 logits = logits [group_logit_indices ]
115- logits = flashinfer .sampling .softmax (
112+ probs = flashinfer .sampling .softmax (
116113 logits ,
117114 temperature ,
118115 enable_pdl = ENABLE_PDL ,
119116 )
120- return logits
117+ return probs
121118
122119 @classmethod
123120 def _sample_from_probs (
@@ -151,7 +148,7 @@ def _sample_with_probs(
151148 group_logit_indices : Optional [torch .Tensor ],
152149 top_k : Optional [torch .Tensor ],
153150 top_p : Optional [torch .Tensor ],
154- temperature : Optional [ torch .Tensor ] ,
151+ temperature : torch .Tensor ,
155152 generator : Optional [torch .Generator ],
156153 ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
157154 if top_k is not None :
0 commit comments