Skip to content

Commit

Permalink
add baichuan lint (#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan authored Sep 20, 2023
1 parent 1231302 commit b478b31
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions lmdeploy/pytorch_poc/patch/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def _attention_partition_fn(mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""A function for attention partition."""
if mod_name in ['W_pack']:
for name, param in mod.named_parameters():
param = param.unflatten(0, (3, -1))
Expand All @@ -37,10 +38,12 @@ class Attention(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
return _attention_partition_fn(mod_name, mod, device_mesh)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs

Expand All @@ -54,6 +57,7 @@ def forward(
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of Attention.forward."""
if self.context.use_origin:
return self.origin_mod(
hidden_states,
Expand Down Expand Up @@ -85,6 +89,11 @@ def _contiguous_batching_forward(
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of Attention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert not output_attentions
context = self.context.context
history_lengths = context.history_lengths
Expand Down Expand Up @@ -125,10 +134,12 @@ class BaichuanAttention(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
return _attention_partition_fn(mod_name, mod, device_mesh)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs

Expand All @@ -141,6 +152,7 @@ def forward(
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of BaichuanAttention.forward."""
if self.context.use_origin:
return self.origin_mod(
hidden_states,
Expand Down Expand Up @@ -168,6 +180,11 @@ def _contiguous_batching_forward(
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of BaichuanAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert not output_attentions
context = self.context.context
position_ids = context.position_ids
Expand Down Expand Up @@ -210,6 +227,11 @@ def _continuous_batching_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Rewrite implementation of BaichuanModel.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Attention mask is not necessary in continuous batching
Expand Down Expand Up @@ -275,6 +297,7 @@ def forward(
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
):
"""Rewrite of BaichuanModel.forward."""
use_origin = self.context.use_origin
if use_origin:
# use origin model
Expand Down

0 comments on commit b478b31

Please sign in to comment.