Skip to content

Commit

Permalink
fix cache position for pytorch engine (#2388)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon authored Aug 27, 2024
1 parent 97b880b commit c57b635
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 19 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of Attention.forward."""
Expand Down Expand Up @@ -186,6 +187,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of BaichuanAttention.forward."""
Expand Down
17 changes: 8 additions & 9 deletions lmdeploy/pytorch/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,14 @@ def _contiguous_batching_forward(

return output, kv_cache

def forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
output_attentions=False,
):
def forward(self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
output_attentions=False,
**kwargs):
return self._contiguous_batching_forward(
hidden_states,
rotary_pos_emb,
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""forward."""
Expand Down
19 changes: 9 additions & 10 deletions lmdeploy/pytorch/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,15 @@ def __rotary_emb_fn(query_states, key_states, value_states):
else:
return output_tensor, layer_past

def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
def forward(self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs):
return self._contiguous_batching_forward(hidden_states, alibi,
layer_past)

Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""forward."""
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""rewrite of forward."""
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""forward."""
Expand Down

0 comments on commit c57b635

Please sign in to comment.