Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
zhulin1 committed Mar 5, 2024
2 parents 808fdca + 4bec832 commit b31d1df
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 473 deletions.
16 changes: 16 additions & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,11 +673,13 @@ async def __long_context_forward(inputs):
if token_count == 0 and slen > max_prefill_token_num:
tmp_out = await __long_context_single_forward(inputs, idx)
logits_gather.gather(tmp_out)
tmp_out.pop('logits', None)
idx += 1
elif token_count + slen > max_prefill_token_num:
tmp_out = await __long_context_batched_forward(
inputs, indices[0], idx)
logits_gather.gather(tmp_out)
tmp_out.pop('logits', None)
indices = []
token_count = 0
else:
Expand Down Expand Up @@ -834,6 +836,20 @@ async def _add_messages(session_ids, token_ids):
output_token_len = [len(token_ids) for token_ids in output_token_ids]
return (status, output_token_ids, output_token_len)

def batched_infer(self,
session_ids: List[int],
token_ids: List[List[int]] = None,
gen_config: EngineGenerationConfig = None,
adapter_names: List[str] = None,
keep_cache: bool = False):
"""batched infer."""
coro = self.async_batched_infer(session_ids,
token_ids,
gen_config=gen_config,
adapter_names=adapter_names,
keep_cache=keep_cache)
return self.req_sender.run_until_complete(coro)

def decode(self,
input_ids,
steps: List[int] = None,
Expand Down
5 changes: 1 addition & 4 deletions lmdeploy/pytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alibi_pagedattention import alibi_paged_attention_fwd
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .biased_pagedattention import biased_paged_attention_fwd
from .fill_kv_cache import fill_kv_cache
from .flashattention_nopad import context_attention_fwd
from .fused_rotary_emb import fused_rotary_emb
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rerope_attention import rerope_attention_fwd
from .rms_norm import rms_norm

__all__ = [
'apply_rotary_pos_emb', 'context_attention_fwd', 'fused_rotary_emb',
'paged_attention_fwd', 'biased_paged_attention_fwd',
'apply_rotary_pos_emb', 'fused_rotary_emb', 'paged_attention_fwd',
'alibi_paged_attention_fwd', 'fill_kv_cache', 'multinomial_sampling',
'rms_norm', 'rerope_attention_fwd'
]
240 changes: 0 additions & 240 deletions lmdeploy/pytorch/kernels/biased_pagedattention.py

This file was deleted.

Loading

0 comments on commit b31d1df

Please sign in to comment.