diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 0d256feef0..6c237b18cb 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -507,6 +507,9 @@ def __get_vlm_embeddings(): cross_length = torch.tensor([msg.num_cross for msg in messages]) history_cross_length = torch.tensor( [msg.num_history_cross for msg in messages]) + if (cross_length + history_cross_length).max().item() == 0: + cross_length = None + history_cross_length = None return ModelInputs( input_ids=input_ids, @@ -675,10 +678,10 @@ async def __long_context_single_forward(inputs): ret['logits'] = logits return ret - def _make_infer_outputs(self, next_token_ids: torch.LongTensor, - logits: torch.Tensor, stopped: torch.Tensor, - model_metas: List[Dict[str, Any]], - event: torch.cuda.Event): + async def _make_infer_outputs(self, next_token_ids: torch.LongTensor, + logits: torch.Tensor, stopped: torch.Tensor, + model_metas: List[Dict[str, Any]], + event: torch.cuda.Event): """make infer output.""" def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, @@ -699,8 +702,9 @@ def __get_q_start_loc(): else: return seq_length.cumsum(0) - seq_length + while not event.query(): + await asyncio.sleep(0.001) with torch.cuda.stream(self._output_stream): - event.wait() next_token_ids = next_token_ids.cpu() stopped = stopped.cpu() @@ -1004,7 +1008,7 @@ async def __step(): if isinstance(out, Exception): raise out (next_token_ids, logits, stopped, model_metas, event) = out - step_outputs = self._make_infer_outputs( + step_outputs = await self._make_infer_outputs( next_token_ids, logits, stopped, model_metas, event) __send_resps(step_outputs) except Exception as e: diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 421b171ee5..c1799f447b 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -276,8 +276,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_in_map=swap_in_map, swap_out_map=swap_out_map) await asyncio.sleep(0) - while not self.stream.query(): - await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor):