1111 StaticCache ,
1212 AutoModelForCausalLM ,
1313 BitsAndBytesConfig ,
14+ LogitsProcessorList ,
15+ TemperatureLogitsWarper ,
16+ TopKLogitsWarper ,
17+ TopPLogitsWarper ,
18+ MinPLogitsWarper ,
19+ TypicalLogitsWarper ,
20+ EpsilonLogitsWarper ,
21+ EtaLogitsWarper ,
1422)
1523
1624from ktransformers .server .config .config import Config
@@ -206,6 +214,58 @@ def append_new_tokens(self, new_tokens: int) -> Optional[str]:
206214 self .seq_length += 1
207215 return self .streamer .put (new_tokens )
208216
217+ @staticmethod
218+ def tf_logits_warper (generation_config ):
219+ """
220+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
221+ used for multinomial sampling.
222+ """
223+
224+ # instantiate warpers list
225+ warpers = LogitsProcessorList ()
226+
227+ # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
228+ # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
229+ if generation_config .num_beams > 1 :
230+ if isinstance (generation_config ._eos_token_tensor , list ):
231+ min_tokens_to_keep = len (generation_config ._eos_token_tensor ) + 1
232+ elif isinstance (generation_config ._eos_token_tensor , torch .Tensor ):
233+ min_tokens_to_keep = generation_config ._eos_token_tensor .shape [0 ] + 1
234+ else :
235+ min_tokens_to_keep = 2
236+ else :
237+ min_tokens_to_keep = 1
238+
239+ # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
240+ # all samplers can be found in `generation_utils_samplers.py`
241+ if generation_config .temperature is not None and generation_config .temperature != 1.0 :
242+ warpers .append (TemperatureLogitsWarper (generation_config .temperature ))
243+ if generation_config .top_k is not None and generation_config .top_k != 0 :
244+ warpers .append (TopKLogitsWarper (top_k = generation_config .top_k , min_tokens_to_keep = min_tokens_to_keep ))
245+ if generation_config .top_p is not None and generation_config .top_p < 1.0 :
246+ warpers .append (TopPLogitsWarper (top_p = generation_config .top_p , min_tokens_to_keep = min_tokens_to_keep ))
247+ if generation_config .min_p is not None :
248+ # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
249+ warpers .append (MinPLogitsWarper (min_p = generation_config .min_p , min_tokens_to_keep = min_tokens_to_keep ))
250+ if generation_config .typical_p is not None and generation_config .typical_p < 1.0 :
251+ warpers .append (
252+ TypicalLogitsWarper (mass = generation_config .typical_p , min_tokens_to_keep = min_tokens_to_keep )
253+ )
254+ if generation_config .epsilon_cutoff is not None and 0.0 < generation_config .epsilon_cutoff < 1.0 :
255+ warpers .append (
256+ EpsilonLogitsWarper (epsilon = generation_config .epsilon_cutoff , min_tokens_to_keep = min_tokens_to_keep )
257+ )
258+ if generation_config .eta_cutoff is not None and 0.0 < generation_config .eta_cutoff < 1.0 :
259+ warpers .append (
260+ EtaLogitsWarper (
261+ epsilon = generation_config .eta_cutoff , min_tokens_to_keep = min_tokens_to_keep , device = device
262+ )
263+ )
264+ # `LogitNormalization` should always be the last logit processor, when present
265+ if generation_config .renormalize_logits is True :
266+ warpers .append (LogitNormalization ())
267+ return warpers
268+
209269 def prepare_logits_wrapper (self , inputs , device , temperature : Optional [float ] = None , top_p : Optional [float ] = None ):
210270 if temperature is None or temperature == 0 :
211271 temperature = self .model .generation_config .temperature
@@ -222,14 +282,8 @@ def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] =
222282 repetition_penalty = self .args .repetition_penalty # change this to modify generate config
223283 )
224284 self .inputs = inputs
225- try : # transformers==4.43
226- self .logits_warper = (
227- self .model ._get_logits_warper (generation_config , device = device )
228- )
229- except :
230- self .logits_warper = (
231- self .model ._get_logits_warper (generation_config )
232- )
285+
286+ self .logits_warper = self .tf_logits_warper (generation_config )
233287
234288 def logits_to_token (self , logits : torch .Tensor ):
235289 logits = self .logits_warper (self .inputs .view (1 , - 1 ), logits .view (1 , - 1 ))
0 commit comments