-
Notifications
You must be signed in to change notification settings - Fork 442
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PytorchEngine refactor multimodal (#2742)
* WIP * support mrope * support long context * support causal=false * fix mask * flash attn bound * optimize * Moskau, Moskau, wirf die Gläser an die Wand * YMCA * optimize mllama * update processor * support cogvlm * all work and no play make jack a dull boy * upgrade triton * support qwen2vl * support internvl * phi3-v WIP * glm4v WIP * support chatglm and cogvlm * use image tokens * support llava * support internvl-mono * phi3v, mllama * add llavanext * use img token ids * support multiimage chatglm cogvlm * fix ut
- Loading branch information
Showing
38 changed files
with
5,148 additions
and
815 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from torch import Tensor | ||
|
||
from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl | ||
|
||
|
||
class TritonFlashAttentionImpl(FlashAttentionImpl): | ||
"""triton flash attention implementation.""" | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_dim: int, | ||
scale: float = None, | ||
num_kv_heads: int = None, | ||
v_head_dim: int = None, | ||
causal: bool = True, | ||
sliding_window: int = None, | ||
logical_softcapping: float = None, | ||
): | ||
if scale is None: | ||
scale = 1.0 / (head_dim**0.5) | ||
|
||
if num_kv_heads is None: | ||
num_kv_heads = num_heads | ||
|
||
if v_head_dim is None: | ||
v_head_dim = head_dim | ||
|
||
self.num_heads = num_heads | ||
self.head_dim = head_dim | ||
self.scale = scale | ||
self.num_kv_heads = num_kv_heads | ||
self.v_head_dim = v_head_dim | ||
self.causal = causal | ||
self.sliding_window = sliding_window | ||
self.logical_softcapping = logical_softcapping | ||
|
||
from lmdeploy.pytorch.kernels.cuda import flash_attention_fwd | ||
self.flash_attention_fwd = flash_attention_fwd | ||
|
||
def forward(self, | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
q_start_loc: Tensor, | ||
q_seqlens: Tensor, | ||
kv_start_loc: Tensor, | ||
kv_seqlens: Tensor, | ||
max_q_seqlen: int = None): | ||
"""forward.""" | ||
|
||
q_shape = query.shape | ||
o_shape = q_shape[:-1] + (self.v_head_dim, ) | ||
out = query.new_empty(o_shape) | ||
self.flash_attention_fwd( | ||
query, | ||
key, | ||
value, | ||
out, | ||
q_start_loc=q_start_loc, | ||
q_seqlens=q_seqlens, | ||
kv_start_loc=kv_start_loc, | ||
kv_seqlens=kv_seqlens, | ||
max_seqlen=max_q_seqlen, | ||
window_size=self.sliding_window, | ||
sm_scale=self.scale, | ||
logit_softcapping=self.logical_softcapping, | ||
causal=self.causal, | ||
kv_layout='shd', | ||
) | ||
|
||
return out | ||
|
||
|
||
class TritonFlashAttentionBuilder(FlashAttentionBuilder): | ||
"""triton attention builder.""" | ||
|
||
@staticmethod | ||
def build( | ||
num_heads: int, | ||
head_dim: int, | ||
scale: float = None, | ||
num_kv_heads: int = None, | ||
v_head_dim: int = None, | ||
causal: bool = True, | ||
sliding_window: int = None, | ||
logical_softcapping: float = None, | ||
**kwargs, | ||
) -> FlashAttentionImpl: | ||
"""build.""" | ||
return TritonFlashAttentionImpl( | ||
num_heads=num_heads, | ||
head_dim=head_dim, | ||
scale=scale, | ||
num_kv_heads=num_kv_heads, | ||
v_head_dim=v_head_dim, | ||
causal=causal, | ||
sliding_window=sliding_window, | ||
logical_softcapping=logical_softcapping, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import ABC, abstractmethod | ||
|
||
from torch import Tensor | ||
|
||
|
||
class FlashAttentionImpl(ABC): | ||
"""FlashAttention implementation.""" | ||
|
||
def forward(self, | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
q_start_loc: Tensor, | ||
q_seqlens: Tensor, | ||
kv_start_loc: Tensor, | ||
kv_seqlens: Tensor, | ||
max_q_seqlen: int = None): | ||
"""forward.""" | ||
raise NotImplementedError | ||
|
||
|
||
class FlashAttentionBuilder(ABC): | ||
"""FlashAttention implementation builder.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def build( | ||
num_heads: int, | ||
head_dim: int, | ||
scale: float = None, | ||
num_kv_heads: int = None, | ||
v_head_dim: int = None, | ||
causal: bool = True, | ||
sliding_window: int = None, | ||
logical_softcapping: float = None, | ||
**kwargs, | ||
) -> FlashAttentionImpl: | ||
"""build.""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.