Skip to content

Commit

Permalink
PytorchEngine refactor multimodal (#2742)
Browse files Browse the repository at this point in the history
* WIP

* support mrope

* support long context

* support causal=false

* fix mask

* flash attn bound

* optimize

* Moskau, Moskau, wirf die Gläser an die Wand

* YMCA

* optimize mllama

* update processor

* support cogvlm

* all work and no play make jack a dull boy

* upgrade triton

* support qwen2vl

* support internvl

* phi3-v WIP

* glm4v WIP

* support chatglm and cogvlm

* use image tokens

* support llava

* support internvl-mono

* phi3v, mllama

* add llavanext

* use img token ids

* support multiimage chatglm cogvlm

* fix ut
  • Loading branch information
grimoire authored Nov 25, 2024
1 parent 787f765 commit 099721a
Show file tree
Hide file tree
Showing 38 changed files with 5,148 additions and 815 deletions.
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
alibi: bool = None,
sliding_window: int = None,
logit_softcapping: float = None,
causal: bool = True,
**kwargs,
) -> None:
if scale is None:
Expand All @@ -53,6 +54,7 @@ def __init__(
self.alibi = alibi
self.sliding_window = sliding_window
self.logit_softcapping = logit_softcapping
self.causal = causal

@abstractmethod
def forward(
Expand Down Expand Up @@ -82,6 +84,7 @@ def build(
alibi: bool = False,
sliding_window: int = None,
logical_softcapping: float = None,
causal: bool = True,
**kwargs,
) -> AttentionImpl[T]:
"""build."""
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

class OpType(Enum):
"""Layer type enumerate."""
Attention = auto()
PagedAttention = auto()
FlashAttention = auto()
Linear = auto()
RotaryEmbedding = auto()
ApplyRotaryEmb = auto()
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
alibi: bool = False,
sliding_window: int = None,
logit_softcapping: float = None,
causal: bool = True,
**kwargs,
):
super().__init__(
Expand All @@ -52,8 +53,10 @@ def __init__(
alibi=alibi,
sliding_window=sliding_window,
logit_softcapping=logit_softcapping,
causal=causal,
**kwargs,
)
assert not (alibi and not causal)

from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd,
fill_kv_cache,
Expand Down Expand Up @@ -169,6 +172,7 @@ def forward(
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
causal=self.causal,
)
else:
self.alibi_paged_attention_fwd(
Expand Down Expand Up @@ -204,6 +208,7 @@ def build(
alibi: bool = False,
sliding_window: int = None,
logical_softcapping: float = None,
causal: bool = True,
**kwargs,
) -> TritonAttentionImpl:
"""build."""
Expand All @@ -215,4 +220,5 @@ def build(
alibi=alibi,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
101 changes: 101 additions & 0 deletions lmdeploy/pytorch/backends/cuda/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor

from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl


class TritonFlashAttentionImpl(FlashAttentionImpl):
"""triton flash attention implementation."""

def __init__(
self,
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
):
if scale is None:
scale = 1.0 / (head_dim**0.5)

if num_kv_heads is None:
num_kv_heads = num_heads

if v_head_dim is None:
v_head_dim = head_dim

self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.v_head_dim = v_head_dim
self.causal = causal
self.sliding_window = sliding_window
self.logical_softcapping = logical_softcapping

from lmdeploy.pytorch.kernels.cuda import flash_attention_fwd
self.flash_attention_fwd = flash_attention_fwd

def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_start_loc: Tensor,
kv_seqlens: Tensor,
max_q_seqlen: int = None):
"""forward."""

q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_dim, )
out = query.new_empty(o_shape)
self.flash_attention_fwd(
query,
key,
value,
out,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
max_seqlen=max_q_seqlen,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logical_softcapping,
causal=self.causal,
kv_layout='shd',
)

return out


class TritonFlashAttentionBuilder(FlashAttentionBuilder):
"""triton attention builder."""

@staticmethod
def build(
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
**kwargs,
) -> FlashAttentionImpl:
"""build."""
return TritonFlashAttentionImpl(
num_heads=num_heads,
head_dim=head_dim,
scale=scale,
num_kv_heads=num_kv_heads,
v_head_dim=v_head_dim,
causal=causal,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
)
51 changes: 27 additions & 24 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ def get_name() -> str:
@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get cuda layer builder."""
if layer_type == OpType.Attention:
if layer_type == OpType.PagedAttention:
from .attention import TritonAttentionBuilder
return TritonAttentionBuilder
elif layer_type == OpType.FlashAttention:
from .flash_attention import TritonFlashAttentionBuilder
return TritonFlashAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
from .apply_rotary_emb import TritonApplyRotaryEmbBuilder
return TritonApplyRotaryEmbBuilder
Expand Down Expand Up @@ -121,30 +124,30 @@ def update_step_context(cls, step_context):
quant_policy=step_context.kv_quant_policy,
)

cross_attn_metadata = None
fill_seqlens = None
if step_context.cross_attention_states is not None:
fill_seqlens = torch.zeros_like(q_seqlens)
for idx, state in enumerate(step_context.cross_attention_states):
if state is not None:
fill_seqlens[idx] = state.shape[-2]
cross_seqlens = step_context.cross_seqlens
cross_kv_seqlens = step_context.cross_kv_seqlens
cross_kv_start_loc = None
cross_kv_flatten_size = None
if not step_context.is_decoding and cross_kv_seqlens is not None:
cross_kv_start_loc = cross_kv_seqlens.cumsum(0) - cross_kv_seqlens
cross_kv_flatten_size = cross_kv_seqlens.sum().item()
cross_attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=cross_kv_start_loc,
kv_seqlens=cross_kv_seqlens,
kv_flatten_size=cross_kv_flatten_size,
fill_seqlens=fill_seqlens,
quant_policy=step_context.kv_quant_policy,
)
cross_attn_metadata = None
if cross_seqlens is not None:
fill_seqlens = cross_seqlens
if fill_seqlens.sum().item() == 0:
fill_seqlens = None
cross_kv_start_loc = None
cross_kv_flatten_size = None
if not step_context.is_decoding and cross_kv_seqlens is not None:
cross_kv_start_loc = cross_kv_seqlens.cumsum(
0) - cross_kv_seqlens
cross_kv_flatten_size = cross_kv_seqlens.sum().item()
cross_attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=cross_kv_start_loc,
kv_seqlens=cross_kv_seqlens,
kv_flatten_size=cross_kv_flatten_size,
fill_seqlens=fill_seqlens,
quant_policy=step_context.kv_quant_policy,
)

step_context.attn_metadata = attn_metadata
step_context.cross_attn_metadata = cross_attn_metadata
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def __init__(
alibi: bool = None,
sliding_window: int = None,
logit_softcapping: float = None,
causal: bool = True,
**kwargs,
):
assert causal
super().__init__(
num_heads,
head_size,
Expand All @@ -41,6 +43,7 @@ def __init__(
alibi,
sliding_window,
logit_softcapping,
causal=causal,
**kwargs,
)

Expand Down Expand Up @@ -121,6 +124,7 @@ def build(
alibi_scale: float = None,
sliding_window: int = None,
logical_softcapping: float = None,
causal: bool = True,
**kwargs,
) -> DlinferAttentionImpl:
"""build."""
Expand All @@ -132,4 +136,5 @@ def build(
alibi_scale=alibi_scale,
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_name() -> str:
@classmethod
def get_layer_impl_builder(cls, layer_type: OpType):
"""get dlinfer layer builder."""
if layer_type == OpType.Attention:
if layer_type == OpType.PagedAttention:
from .attention import DlinferAttentionBuilder
return DlinferAttentionBuilder
elif layer_type == OpType.ApplyRotaryEmb:
Expand Down
40 changes: 40 additions & 0 deletions lmdeploy/pytorch/backends/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

from torch import Tensor


class FlashAttentionImpl(ABC):
"""FlashAttention implementation."""

def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_start_loc: Tensor,
kv_seqlens: Tensor,
max_q_seqlen: int = None):
"""forward."""
raise NotImplementedError


class FlashAttentionBuilder(ABC):
"""FlashAttention implementation builder."""

@staticmethod
@abstractmethod
def build(
num_heads: int,
head_dim: int,
scale: float = None,
num_kv_heads: int = None,
v_head_dim: int = None,
causal: bool = True,
sliding_window: int = None,
logical_softcapping: float = None,
**kwargs,
) -> FlashAttentionImpl:
"""build."""
raise NotImplementedError
23 changes: 23 additions & 0 deletions lmdeploy/pytorch/backends/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,26 @@ def prepare_inputs_for_generation(
inputs_embeds,
context,
)

def update_model_metas(
self,
past_key_values: List[List[torch.Tensor]],
inputs_embeds: torch.Tensor = None,
context: StepContext = None,
):
"""prepare inputs."""
if hasattr(self.model, 'update_model_metas'):
return self.model.update_model_metas(
past_key_values,
inputs_embeds,
context,
)

return None

def get_input_processor(self):
"""get input processor."""
if hasattr(self.model, 'get_input_processor'):
return self.model.get_input_processor()
else:
return None
21 changes: 10 additions & 11 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def check_env_torch():
_handle_exception(e, 'PyTorch', logger)


MIN_TRITON_VERSION = '3.0.0'
MAX_TRITON_VERSION = '3.0.0'


Expand All @@ -74,8 +75,10 @@ def check_env_triton(device: str):
logger.debug('Checking <Triton> environment.')
import torch
import triton
max_version = version.parse(MAX_TRITON_VERSION)
triton_version = version.parse(triton.__version__)
if triton_version > version.parse(MAX_TRITON_VERSION):

if triton_version > max_version:
logger.warning(
f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.')

Expand All @@ -96,16 +99,12 @@ def check_env_triton(device: str):
_handle_exception(e, 'Triton', logger, msg)

if device == 'cuda':
device_cap = torch.cuda.get_device_capability()
TRITON_VER_231 = version.parse('2.3.1')

if device_cap[0] <= 7:
if triton_version <= TRITON_VER_231:
err = RuntimeError(
'Attention triton kernel does not fully support '
'triton<3.0.0 on device with capability<8. '
'Please upgrade your triton version.')
_handle_exception(err, 'Triton', logger)
min_version = version.parse(MIN_TRITON_VERSION)
if triton_version < min_version:
msg = (f'triton>={MIN_TRITON_VERSION} is required. '
f'Found triton=={triton_version}')
e = RuntimeError(msg)
_handle_exception(e, 'Triton', logger, msg)


def check_env(device_type: str):
Expand Down
Loading

0 comments on commit 099721a

Please sign in to comment.