Skip to content

Commit

Permalink
better streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Dec 5, 2024
1 parent b3a2887 commit 8f7a56f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
16 changes: 10 additions & 6 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ def __get_vlm_embeddings():
cross_length = torch.tensor([msg.num_cross for msg in messages])
history_cross_length = torch.tensor(
[msg.num_history_cross for msg in messages])
if (cross_length + history_cross_length).max().item() == 0:
cross_length = None
history_cross_length = None

return ModelInputs(
input_ids=input_ids,
Expand Down Expand Up @@ -675,10 +678,10 @@ async def __long_context_single_forward(inputs):
ret['logits'] = logits
return ret

def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
logits: torch.Tensor, stopped: torch.Tensor,
model_metas: List[Dict[str, Any]],
event: torch.cuda.Event):
async def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
logits: torch.Tensor, stopped: torch.Tensor,
model_metas: List[Dict[str, Any]],
event: torch.cuda.Event):
"""make infer output."""

def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence,
Expand All @@ -699,8 +702,9 @@ def __get_q_start_loc():
else:
return seq_length.cumsum(0) - seq_length

while not event.query():
await asyncio.sleep(0.001)
with torch.cuda.stream(self._output_stream):
event.wait()
next_token_ids = next_token_ids.cpu()
stopped = stopped.cpu()

Expand Down Expand Up @@ -1004,7 +1008,7 @@ async def __step():
if isinstance(out, Exception):
raise out
(next_token_ids, logits, stopped, model_metas, event) = out
step_outputs = self._make_infer_outputs(
step_outputs = await self._make_infer_outputs(
next_token_ids, logits, stopped, model_metas, event)
__send_resps(step_outputs)
except Exception as e:
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.sleep(0)
while not self.stream.query():
await asyncio.sleep(0)
return output

def get_logits(self, hidden_states: torch.Tensor):
Expand Down

0 comments on commit 8f7a56f

Please sign in to comment.