diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index adece21511..edd62dbd41 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -399,7 +399,7 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): max_rank=max_rank, meta=meta) - def _stoping_criteria(self, msg: SchedulerSequence, next_token_id: int): + def _stopping_criteria(self, msg: SchedulerSequence, next_token_id: int): """Check if the message should stop. Args: @@ -489,9 +489,23 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, msg.meta = meta msg.update_token_ids(token) msg.remain_output_len -= 1 - if self._stoping_criteria(msg, token): + if msg.remain_output_len < 0: + msg.token_ids = torch.empty((0, ), dtype=torch.long) + if self._stopping_criteria(msg, token): msg.status = MessageStatus.STOPPED + def _can_output_token(self, token: torch.Tensor, msg: SchedulerSequence): + """check if output is necessary.""" + if isinstance(token, torch.Tensor): + token = token.item() + if token == self.model_config.eos_token_id: + return False + + if token in msg.sampling_param.stop_words: + return False + + return True + def _model_forward(self, inputs: ModelInputs, swap_in_map: Dict, swap_out_map: Dict): """model forward.""" @@ -638,12 +652,16 @@ def step(self, is_prefill: bool, return_logits: bool = False): outputs: Dict[int, InferOutput] = dict() for msg, next_id in zip(running, next_token_ids): session_id = msg.session_id + if self._can_output_token(next_id, msg): + out_token_ids = [next_id.item()] + else: + out_token_ids = [] out = InferOutput( session_id=session_id, sender_id=msg.sender_id, req_id=msg.req_id, finish=(msg.status == MessageStatus.STOPPED), - token_ids=[next_id.item()], + token_ids=out_token_ids, ) outputs[session_id] = out diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 2ac81c0b84..462dd91973 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -102,9 +102,6 @@ def add_sequence(self, seq: SchedulerSequence): # push message to waiting queue self._set_message_status(seq, MessageStatus.WAITING) - if seq.remain_output_len <= 0: - seq.remain_output_len = \ - self.scheduler_config.max_request_output_len self.waiting.append(seq) def add_adapter(self, adapter_path: str, adapter_name: str):