Skip to content

Commit 7071970

Browse files
authored
Merge pull request #1343 from zhouye/main
Mirror #1247 in server mode
2 parents cc367f5 + 00949d5 commit 7071970

File tree

1 file changed

+62
-8
lines changed

1 file changed

+62
-8
lines changed

ktransformers/server/backend/interfaces/transformers.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
StaticCache,
1212
AutoModelForCausalLM,
1313
BitsAndBytesConfig,
14+
LogitsProcessorList,
15+
TemperatureLogitsWarper,
16+
TopKLogitsWarper,
17+
TopPLogitsWarper,
18+
MinPLogitsWarper,
19+
TypicalLogitsWarper,
20+
EpsilonLogitsWarper,
21+
EtaLogitsWarper,
1422
)
1523

1624
from ktransformers.server.config.config import Config
@@ -206,6 +214,58 @@ def append_new_tokens(self, new_tokens: int) -> Optional[str]:
206214
self.seq_length += 1
207215
return self.streamer.put(new_tokens)
208216

217+
@staticmethod
218+
def tf_logits_warper(generation_config):
219+
"""
220+
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
221+
used for multinomial sampling.
222+
"""
223+
224+
# instantiate warpers list
225+
warpers = LogitsProcessorList()
226+
227+
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
228+
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
229+
if generation_config.num_beams > 1:
230+
if isinstance(generation_config._eos_token_tensor, list):
231+
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
232+
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
233+
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
234+
else:
235+
min_tokens_to_keep = 2
236+
else:
237+
min_tokens_to_keep = 1
238+
239+
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
240+
# all samplers can be found in `generation_utils_samplers.py`
241+
if generation_config.temperature is not None and generation_config.temperature != 1.0:
242+
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
243+
if generation_config.top_k is not None and generation_config.top_k != 0:
244+
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
245+
if generation_config.top_p is not None and generation_config.top_p < 1.0:
246+
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
247+
if generation_config.min_p is not None:
248+
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
249+
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
250+
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
251+
warpers.append(
252+
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
253+
)
254+
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
255+
warpers.append(
256+
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
257+
)
258+
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
259+
warpers.append(
260+
EtaLogitsWarper(
261+
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
262+
)
263+
)
264+
# `LogitNormalization` should always be the last logit processor, when present
265+
if generation_config.renormalize_logits is True:
266+
warpers.append(LogitNormalization())
267+
return warpers
268+
209269
def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):
210270
if temperature is None or temperature == 0:
211271
temperature = self.model.generation_config.temperature
@@ -222,14 +282,8 @@ def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] =
222282
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
223283
)
224284
self.inputs = inputs
225-
try: # transformers==4.43
226-
self.logits_warper = (
227-
self.model._get_logits_warper(generation_config, device=device)
228-
)
229-
except:
230-
self.logits_warper = (
231-
self.model._get_logits_warper(generation_config)
232-
)
285+
286+
self.logits_warper = self.tf_logits_warper(generation_config)
233287

234288
def logits_to_token(self, logits: torch.Tensor):
235289
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))

0 commit comments

Comments
 (0)