Skip to content

Commit

Permalink
refactor lora tp1
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Sep 13, 2024
1 parent 64fe4c5 commit d895e51
Show file tree
Hide file tree
Showing 29 changed files with 605 additions and 2,010 deletions.
401 changes: 56 additions & 345 deletions lmdeploy/pytorch/adapter/adapter.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OpType(Enum):
GeluAndMul = auto()
RMSNorm = auto()
LayerNorm = auto()
SLoRA = auto()
LoRA = auto()
LinearW8A8 = auto()
RMSNormW8A8 = auto()
MultinomialSampling = auto()
Expand Down
80 changes: 80 additions & 0 deletions lmdeploy/pytorch/backends/cuda/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass

import torch

from lmdeploy.pytorch.kernels.cuda.fused_lora import fused_lora
from lmdeploy.pytorch.model_inputs import StepContextManager

from ..lora import AdapterInfo, LoRABuilder, LoRAImpl


@dataclass
class PackedLoRAInput:
"""packed lora input."""
x: torch.Tensor
q_start_loc: torch.Tensor
q_seqlens: torch.Tensor
adapter_ids: torch.Tensor
max_seq_len: int
is_decoding: bool


class TritonLoRAImpl(LoRAImpl):
"""triton lora implementation."""

@staticmethod
def _make_packed_lora_input(x, ctx_mgr):
"""make PackedLoRAInput."""
context = ctx_mgr.current_context()

# adapter cache
max_q_seq_length = x.numel() // x.size(-1)

return PackedLoRAInput(x=x.flatten(0, -2).contiguous(),
q_start_loc=context.q_start_loc,
q_seqlens=context.q_seqlens,
adapter_ids=context.local_adapter_ids,
max_seq_len=max_q_seq_length,
is_decoding=context.is_decoding)

def forward(self,
x: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
base_output: torch.Tensor,
adapter_info: AdapterInfo,
ctx_mgr: StepContextManager,
colwise: bool,
is_tp: bool = True):
"""forward."""
lora_input = self._make_packed_lora_input(x, ctx_mgr)

lora_out = fused_lora(lora_input.x,
lora_A,
lora_B,
scaling=adapter_info.scalings,
rank_start=adapter_info.rank_offsets,
ranks=adapter_info.ranks,
seq_start=lora_input.q_start_loc,
seq_lens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
max_rank=adapter_info.max_rank,
max_seqlen=lora_input.max_seq_len,
)

base_slice = adapter_info.base_slice
sliced_base = base_output[..., base_slice]
lora_out = lora_out.reshape(sliced_base.shape)
sliced_base.add_(lora_out)
output = base_output
return output


class TritonLoRABuilder(LoRABuilder):
"""triton lora layer builder."""

@staticmethod
def build():
"""build."""
return TritonLoRAImpl()
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.RMSNorm:
from .norm import TritonRMSNormBuilder
return TritonRMSNormBuilder
elif layer_type == OpType.SLoRA:
from .slora import TritonSLoRABuilder
return TritonSLoRABuilder
elif layer_type == OpType.LoRA:
from .lora import TritonLoRABuilder
return TritonLoRABuilder
elif layer_type == OpType.LinearW8A8:
from .qmodules import TritonLinearW8A8Builder
return TritonLinearW8A8Builder
Expand Down
224 changes: 0 additions & 224 deletions lmdeploy/pytorch/backends/cuda/slora.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field

import torch

Expand All @@ -14,20 +14,28 @@ class AdapterInfo:
out_features: int
ranks: torch.Tensor
scalings: torch.Tensor
rank_offsets: torch.Tensor
a_cache: torch.Tensor
b_cache: torch.Tensor
base_slice: slice
max_rank: int
rank_offsets: torch.Tensor = field(init=False)
max_rank: int = field(init=False)

def __post_init__(self):
"""post init."""
ranks = self.ranks
rank_offsets = ranks.cumsum(0) - ranks
max_rank = ranks.max().item()
self.rank_offsets = rank_offsets
self.max_rank = max_rank

class SLoRAImpl(ABC):
"""slora implementation api."""

class LoRAImpl(ABC):
"""lora implementation."""

@abstractmethod
def forward(self,
x: torch.Tensor,
base_output: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
adapter_info: AdapterInfo,
ctx_mgr: StepContextManager,
colwise: bool,
Expand All @@ -36,8 +44,8 @@ def forward(self,
raise NotImplementedError


class SLoRABuilder(ABC):
"""slora implementation builder."""
class LoRABuilder(ABC):
"""lora implementation builder."""

@staticmethod
@abstractmethod
Expand Down
Loading

0 comments on commit d895e51

Please sign in to comment.