Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass

import torch

from lmdeploy.pytorch.kernels.dlinfer.fused_lora import fused_lora
from lmdeploy.pytorch.model_inputs import StepContextManager

from ..lora import AdapterInfo, LoRABuilder, LoRAImpl


@dataclass
class PackedLoRAInput:
"""Packed lora input."""
x: torch.Tensor
q_start_loc: torch.Tensor
q_seqlens: torch.Tensor
adapter_ids: torch.Tensor
max_seq_len: int
is_decoding: bool


class DlinferLoRAImpl(LoRAImpl):
"""Triton lora implementation."""

@staticmethod
def _make_packed_lora_input(x, ctx_mgr):
"""Make PackedLoRAInput."""
context = ctx_mgr.current_context()

# adapter cache
max_q_seq_length = x.numel() // x.size(-1)

return PackedLoRAInput(x=x.flatten(0, -2).contiguous(),
q_start_loc=context.q_start_loc,
q_seqlens=context.q_seqlens,
adapter_ids=context.local_adapter_ids,
max_seq_len=max_q_seq_length,
is_decoding=context.is_decoding)

def forward(self,
x: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
base_output: torch.Tensor,
adapter_info: AdapterInfo,
ctx_mgr: StepContextManager,
colwise: bool,
is_tp: bool = True):
"""forward."""
lora_input = self._make_packed_lora_input(x, ctx_mgr)

return fused_lora(
lora_input.x,
lora_A,
lora_B,
scaling=adapter_info.scalings,
rank_start=adapter_info.rank_offsets,
ranks=adapter_info.ranks,
seq_start=lora_input.q_start_loc,
seq_lens=lora_input.q_seqlens,
adapter_ids=lora_input.adapter_ids,
max_rank=adapter_info.max_rank,
max_seqlen=lora_input.max_seq_len,
slice_start=adapter_info.base_slice.start,
slice_stop=adapter_info.base_slice.stop,
slice_step=adapter_info.base_slice.step,
output=base_output,
)


class DlinferLoRABuilder(LoRABuilder):
"""Dlinfer lora layer builder."""

@staticmethod
def build():
"""build."""
return DlinferLoRAImpl()
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.RotaryEmbedding:
from .rotary_embedding import DlinferRotaryEmbeddingBuilder
return DlinferRotaryEmbeddingBuilder
elif layer_type == OpType.LoRA:
from .lora import DlinferLoRABuilder
return DlinferLoRABuilder
else:
logger.debug(f'Op {layer_type} fallback to default implementation.')
return super().get_layer_impl_builder(layer_type)
Expand Down
13 changes: 13 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/fused_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import dlinfer.ops as ext_ops
from torch import Tensor


def fused_lora(input: Tensor, lora_a: Tensor, lora_b: Tensor, scaling: Tensor, rank_start: Tensor, ranks: Tensor,
seq_start: Tensor, seq_lens: Tensor, adapter_ids: Tensor, max_rank: int, max_seqlen: int,
slice_start: int, slice_stop: int, slice_step: Optional[int], output: Optional[Tensor]):
"""Fused lora."""
return ext_ops.fused_lora(input, lora_a, lora_b, scaling, rank_start, ranks, seq_start, seq_lens, adapter_ids,
max_rank, max_seqlen, slice_start, slice_stop, slice_step, output)
6 changes: 4 additions & 2 deletions requirements/runtime_ascend.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ safetensors
sentencepiece
shortuuid
tiktoken
torch<=2.4.0,>=2.3.1
torch-npu==2.3.1
# Supported torch versions: 2.3.1, 2.5.1, 2.6.0, 2.7.1
# Please install one of the supported versions manually
torch>=2.3.1,<2.8.0
torch-npu>=2.3.1,<2.8.0
torchvision<=0.19.0,>=0.18.1
transformers
uvicorn