diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 92a0befbf..f0e60d86a 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -34,6 +34,7 @@ def __init__( alibi: bool = None, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ) -> None: if scale is None: @@ -53,6 +54,7 @@ def __init__( self.alibi = alibi self.sliding_window = sliding_window self.logit_softcapping = logit_softcapping + self.causal = causal @abstractmethod def forward( @@ -82,6 +84,7 @@ def build( alibi: bool = False, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> AttentionImpl[T]: """build.""" diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index ef538f7a3..c8623666d 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -12,7 +12,8 @@ class OpType(Enum): """Layer type enumerate.""" - Attention = auto() + PagedAttention = auto() + FlashAttention = auto() Linear = auto() RotaryEmbedding = auto() ApplyRotaryEmb = auto() diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index d01d6fe9b..ff1b86d3a 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -41,6 +41,7 @@ def __init__( alibi: bool = False, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ): super().__init__( @@ -52,8 +53,10 @@ def __init__( alibi=alibi, sliding_window=sliding_window, logit_softcapping=logit_softcapping, + causal=causal, **kwargs, ) + assert not (alibi and not causal) from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd, fill_kv_cache, @@ -169,6 +172,7 @@ def forward( window_size=self.sliding_window, sm_scale=self.scale, logit_softcapping=self.logit_softcapping, + causal=self.causal, ) else: self.alibi_paged_attention_fwd( @@ -204,6 +208,7 @@ def build( alibi: bool = False, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> TritonAttentionImpl: """build.""" @@ -215,4 +220,5 @@ def build( alibi=alibi, sliding_window=sliding_window, logical_softcapping=logical_softcapping, + causal=causal, **kwargs) diff --git a/lmdeploy/pytorch/backends/cuda/flash_attention.py b/lmdeploy/pytorch/backends/cuda/flash_attention.py new file mode 100644 index 000000000..5d3925b74 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/flash_attention.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from ..flash_attention import FlashAttentionBuilder, FlashAttentionImpl + + +class TritonFlashAttentionImpl(FlashAttentionImpl): + """triton flash attention implementation.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + ): + if scale is None: + scale = 1.0 / (head_dim**0.5) + + if num_kv_heads is None: + num_kv_heads = num_heads + + if v_head_dim is None: + v_head_dim = head_dim + + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.v_head_dim = v_head_dim + self.causal = causal + self.sliding_window = sliding_window + self.logical_softcapping = logical_softcapping + + from lmdeploy.pytorch.kernels.cuda import flash_attention_fwd + self.flash_attention_fwd = flash_attention_fwd + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_dim, ) + out = query.new_empty(o_shape) + self.flash_attention_fwd( + query, + key, + value, + out, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logical_softcapping, + causal=self.causal, + kv_layout='shd', + ) + + return out + + +class TritonFlashAttentionBuilder(FlashAttentionBuilder): + """triton attention builder.""" + + @staticmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + return TritonFlashAttentionImpl( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logical_softcapping=logical_softcapping, + ) diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index 3e7fc2372..d710f3891 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -23,9 +23,12 @@ def get_name() -> str: @classmethod def get_layer_impl_builder(cls, layer_type: OpType): """get cuda layer builder.""" - if layer_type == OpType.Attention: + if layer_type == OpType.PagedAttention: from .attention import TritonAttentionBuilder return TritonAttentionBuilder + elif layer_type == OpType.FlashAttention: + from .flash_attention import TritonFlashAttentionBuilder + return TritonFlashAttentionBuilder elif layer_type == OpType.ApplyRotaryEmb: from .apply_rotary_emb import TritonApplyRotaryEmbBuilder return TritonApplyRotaryEmbBuilder @@ -121,30 +124,30 @@ def update_step_context(cls, step_context): quant_policy=step_context.kv_quant_policy, ) - cross_attn_metadata = None - fill_seqlens = None - if step_context.cross_attention_states is not None: - fill_seqlens = torch.zeros_like(q_seqlens) - for idx, state in enumerate(step_context.cross_attention_states): - if state is not None: - fill_seqlens[idx] = state.shape[-2] + cross_seqlens = step_context.cross_seqlens cross_kv_seqlens = step_context.cross_kv_seqlens - cross_kv_start_loc = None - cross_kv_flatten_size = None - if not step_context.is_decoding and cross_kv_seqlens is not None: - cross_kv_start_loc = cross_kv_seqlens.cumsum(0) - cross_kv_seqlens - cross_kv_flatten_size = cross_kv_seqlens.sum().item() - cross_attn_metadata = attn_meta_cls( - step_context.is_decoding, - step_context.block_offsets, - q_start_loc=q_start_loc, - q_seqlens=q_seqlens, - kv_start_loc=cross_kv_start_loc, - kv_seqlens=cross_kv_seqlens, - kv_flatten_size=cross_kv_flatten_size, - fill_seqlens=fill_seqlens, - quant_policy=step_context.kv_quant_policy, - ) + cross_attn_metadata = None + if cross_seqlens is not None: + fill_seqlens = cross_seqlens + if fill_seqlens.sum().item() == 0: + fill_seqlens = None + cross_kv_start_loc = None + cross_kv_flatten_size = None + if not step_context.is_decoding and cross_kv_seqlens is not None: + cross_kv_start_loc = cross_kv_seqlens.cumsum( + 0) - cross_kv_seqlens + cross_kv_flatten_size = cross_kv_seqlens.sum().item() + cross_attn_metadata = attn_meta_cls( + step_context.is_decoding, + step_context.block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=cross_kv_start_loc, + kv_seqlens=cross_kv_seqlens, + kv_flatten_size=cross_kv_flatten_size, + fill_seqlens=fill_seqlens, + quant_policy=step_context.kv_quant_policy, + ) step_context.attn_metadata = attn_metadata step_context.cross_attn_metadata = cross_attn_metadata diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 0d666c913..46da12469 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -30,8 +30,10 @@ def __init__( alibi: bool = None, sliding_window: int = None, logit_softcapping: float = None, + causal: bool = True, **kwargs, ): + assert causal super().__init__( num_heads, head_size, @@ -41,6 +43,7 @@ def __init__( alibi, sliding_window, logit_softcapping, + causal=causal, **kwargs, ) @@ -121,6 +124,7 @@ def build( alibi_scale: float = None, sliding_window: int = None, logical_softcapping: float = None, + causal: bool = True, **kwargs, ) -> DlinferAttentionImpl: """build.""" @@ -132,4 +136,5 @@ def build( alibi_scale=alibi_scale, sliding_window=sliding_window, logical_softcapping=logical_softcapping, + causal=causal, **kwargs) diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index 52a883059..93733fbf5 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -22,7 +22,7 @@ def get_name() -> str: @classmethod def get_layer_impl_builder(cls, layer_type: OpType): """get dlinfer layer builder.""" - if layer_type == OpType.Attention: + if layer_type == OpType.PagedAttention: from .attention import DlinferAttentionBuilder return DlinferAttentionBuilder elif layer_type == OpType.ApplyRotaryEmb: diff --git a/lmdeploy/pytorch/backends/flash_attention.py b/lmdeploy/pytorch/backends/flash_attention.py new file mode 100644 index 000000000..bed3af8d6 --- /dev/null +++ b/lmdeploy/pytorch/backends/flash_attention.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +from torch import Tensor + + +class FlashAttentionImpl(ABC): + """FlashAttention implementation.""" + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + q_start_loc: Tensor, + q_seqlens: Tensor, + kv_start_loc: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int = None): + """forward.""" + raise NotImplementedError + + +class FlashAttentionBuilder(ABC): + """FlashAttention implementation builder.""" + + @staticmethod + @abstractmethod + def build( + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logical_softcapping: float = None, + **kwargs, + ) -> FlashAttentionImpl: + """build.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index 9ab66b26a..9347995e0 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -46,3 +46,26 @@ def prepare_inputs_for_generation( inputs_embeds, context, ) + + def update_model_metas( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare inputs.""" + if hasattr(self.model, 'update_model_metas'): + return self.model.update_model_metas( + past_key_values, + inputs_embeds, + context, + ) + + return None + + def get_input_processor(self): + """get input processor.""" + if hasattr(self.model, 'get_input_processor'): + return self.model.get_input_processor() + else: + return None diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 7d7243822..b943a0727 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -58,6 +58,7 @@ def check_env_torch(): _handle_exception(e, 'PyTorch', logger) +MIN_TRITON_VERSION = '3.0.0' MAX_TRITON_VERSION = '3.0.0' @@ -74,8 +75,10 @@ def check_env_triton(device: str): logger.debug('Checking environment.') import torch import triton + max_version = version.parse(MAX_TRITON_VERSION) triton_version = version.parse(triton.__version__) - if triton_version > version.parse(MAX_TRITON_VERSION): + + if triton_version > max_version: logger.warning( f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.') @@ -96,16 +99,12 @@ def check_env_triton(device: str): _handle_exception(e, 'Triton', logger, msg) if device == 'cuda': - device_cap = torch.cuda.get_device_capability() - TRITON_VER_231 = version.parse('2.3.1') - - if device_cap[0] <= 7: - if triton_version <= TRITON_VER_231: - err = RuntimeError( - 'Attention triton kernel does not fully support ' - 'triton<3.0.0 on device with capability<8. ' - 'Please upgrade your triton version.') - _handle_exception(err, 'Triton', logger) + min_version = version.parse(MIN_TRITON_VERSION) + if triton_version < min_version: + msg = (f'triton>={MIN_TRITON_VERSION} is required. ' + f'Found triton=={triton_version}') + e = RuntimeError(msg) + _handle_exception(e, 'Triton', logger, msg) def check_env(device_type: str): diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index cffe13bbd..60b9e4644 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -8,8 +8,7 @@ import numpy as np import torch -from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, - ResponseType) +from lmdeploy.messages import PytorchEngineConfig, ResponseType from lmdeploy.utils import (get_logger, get_max_batch_size, get_model, logging_timer) @@ -17,9 +16,8 @@ from ..check_env import check_adapters, check_env, check_model from ..config import BackendConfig, CacheConfig, SchedulerConfig from ..devices import DeviceContext, get_device_manager -from ..messages import (InputEmbeddingRangeType, InputEmbeddingType, - MessageStatus, SchedulerSequence) -from ..model_inputs import ModelInputs, MRopeModelInputs, VisionModelInputs +from ..messages import MessageStatus, SchedulerSequence +from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler from .logits_process import FusedLogitsProcessor, SamplingInputs from .model_agent import build_model_agent @@ -156,6 +154,8 @@ def __init__(self, dtype=engine_config.dtype, custom_module_map=engine_config.custom_module_map) + self.input_processor = self.model_agent.get_input_processor() + cache_config = self.model_agent.cache_config self.adapter_manager = self._build_adapter_manager(adapters) self.scheduler = Scheduler(scheduler_config, cache_config) @@ -170,7 +170,6 @@ def __init__(self, # create main thread self._start_loop() self._create_buffers() - self.engine_instance = self.create_instance() @classmethod def from_pretrained(cls, @@ -300,6 +299,10 @@ def _on_end_session(self, reqs: Request, **kwargs): def _on_add_message(self, reqs: Request, **kwargs): """on add message callback.""" + self._msg_preprocess_inque.put_nowait(reqs) + + def _add_message(self, que): + def __update_bad_words(msg): """update bad words.""" sampling_param = msg.sampling_param @@ -322,6 +325,11 @@ def __update_max_new_tokens(msg): sampling_param.max_new_tokens, max_session_len - msg.num_all_tokens()) + if que.qsize() == 0: + return + + reqs = que.get_nowait() + for req in reqs: session_id = req.data['session_id'] if session_id not in self.scheduler.sessions: @@ -339,11 +347,8 @@ def __update_max_new_tokens(msg): sampling_param=req.data['sampling_param'], adapter_name=req.data['adapter_name'], return_logits=req.data.get('return_logits', False), + multimodals=req.data.get('input_multimodals'), input_embeddings=req.data.get('input_embeddings'), - mrope_position_ids=req.data.get('mrope_position_ids'), - mrope_position_delta=req.data.get('mrope_position_delta'), - cross_attention_states=req.data.get( - 'cross_attention_states'), ) msg = next(iter(sess.sequences.values())) __update_bad_words(msg) @@ -351,9 +356,11 @@ def __update_max_new_tokens(msg): self.scheduler.add_sequence(msg) else: msg = next(iter(sess.sequences.values())) - msg.update_token_ids(req.data['token_ids'], - req.data.get('input_embeddings'), - req.data.get('cross_attention_states')) + msg.update_token_ids( + req.data['token_ids'], + multimodals=req.data.get('input_multimodals'), + embeddings=req.data.get('input_embeddings'), + ) msg.num_new_tokens = 0 msg.sampling_param = req.data['sampling_param'] msg.return_logits = req.data.get('return_logits', False) @@ -399,7 +406,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): seq_length = self._seq_length_buf[:batch_size] max_q_seq_length = seq_length.max().item() - # TODO: get block offsets is slow when block_size = 1 block_offsets = self.scheduler.get_block_tables(messages) block_offsets = _tensorlize_block_offsets(block_offsets) @@ -417,6 +423,8 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): num_ignored_history = [msg.num_ignored_history for msg in messages] num_ignored_history = torch.tensor(num_ignored_history) + model_metas = [msg.model_meta for msg in messages] + def __get_cogvlm_image_info(): """Get cogvlm history image info for position ids.""" history_image_nums = torch.LongTensor( @@ -448,12 +456,6 @@ def __get_vlm_embeddings(): return (input_embeddings, input_embedding_indexing, input_embedding_ranges) - def __get_mrope_inputs(): - """get multimodal rotary position inputs.""" - position_ids = [msg.mrope_position_ids for msg in messages] - deltas = [msg.mrope_position_delta for msg in messages] - return MRopeModelInputs(position_ids=position_ids, deltas=deltas) - # for inputs with embeddings history_image_nums = None history_image_token_lengths = None @@ -461,12 +463,6 @@ def __get_mrope_inputs(): if self.model_config.cogvlm_style: (history_image_nums, history_image_token_lengths) = __get_cogvlm_image_info() - # only for qwen2_vl - mrope_inputs = None - has_mrope_params = any( - [msg.mrope_position_ids is not None for msg in messages]) - if has_mrope_params: - mrope_inputs = __get_mrope_inputs() input_embeddings = None input_embedding_indexing = None @@ -477,25 +473,37 @@ def __get_mrope_inputs(): (input_embeddings, input_embedding_indexing, input_embedding_ranges) = __get_vlm_embeddings() + input_multimodals = None + has_multimodal = any( + [not msg.history_multimodals.empty() for msg in messages]) + if has_multimodal: + has_multimodal = False + input_multimodals = [ + msg.get_input_multimodals() for msg in messages + ] + for input_mm in input_multimodals: + for val in input_mm.values(): + if len(val) > 0: + has_multimodal = True + break + if has_multimodal: + break + vision_embedding_inputs = None - if has_embedding or history_image_nums is not None: + if has_embedding or has_multimodal or history_image_nums is not None: vision_embedding_inputs = VisionModelInputs( history_lengths=history_lengths, history_image_nums=history_image_nums, history_image_token_lengths=history_image_token_lengths, input_embeddings=input_embeddings, input_embedding_indexing=input_embedding_indexing, - input_embedding_ranges=input_embedding_ranges) - - # only for mllama - cross_attention_states = None - history_cross_kv_seqlens = None - if any([msg.cross_attention_states is not None for msg in messages]): - cross_attention_states = [ - msg.cross_attention_states for msg in messages - ] - history_cross_kv_seqlens = torch.tensor( - [msg.history_cross_kv_seqlens for msg in messages]) + input_embedding_ranges=input_embedding_ranges, + input_multimodals=input_multimodals) + + # cross + cross_length = torch.tensor([msg.num_cross for msg in messages]) + history_cross_length = torch.tensor( + [msg.num_history_cross for msg in messages]) return ModelInputs( input_ids=input_ids, @@ -506,9 +514,9 @@ def __get_mrope_inputs(): num_ignored_history=num_ignored_history, local_adapter_ids=local_adapter_ids, vision_inputs=vision_embedding_inputs, - mrope_inputs=mrope_inputs, - cross_attention_states=cross_attention_states, - history_cross_kv_seqlens=history_cross_kv_seqlens, + cross_length=cross_length, + history_cross_length=history_cross_length, + model_metas=model_metas, ) def _batch_stopping_criteria(self, token_ids: torch.Tensor, @@ -552,11 +560,15 @@ def __get_last_logits(): @logging_timer('UpdateRunning', logger) def update_running(self, running: SeqList, next_token_ids: torch.Tensor, - stopped: torch.Tensor): + stopped: torch.Tensor, model_metas: List[Dict[str, + Any]]): """update scheduler.""" + if model_metas is None: + model_metas = [None] * len(running) next_token_ids = next_token_ids.numpy() eos_token_id = self.model_config.eos_token_id - for token, msg, stop in zip(next_token_ids, running, stopped): + for token, msg, stop, model_meta in zip(next_token_ids, running, + stopped, model_metas): if msg.status != MessageStatus.RUNNING: continue update_token = token @@ -565,7 +577,7 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, update_token = _EMPTY_TOKEN else: msg.num_new_tokens += 1 - msg.update_token_ids(update_token) + msg.update_token_ids(update_token, model_meta=model_meta) if stop: msg.status = MessageStatus.STOPPED @@ -631,12 +643,14 @@ async def __long_context_single_forward(inputs): batch_size = seq_len.size(0) assert batch_size == 1 - new_inputs = inputs.split(max_prefill_token_num, - self.cache_config.block_size) + new_inputs = inputs.split(max_prefill_token_num) + model_metas = new_inputs[0].model_metas output_gather = _OutputGather(max_seq_len) for inp in new_inputs: + inp.model_metas = model_metas tmp_out = await __forward(inp) + model_metas = tmp_out.get('model_metas') output_gather.gather(tmp_out) tmp_out.pop('hidden_states', None) tmp_out['hidden_states'] = output_gather.get_output() @@ -659,7 +673,8 @@ async def __long_context_single_forward(inputs): return ret def _make_infer_outputs(self, next_token_ids: torch.LongTensor, - logits: torch.Tensor, stopped: torch.Tensor): + logits: torch.Tensor, stopped: torch.Tensor, + model_metas: List[Dict[str, Any]]): """make infer output.""" def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, @@ -683,7 +698,7 @@ def __get_q_start_loc(): running = self._running is_run = [seq.status == MessageStatus.RUNNING for seq in running] stopped = stopped.tolist() - self.update_running(running, next_token_ids, stopped) + self.update_running(running, next_token_ids, stopped, model_metas) # generate output next_token_ids = next_token_ids.tolist() @@ -771,12 +786,15 @@ def __update_inputs(next_token_ids): next_token_ids, sampling_inputs.stop_words, num_appendable_ids) # send output + model_metas = output.get('model_metas') stopped = stopped.cpu() finish = stopped.all().item() or (idx == loop_count - 1) finish = finish or _check_finish(self.scheduler, idx) - output = (next_token_ids.cpu(), logits, stopped) + output = (next_token_ids.cpu(), logits, stopped, model_metas) output_que.put_nowait((finish, output)) + inputs.model_metas = model_metas + if finish: break @@ -786,6 +804,33 @@ def __update_inputs(next_token_ids): swap_out_map = dict() __update_inputs(next_token_ids) + @torch.inference_mode() + async def _async_loop_preprocess_message(self, inque, outque): + """preprocess msg.""" + while True: + reqs = await inque.get() + + for req in reqs: + req_data = req.data + if req_data.get('input_multimodals', None) is None: + continue + input_ids = req_data['token_ids'] + input_multimodals = req_data['input_multimodals'] + if len(input_multimodals) == 0: + req_data['input_multimodals'] = None + continue + result = self.input_processor.preprocess_input( + input_ids, input_multimodals) + + input_ids = result.input_ids + input_multimodals = result.input_multimodals + + req_data['token_ids'] = input_ids + req_data['input_multimodals'] = input_multimodals + + if len(reqs) > 0: + outque.put_nowait(reqs) + @torch.inference_mode() async def _async_loop_background(self, in_que: asyncio.Queue, out_que: asyncio.Queue): @@ -894,12 +939,20 @@ async def _async_loop(self): Each engine instance would communicate with the engine by queue. """ + + self._msg_preprocess_inque = asyncio.Queue() + self._msg_preprocess_outque = asyncio.Queue() + prefill_interval = self.scheduler_config.prefill_interval in_que = asyncio.Queue() out_que = asyncio.Queue() loop_background = asyncio.get_event_loop().create_task( self._async_loop_background(in_que, out_que), name='MainLoopBackground') + loop_background = asyncio.get_event_loop().create_task( + self._async_loop_preprocess_message(self._msg_preprocess_inque, + self._msg_preprocess_outque), + name='MainLoopPreprocessMessage') loop_background.add_done_callback(_raise_exception_on_finish) def __send_resp(out: InferOutput): @@ -933,13 +986,14 @@ async def __step(): while not finish: if self.req_manager.has_requests(): self.req_manager.step() + self._add_message(self._msg_preprocess_outque) finish, out = await out_que.get() try: if isinstance(out, Exception): raise out - next_token_ids, logits, stopped = out + next_token_ids, logits, stopped, model_metas = out step_outputs = self._make_infer_outputs( - next_token_ids, logits, stopped) + next_token_ids, logits, stopped, model_metas) __send_resps(step_outputs) except Exception as e: raise e @@ -949,6 +1003,7 @@ async def __step(): while True: if self.req_manager.has_requests(): self.req_manager.step() + self._add_message(self._msg_preprocess_outque) if not self.scheduler.has_unfinished(): await asyncio.sleep(0.01) @@ -972,78 +1027,3 @@ def create_instance(self, cuda_stream_id=0): """ from .engine_instance import EngineInstance return EngineInstance(self) - - async def async_batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None): - """Send inference request. - - Args: - session_ids (List[int]): The session id. - token_ids (List[int]): The input token ids. - gen_config (GenerationConfig): The sampling parameters. - adapter_names (List[str]): The name of the adapters. - keep_cache (bool): Keep kv cache after infer. - - Returns: - int: Error flags. 0 if success. - List[int]: The streaming output tokens. - int: The number of the output tokens. - """ - return await self.engine_instance.async_batched_infer( - session_ids=session_ids, - token_ids=token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - - def batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None): - """batched infer.""" - return self.engine_instance.batched_infer( - session_ids=session_ids, - token_ids=token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - - async def async_add_session(self, session_id: int): - """Add new session.""" - return await self.engine_instance._async_try_add_session(session_id) - - def add_session(self, session_id: int): - """Add new session.""" - return self.engine_instance._try_add_session(session_id) - - async def async_cancel(self, session_id: int): - """Stop the given session.""" - return await self.engine_instance.async_cancel(session_id) - - def cancel(self, session_id: int): - """Add new session.""" - return self.engine_instance.cancel(session_id) - - async def async_end(self, session_id: int): - """End the given session.""" - return await self.engine_instance.async_end(session_id) - - def end(self, session_id: int): - """Add new session.""" - return self.engine_instance.end(session_id) diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 3e741c7ba..1d90b8a43 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import Any, Dict, List from lmdeploy.messages import EngineOutput, GenerationConfig from lmdeploy.utils import get_logger -from ..messages import (InputEmbeddingRangeType, InputEmbeddings, - InputEmbeddingType, SamplingParam) +from ..messages import SamplingParam from .engine import Engine from .request import RequestSender, RequestType, Response, ResponseType logger = get_logger('lmdeploy') +InputMultiModalType = List[Dict[str, Any]] + def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None): """check if response has state.""" @@ -125,15 +126,13 @@ def _try_add_session(self, session_id: int): """ return try_add_session(self.req_sender, session_id) - async def async_stream_infer( - self, - session_id: int, - input_ids: List[int], - gen_config: GenerationConfig = None, - adapter_name: str = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, - **kwargs): + async def async_stream_infer(self, + session_id: int, + input_ids: List[int], + gen_config: GenerationConfig = None, + multimodal: InputMultiModalType = None, + adapter_name: str = None, + **kwargs): """Send stream inference request. Args: @@ -155,21 +154,13 @@ async def async_stream_infer( await self.req_sender.async_send_async( RequestType.ADD_SESSION, dict(session_id=session_id, response=False)) - input_embeddings_new: List[InputEmbeddings] = None - if input_embeddings is not None and len(input_embeddings) > 0: - assert len(input_embeddings) == len(input_embedding_ranges) - input_embeddings_new = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_embeddings, input_embedding_ranges) - ] - msg = dict(token_ids=input_ids, - session_id=session_id, - sampling_param=sampling_param, - adapter_name=adapter_name, - input_embeddings=input_embeddings_new, - mrope_position_ids=kwargs.get('mrope_position_ids'), - mrope_position_delta=kwargs.get('mrope_position_delta'), - cross_attention_states=kwargs.get('cross_attention_states')) + msg = dict( + token_ids=input_ids, + session_id=session_id, + sampling_param=sampling_param, + adapter_name=adapter_name, + input_multimodals=multimodal, + ) req_id = await self.req_sender.async_send_async( RequestType.ADD_MESSAGE, msg) @@ -190,14 +181,12 @@ async def async_stream_infer( yield EngineOutput(resp.type, [], 0) break - async def async_infer( - self, - session_id: int, - input_ids: List[int] = None, - gen_config: GenerationConfig = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, - **kwargs): + async def async_infer(self, + session_id: int, + input_ids: List[int] = None, + multimodal: InputMultiModalType = None, + gen_config: GenerationConfig = None, + **kwargs): """Send inference request. Args: @@ -211,13 +200,11 @@ async def async_infer( int: The number of the output tokens. """ token_ids = [] - async for outputs in self.async_stream_infer( - session_id, - input_ids, - gen_config=gen_config, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs): + async for outputs in self.async_stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + **kwargs): status, tmp_ids = outputs.status, outputs.token_ids if status not in [ResponseType.SUCCESS, ResponseType.FINISH]: return EngineOutput(status, token_ids, len(token_ids)) @@ -228,10 +215,9 @@ async def async_infer( def stream_infer(self, session_id: int, input_ids: List[int], + multimodal: InputMultiModalType = None, gen_config: GenerationConfig = None, adapter_name: str = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, **kwargs): """Send stream inference request. @@ -252,14 +238,12 @@ def stream_infer(self, def __call_async(): """call async.""" - coro_gen = self.async_stream_infer( - session_id, - input_ids, - gen_config, - adapter_name, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs) + coro_gen = self.async_stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + adapter_name=adapter_name, + **kwargs) while True: try: yield self.req_sender.run_until_complete( @@ -275,19 +259,12 @@ def __call_async(): sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) self.req_sender.send_async(RequestType.ADD_SESSION, dict(session_id=session_id, response=False)) - input_embeddings_new: List[InputEmbeddings] = None - if input_embeddings is not None and len(input_embeddings) > 0: - assert len(input_embeddings) == len(input_embedding_ranges) - input_embeddings_new = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_embeddings, input_embedding_ranges) - ] msg = dict( token_ids=input_ids, session_id=session_id, sampling_param=sampling_param, adapter_name=adapter_name, - input_embeddings=input_embeddings_new, + input_multimodals=multimodal, ) req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg) @@ -311,9 +288,8 @@ def __call_async(): def infer(self, session_id: int, input_ids: List[int] = None, + multimodal: InputMultiModalType = None, gen_config: GenerationConfig = None, - input_embeddings: InputEmbeddingType = None, - input_embedding_ranges: InputEmbeddingRangeType = None, **kwargs): """Send inference request. @@ -328,13 +304,11 @@ def infer(self, int: The number of the output tokens. """ token_ids = [] - for outputs in self.stream_infer( - session_id, - input_ids, - gen_config=gen_config, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - **kwargs): + for outputs in self.stream_infer(session_id, + input_ids, + multimodal=multimodal, + gen_config=gen_config, + **kwargs): status, tmp_ids = outputs.status, outputs.token_ids if status not in [ResponseType.SUCCESS, ResponseType.FINISH]: return EngineOutput(status, token_ids, len(token_ids)) @@ -342,127 +316,6 @@ def infer(self, return EngineOutput(0, token_ids, len(token_ids)) - async def async_batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, - ): - """Send inference request. - - Args: - session_ids (List[int]): The session id. - token_ids (List[int]): The input token ids. - gen_config (GenerationConfig): The sampling parameters. - adapter_names (List[str]): The name of the adapters. - keep_cache (bool): Keep kv cache after infer. - - Returns: - int: Error flags. 0 if success. - List[int]: The streaming output tokens. - int: The number of the output tokens. - """ - batch_size = len(token_ids) - assert len(session_ids) == batch_size - if adapter_names is not None: - assert len(adapter_names) == batch_size - else: - adapter_names = [None for _ in range(batch_size)] - - if input_embeddings is not None: - assert len(input_embeddings) == batch_size - assert len(input_embedding_ranges) == batch_size - else: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - - async def _add_sessions(session_ids): - for session_id in session_ids: - await self._async_try_add_session(session_id) - - async def _add_messages(session_ids, token_ids, adapter_names, - input_embeddings, input_embedding_ranges): - add_msgs = [] - sampling_param = SamplingParam.from_gen_config(gen_config) - for session_id, token_id, adapter_name, input_emb, input_ranges in zip( # noqa: E501 - session_ids, token_ids, adapter_names, input_embeddings, - input_embedding_ranges): - cur_input_embeddings: List[InputEmbeddings] = None - if input_emb is not None and len(input_emb) > 0: - assert len(input_emb) == len(input_ranges) - cur_input_embeddings = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_emb, input_ranges) - ] - msg = dict( - token_ids=token_id, - session_id=session_id, - sampling_param=sampling_param, - adapter_name=adapter_name, - input_embeddings=cur_input_embeddings, - ) - add_msgs.append(msg) - req_types = [RequestType.ADD_MESSAGE] * batch_size - req_ids = await self.req_sender.async_batched_send_async( - req_types, data=add_msgs) - return req_ids - - await _add_sessions(session_ids) - req_ids = await _add_messages(session_ids, token_ids, adapter_names, - input_embeddings, input_embedding_ranges) - - # receive messages - req_idx_map = dict(zip(req_ids, range(len(req_ids)))) - output_token_ids = [list() for _ in req_ids] - status = 0 - finish_count = batch_size - while finish_count: - resp = await self.req_sender.async_recv_any() - if resp.req_id not in req_ids: - continue - idx = req_idx_map[resp.req_id] - token_ids = output_token_ids[idx] - if resp.type == ResponseType.SUCCESS: - token_ids += resp.data['token_ids'] - elif resp.type == ResponseType.FINISH: - token_ids += resp.data['token_ids'] - if not keep_cache: - session_id = session_ids[idx] - await self.async_end(session_id=session_id) - finish_count -= 1 - else: - logger.error(f'Unexpected response: {resp.type}') - status = 1 - break - - output_token_len = [len(token_ids) for token_ids in output_token_ids] - return EngineOutput(status, output_token_ids, output_token_len) - - def batched_infer( - self, - session_ids: List[int], - token_ids: List[List[int]] = None, - gen_config: GenerationConfig = None, - adapter_names: List[str] = None, - keep_cache: bool = False, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, - ): - """batched infer.""" - coro = self.async_batched_infer( - session_ids, - token_ids, - gen_config=gen_config, - adapter_names=adapter_names, - input_embeddings=input_embeddings, - input_embedding_ranges=input_embedding_ranges, - keep_cache=keep_cache) - return self.req_sender.run_until_complete(coro) - async def async_end(self, session_id: int): """End the given session.""" return await async_end(self.req_sender, session_id) @@ -481,8 +334,7 @@ def cancel(self, session_id: int): def decode(self, input_ids, - input_embeddings: List[InputEmbeddingType] = None, - input_embedding_ranges: List[InputEmbeddingRangeType] = None, + multimodal: List[InputMultiModalType] = None, steps: List[int] = None, sequence_start: bool = True, sequence_end: bool = True, @@ -492,10 +344,8 @@ def decode(self, Args: input_ids (numpy.ndarray): the batch of input token ids steps (List[int]): the offset of the k/v cache - input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]): - embeddings features - input_embedding_ranges: (List[List[Tuple[int, int]]]): - the begin/end offsets of input_embeddings to input_ids + multimodal (List[InputMultiModalType]): + multimodals inputs. sequence_start (bool): indicator for starting a sequence sequence_end (bool): indicator for ending a sequence adapter_names (List[str]): The name of the adapters. @@ -505,33 +355,24 @@ def decode(self, batch_size = len(input_ids) def __add_messages(session_ids, input_ids, adapter_names, - input_embeddings, input_embedding_ranges): + input_multimodals): add_msgs = [] sampling_param = SamplingParam(max_new_tokens=0) batch_size = len(input_ids) - if input_embeddings is None: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - for (session_id, token_id, adapter_name, input_emb, - input_ranges) in zip(session_ids, input_ids, adapter_names, - input_embeddings, - input_embedding_ranges): + if input_multimodals is None: + input_multimodals = [None] * batch_size + for (session_id, token_id, adapter_name, + in_mm) in zip(session_ids, input_ids, adapter_names, + input_multimodals): if len(token_id) > self.max_input_len: raise RuntimeError( f'Expect input length<={self.max_input_len} ' f'but get {len(token_id)}') - cur_input_embeddings: List[InputEmbeddings] = None - if input_emb is not None and len(input_emb) > 0: - assert len(input_emb) == len(input_ranges) - cur_input_embeddings = [ - InputEmbeddings(emb, rg[0], rg[1]) - for emb, rg in zip(input_emb, input_ranges) - ] msg = dict(token_ids=token_id, session_id=session_id, sampling_param=sampling_param, adapter_name=adapter_name, - input_embeddings=cur_input_embeddings, + input_multimodals=in_mm, return_logits=True) add_msgs.append(msg) req_types = [RequestType.ADD_MESSAGE] * batch_size @@ -547,13 +388,6 @@ def __add_messages(session_ids, input_ids, adapter_names, else: adapter_names = [None] * batch_size - if input_embeddings is not None: - assert len(input_embeddings) == batch_size - assert len(input_embedding_ranges) == batch_size - else: - input_embeddings = [None] * batch_size - input_embedding_ranges = [None] * batch_size - session_ids = tuple(range(batch_size)) if sequence_start: for sid in session_ids: @@ -562,7 +396,7 @@ def __add_messages(session_ids, input_ids, adapter_names, self._try_add_session(sid) req_ids = __add_messages(session_ids, input_ids, adapter_names, - input_embeddings, input_embedding_ranges) + multimodal) req_idx_map = dict(zip(req_ids, range(len(req_ids)))) finish_count = batch_size diff --git a/lmdeploy/pytorch/engine/input_process.py b/lmdeploy/pytorch/engine/input_process.py new file mode 100644 index 000000000..7f442e153 --- /dev/null +++ b/lmdeploy/pytorch/engine/input_process.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs + +TypeModelMetas = Dict[str, Any] + +InputMultiModalType = List[Dict[str, Any]] + + +@dataclass +class PreprocessInputResult: + """results of preprocess input.""" + input_ids: List[int] + input_multimodals: Optional[MultiModalInputs] = None + model_metas: Optional[TypeModelMetas] = None + + +class BaseModelInputProcessor(ABC): + """processor of model inputs.""" + + @abstractmethod + def preprocess_input(self, + input_ids: List[int], + input_mms: InputMultiModalType = None, + **kwargs) -> PreprocessInputResult: + """preprocess input.""" + raise NotImplementedError('Not implemented.') + + +class DefaultModelInputProcessor(BaseModelInputProcessor): + """default model input processor.""" + + def preprocess_input(self, + input_ids: List[int], + input_mms: MultiModalInputs = None, + **kwargs) -> PreprocessInputResult: + """preprocess input.""" + return PreprocessInputResult( + input_ids=input_ids, + input_multimodals=input_mms, + ) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 74938de81..014bee65a 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -146,12 +146,17 @@ def model_forward( kv_quant_policy=cache_engine.cache_config.quant_policy, ) with ctx_mgr.context(context): + model_metas = None + model_metas = model.update_model_metas( + past_key_values=cache_engine.gpu_cache, + context=context, + ) input_dict = model.prepare_inputs_for_generation( past_key_values=cache_engine.gpu_cache, context=context, ) output = model(**input_dict) - return dict(hidden_states=output) + return dict(hidden_states=output, model_metas=model_metas) SwapMap = Dict[int, int] @@ -164,10 +169,6 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): self.model_config = model_config self.cache_config = cache_config - def get_block_numel(self): - """get block nelement.""" - raise NotImplementedError('Not implemented') - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -194,6 +195,10 @@ def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" raise NotImplementedError('Not implemented.') + def get_input_processor(self): + """get input processor.""" + raise NotImplementedError('Not implemented.') + class BaseModelAgent(AutoModelAgent): """Base model agent. @@ -257,11 +262,6 @@ def _build_model(self, device=device) return patched_model - def get_block_numel(self): - """get block nelement.""" - k_cache = self.cache_engine.local_gpu_cache[0][0] - return k_cache[0].numel() - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): cache_swapping(self.cache_engine, @@ -311,6 +311,10 @@ def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) + def get_input_processor(self): + """get input processor..""" + return self.patched_model.get_input_processor() + @torch.inference_mode() def _tp_build_model( @@ -690,11 +694,6 @@ def _build_model( return model, cache_engine, cache_config - def get_block_numel(self): - """get block nelement.""" - k_cache = self.cache_engine.local_gpu_cache[0][0] - return k_cache[0].numel() - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """forward impl.""" @@ -750,6 +749,10 @@ def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" return self.patched_model.get_logits(hidden_states) + def get_input_processor(self): + """get input processor..""" + return self.patched_model.get_input_processor() + def _exit_handler(agent: TPModelAgent): if hasattr(agent, 'patched_model'): diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 7521a3e2b..7df67ce78 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -47,6 +47,17 @@ def softcapping(qk, logit_softcapping: tl.constexpr): return qk +@triton.jit +def _load_kv(ptrs, causal_mask: tl.constexpr, boundary_check: tl.constexpr): + """load kv.""" + if causal_mask: + return tl.load(ptrs, + boundary_check=boundary_check, + padding_option='zero') + else: + return tl.load(ptrs) + + @triton.jit def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, qk_scale, history_mask, @@ -63,11 +74,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, for start_n in range(loop_start, loop_end, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load(k_ptrs) + k = _load_kv(k_ptrs, causal_mask, boundary_check=(1, )) qk = tl.dot(q, k) if BLOCK_DK1 != 0: - k1 = tl.load(k1_ptrs) + k1 = _load_kv(k1_ptrs, causal_mask, boundary_check=(1, )) qk += tl.dot(q1, k1) if causal_mask: @@ -113,7 +124,7 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, acc = acc * alpha[:, None] # update acc - v = tl.load(v_ptrs) + v = _load_kv(v_ptrs, causal_mask, boundary_check=(0, )) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i @@ -168,6 +179,7 @@ def _flash_prefill_fwd_kernel( kv_group_num, head_dim_k, head_dim_v, + causal: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, BLOCK_M: tl.constexpr, @@ -257,9 +269,13 @@ def _flash_prefill_fwd_kernel( acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) qk_scale = sm_scale * tl_log2(math.e) - history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) + if causal: + history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) + loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N + else: + history_mask = tl.full([BLOCK_M], kv_seqlen - 1, dtype=tl.int32) + loop_end = kv_seqlen // BLOCK_N * BLOCK_N - loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N acc, l_i, m_i = _prefill_fwd_inner(acc, l_i, m_i, @@ -280,7 +296,10 @@ def _flash_prefill_fwd_kernel( BLOCK_DK1=BLOCK_DK1) loop_start = loop_end - loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N) + if causal: + loop_end = tl.minimum(kv_seqlen, loop_start + BLOCK_M + BLOCK_N) + else: + loop_end = kv_seqlen acc, l_i, m_i = _prefill_fwd_inner(acc, l_i, m_i, @@ -330,6 +349,7 @@ def flash_attention_fwd( window_size: int = None, sm_scale: float = None, logit_softcapping: float = None, + causal: bool = True, kv_layout: str = 'hsd', ): """varlen flash Attention forward. @@ -380,6 +400,7 @@ def grid(args): BLOCK_M = max(16, 8192 // BLOCK_DK) else: BLOCK_M = max(16, 16384 // BLOCK_DK) + BLOCK_M = min(128, BLOCK_M) num_warps = 4 num_stages = min(4, max(2, 1024 // BLOCK_DK)) if BLOCK_DK >= 512: @@ -413,6 +434,7 @@ def grid(args): kv_group_num=kv_group_num, head_dim_k=head_dim_k, head_dim_v=head_dim_v, + causal=causal, window_size=window_size, logit_softcapping=logit_softcapping, BLOCK_DK=BLOCK_DK, diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index b16a78f1f..0aaba98c9 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -8,6 +8,7 @@ from torch import Tensor from lmdeploy.messages import GenerationConfig, LogitsProcessor +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger from .block import LogicalTokenBlocks @@ -205,10 +206,9 @@ def add_sequence( sampling_param: SamplingParam = None, adapter_name: str = None, return_logits: bool = False, - input_embeddings: List[InputEmbeddings] = None, - mrope_position_ids: Tensor = None, - mrope_position_delta: Tensor = None, - cross_attention_states: Tensor = None) -> 'SchedulerSequence': + multimodals: MultiModalInputs = None, + input_embeddings: List[InputEmbeddings] = None + ) -> 'SchedulerSequence': """Add a new message.""" if isinstance(token_ids, Tensor): token_ids = token_ids.numpy() @@ -228,10 +228,8 @@ def add_sequence( adapter_name=adapter_name, arrive_time=time.time(), history_embeddings=HistoryEmbeddings(input_embeddings), + history_multimodals=HistoryMultiModals(multimodals), return_logits=return_logits, - mrope_position_ids=mrope_position_ids, - mrope_position_delta=mrope_position_delta, - cross_attention_states=cross_attention_states, ) self.sequences[seq.seq_id] = seq if self.seq_manager is not None: @@ -361,6 +359,66 @@ def copy(self): return self.clone() +class HistoryMultiModals: + + def __init__(self, multimodals: MultiModalInputs): + if multimodals is None: + multimodals = dict() + self.multimodals = multimodals + + def get_datas(self, start=0, end=-1): + """get multimodals from prompts position [start, end).""" + outs = dict() + test_range = range(start, end) + for modal_type, modal_datas in self.multimodals.items(): + data = [] + for modal_data in modal_datas: + if (modal_data.start not in test_range + and modal_data.end not in test_range): + continue + data.append(modal_data) + if len(data) > 0: + outs[modal_type] = data + return outs + + def add_inputs(self, input_mms: MultiModalInputs): + """add new inputs.""" + for modal_type, vals in input_mms.items(): + if modal_type in self.multimodals: + self.multimodals[modal_type] += vals + else: + self.multimodals[modal_type] = vals + + def empty(self): + if len(self.multimodals) == 0: + return 0 + + return all(len(vals) == 0 for vals in self.multimodals) + + @staticmethod + def update_multimodals(input_mms: MultiModalInputs, prev_len: int): + """update multimodals.""" + for vals in input_mms.values(): + for val in vals: + val.start += prev_len + val.end += prev_len + return input_mms + + def get_encoder_len(self, start=0, end=-1): + """get lens of encoder.""" + test_range = range(start, end) + out_len = 0 + for _, modal_datas in self.multimodals.items(): + for modal_data in modal_datas: + if modal_data.encoder_len is None: + continue + if (modal_data.start not in test_range + and modal_data.end not in test_range): + continue + out_len += modal_data.encoder_len + return out_len + + @dataclass class SchedulerSequence: """Scheduler message.""" @@ -369,6 +427,8 @@ class SchedulerSequence: history_cache: HistoryTokenIds = field(default_factory=HistoryTokenIds) history_embeddings: HistoryEmbeddings = field( default_factory=HistoryEmbeddings) + history_multimodals: HistoryMultiModals = field( + default_factory=HistoryMultiModals) num_new_tokens: int = 0 sampling_param: SamplingParam = field(default_factory=SamplingParam) logical_blocks: LogicalTokenBlocks = field( @@ -382,10 +442,7 @@ class SchedulerSequence: random_offsets: int = 0 _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 - mrope_position_ids: Optional[Tensor] = None - mrope_position_delta: Optional[int] = None - cross_attention_states: Optional[Tensor] = None - history_cross_kv_seqlens: int = 0 + model_meta: Dict[str, Any] = None def __post_init__(self): """post init.""" @@ -394,6 +451,10 @@ def __post_init__(self): self._num_images: int = len(self.history_embeddings) self._num_token_ids: int = len(self.history_cache) + self._num_history_cross: int = 0 + self._num_cross: int = self.history_multimodals.get_encoder_len( + 0, self._num_token_ids) + @property def block_size(self) -> int: """block size.""" @@ -464,6 +525,16 @@ def num_all_ids(self): """num all tokens.""" return self.history_len + self._num_token_ids + @property + def num_cross(self): + """num cross.""" + return self._num_cross + + @property + def num_history_cross(self): + """num history cross.""" + return self._num_history_cross + @property def num_blocks(self): """num blocks.""" @@ -489,22 +560,22 @@ def num_all_tokens(self): def num_all_cross_tokens(self): """num of all cross tokens.""" - if self.cross_attention_states is None: - self.history_cross_kv_seqlens = 0 - else: - self.history_cross_kv_seqlens = self.cross_attention_states.shape[ - -2] - return self.history_cross_kv_seqlens + return self._num_cross + self._num_history_cross + + def get_input_multimodals(self): + """get input multimodals.""" + start = self.num_history_ids + end = self.num_all_ids + return self.history_multimodals.get_datas(start, end) def update_token_ids(self, token_ids: Tensor, + multimodals: MultiModalInputs = None, embeddings: List[InputEmbeddings] = None, - cross_attention_states: List[Tensor] = None): + model_meta: Dict[str, Any] = None): """Update token ids, old token ids will be added to history.""" - # cross attention - if cross_attention_states is not None: - self.history_cross_kv_seqlens += cross_attention_states.shape[-2] - self.cross_attention_states = cross_attention_states + old_num_history_ids = self._num_history_ids + self._num_history_ids += self._num_token_ids # update history image nums self._num_history_images += self._num_images @@ -516,6 +587,23 @@ def update_token_ids(self, self._num_images = len(new_embeddings) self.history_embeddings.append(new_embeddings) + # update multimodals + if multimodals is not None: + multimodals = HistoryMultiModals.update_multimodals( + multimodals, self.num_all_ids) + self.history_multimodals.add_inputs(multimodals) + + # cross + self._num_history_cross += self._num_cross + if multimodals is not None: + self._num_cross = self.history_multimodals.get_encoder_len( + old_num_history_ids, self._num_history_ids) + else: + self._num_cross = 0 + + if model_meta is not None: + self.model_meta = model_meta + if isinstance(token_ids, Tensor): token_ids = token_ids.numpy() elif not isinstance(token_ids, np.ndarray): @@ -539,3 +627,12 @@ def set_step(self, step: int): self._num_history_ids = step self._num_token_ids = num_all_ids - step self.num_ignored_history = min(step, self.num_ignored_history) + + self.model_meta = None + + # cross + if self.history_multimodals is not None: + self._num_history_cross = self.history_multimodals.get_encoder_len( + 0, self.num_history_ids) + self._num_cross = self.history_multimodals.get_encoder_len( + self._num_history_ids, num_all_ids) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 669625d43..99355fa39 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -6,45 +6,7 @@ import torch from lmdeploy.pytorch.backends import get_backend - - -@dataclass -class MRopeModelInputs: - """Multimodal rotary position inputs.""" - position_ids: List[torch.LongTensor] = None - deltas: List[torch.LongTensor] = None - - def get_inputs(self, history_lengths: torch.Tensor, - seq_lengths: torch.Tensor): - mrope_position_ids = [] - for (his_len, seq_len, pos_ids, - delta) in zip(history_lengths, seq_lengths, self.position_ids, - self.deltas): - assert pos_ids.dim() == 2, 'invalid mrope_position_ids' - if his_len + seq_len <= pos_ids.shape[1]: - mrope_position_ids.append(pos_ids[:, - his_len:his_len + seq_len]) - else: - mrope_position_ids.append( - torch.tensor([his_len], device=delta.device).expand(3, -1) - + delta) - - mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) - return mrope_position_ids - - def to_device(self, device: str): - """to device.""" - out_dict = dict() - for f in fields(self): - k = f.name - v = getattr(self, k) - if isinstance(v, torch.Tensor): - v = v.to(device) - elif isinstance(v, list): - v = [x.to(device) for x in v] - out_dict[k] = v - - return MRopeModelInputs(**out_dict) +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor @dataclass @@ -56,6 +18,7 @@ class VisionModelInputs: input_embeddings: List[List[torch.Tensor]] = None input_embedding_ranges: List[torch.LongTensor] = None input_embedding_indexing: torch.BoolTensor = None + input_multimodals: List[MultiModalTensor] = None def to_device(self, device: str): """to device.""" @@ -63,12 +26,19 @@ def to_device(self, device: str): for f in fields(self): k = f.name v = getattr(self, k) + if v is None: + continue if isinstance(v, torch.Tensor): v = v.to(device) - elif k == 'input_embedding_ranges' and v is not None: + elif k == 'input_embedding_ranges': v = [e.to(device) for e in v] - elif k == 'input_embeddings' and v is not None: + elif k == 'input_embeddings': v = [[e.to(device) for e in li] for li in v] + elif k == 'input_multimodals': + for mm_datas in v: + for modal_type, data in mm_datas.items(): + data = [d.to_device(device) for d in data] + mm_datas[modal_type] = data out_dict[k] = v return VisionModelInputs(**out_dict) @@ -119,9 +89,9 @@ class ModelInputs: num_ignored_history: torch.LongTensor local_adapter_ids: torch.LongTensor = None vision_inputs: VisionModelInputs = None - mrope_inputs: MRopeModelInputs = None - cross_attention_states: torch.Tensor = None - history_cross_kv_seqlens: torch.LongTensor = None + cross_length: torch.LongTensor = None + history_cross_length: torch.LongTensor = None + model_metas: List[Dict[str, Any]] = None def update(self, input_ids: torch.LongTensor): """update input ids.""" @@ -132,44 +102,87 @@ def update(self, input_ids: torch.LongTensor): self.input_ids = input_ids return self - def split(self, split_size: int, block_size: int): + def split(self, split_size: int): """split inputs.""" assert len( self.seq_length) == 1, ('Can not perform split on batched input.') - assert split_size % block_size == 0, ( - 'split_size should be multi of block_size.') input_ids = self.input_ids if input_ids.numel() < split_size: return self - num_blocks = split_size // block_size - overlap = (self.history_lengths[0] % block_size != 0) + flatten_mms = [] + vision_inputs = self.vision_inputs + if vision_inputs is not None: + if vision_inputs.input_multimodals is not None: + input_mms = vision_inputs.input_multimodals[0] + + flatten_mms = [] + for k, mms in input_mms.items(): + mms = [(k, mm) for mm in mms] + flatten_mms += mms + + flatten_mms = sorted(flatten_mms, key=lambda mm: mm[1].start) + max_seq_len = self.seq_length[0].item() ret = [] - block_start = 0 - for i in range(0, max_seq_len, split_size): - start = i - end = min(max_seq_len, i + split_size) - block_end = block_start + num_blocks - if overlap: - block_end += 1 - - block_offsets = self.block_offsets + start = 0 + history_cross_length = self.history_cross_length + cross_length = None + if history_cross_length is not None: + cross_length = self.history_cross_length.clone() + while start < max_seq_len: + vision_inputs = None + if len(flatten_mms) > 0: + mm_start = flatten_mms[0][1].start + mm_end = flatten_mms[0][1].end + if mm_start > self.history_lengths + start: + end = min(mm_start, start + split_size) + else: + input_mms = dict() + key, mm = flatten_mms.pop(0) + input_mms.setdefault(key, []) + input_mms[key].append(mm) + end = start + mm.end - mm.start + while len(flatten_mms) > 0: + next_mm = flatten_mms[0] + next_start = next_mm[1].start + next_end = next_mm[1].end + if next_start < mm_end: + key = next_mm[0] + input_mms.setdefault(key, []) + input_mms[key].append(next_mm[1]) + end += max(0, next_end - mm_end) + flatten_mms.pop(0) + + if cross_length is not None: + encoder_len = next_mm[1].encoder_len + if encoder_len is not None: + cross_length += encoder_len + else: + break + vision_inputs = VisionModelInputs( + input_multimodals=[input_mms], ) + else: + end = min(max_seq_len, start + split_size) + inp = ModelInputs( input_ids=self.input_ids[:, start:end], seq_length=input_ids.new_tensor([end - start]), - block_offsets=block_offsets, + block_offsets=self.block_offsets, history_lengths=self.history_lengths + start, is_decoding=self.is_decoding, num_ignored_history=self.num_ignored_history, local_adapter_ids=self.local_adapter_ids, - vision_inputs=self.vision_inputs, - mrope_inputs=self.mrope_inputs, - cross_attention_states=self.cross_attention_states, + vision_inputs=vision_inputs, + model_metas=self.model_metas, + cross_length=cross_length, + history_cross_length=history_cross_length, ) ret.append(inp) - block_start += num_blocks + history_cross_length = cross_length + + start = end return ret @@ -183,8 +196,6 @@ def to_device(self, device: str): v = v.to(device) elif isinstance(v, VisionModelInputs): v = v.to_device(device) - elif isinstance(v, MRopeModelInputs): - v = v.to_device(device) out_dict[k] = v return ModelInputs(**out_dict) @@ -210,13 +221,14 @@ class StepContext: local_adapter_ids: torch.LongTensor = None input_embeddings: torch.Tensor = None input_embedding_indexing: torch.Tensor = None + input_multimodals: List[MultiModalTensor] = None vision_inputs: VisionModelInputs = None - mrope_position_ids: torch.Tensor = None attn_metadata: Any = None - cross_attn_metadata: Any = None - cross_attention_states: torch.Tensor = None + cross_seqlens: torch.LongTensor = None cross_kv_seqlens: torch.LongTensor = None + cross_attn_metadata: Any = None kv_quant_policy: Literal[0, 4, 8] = 0 + model_metas: List[Dict[str, Any]] = None _outputs: Dict = field(default_factory=dict) @@ -239,24 +251,21 @@ def new( history_seqlens = inputs.history_lengths device = q_seqlens.device + input_multimodals = None + if inputs.vision_inputs is not None: + input_multimodals = inputs.vision_inputs.input_multimodals + # for vlm input_embeddings, input_embedding_indexing = None, None if (inputs.vision_inputs is not None and inputs.vision_inputs.input_embeddings is not None): input_embeddings, input_embedding_indexing = \ inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) - # for mrope - mrope_position_ids = None - if inputs.mrope_inputs is not None: - mrope_position_ids = inputs.mrope_inputs.get_inputs( - history_seqlens, q_seqlens) # kv_seqlens - cross_attention_states = inputs.cross_attention_states if inputs.is_decoding: attention_mask = torch.ones_like(q_seqlens)[:, None] position_ids = history_seqlens.unsqueeze(-1) - cross_attention_states = None else: max_q_seqlen = q_seqlens.max().item() mask_range = torch.arange(max_q_seqlen, device=device)[None, :] @@ -265,6 +274,13 @@ def new( position_ids += history_seqlens.unsqueeze(-1) q_start_loc = q_seqlens.cumsum(0) - q_seqlens + # cross + cross_seqlens = inputs.cross_length + cross_kv_seqlens = None + if cross_kv_seqlens is not None: + cross_kv_seqlens = (inputs.cross_length + + inputs.history_cross_length) + # position ids 1d position_ids = cls.get_position_ids_1d(position_ids, q_seqlens)[None] # seq_len + history_length @@ -277,6 +293,7 @@ def new( position_ids=position_ids, input_embeddings=input_embeddings, input_embedding_indexing=input_embedding_indexing, + input_multimodals=input_multimodals, attention_mask=attention_mask, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, @@ -286,10 +303,10 @@ def new( world_size=world_size, local_adapter_ids=inputs.local_adapter_ids, vision_inputs=inputs.vision_inputs, - mrope_position_ids=mrope_position_ids, - cross_attention_states=cross_attention_states, - cross_kv_seqlens=inputs.history_cross_kv_seqlens, kv_quant_policy=kv_quant_policy, + model_metas=inputs.model_metas, + cross_seqlens=cross_seqlens, + cross_kv_seqlens=cross_kv_seqlens, ) ret = get_backend().update_step_context(ret) diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py index 8d7a21a0a..ac69fea2a 100644 --- a/lmdeploy/pytorch/models/chatglm2.py +++ b/lmdeploy/pytorch/models/chatglm2.py @@ -1,101 +1,29 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn +from torch.nn import functional as F from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin LANGUAGE_TOKEN_TYPE = 0 VISION_TOKEN_TYPE = 1 -def get_vision_expert_mask(token_type_ids: torch.LongTensor): - vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) - vision_token_mask[:, :-1] = (token_type_ids[:, :-1] - == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] - == VISION_TOKEN_TYPE) - language_token_mask = ~vision_token_mask - return vision_token_mask, language_token_mask - - -def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor: - tmp = x.clone() - # image boi eoi token as LANGUAGE_TOKEN_TYPE - is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) - is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & ( - tmp[:, :-1] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE) - is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & ( - tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE) - tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE - # final position ids - y = torch.zeros_like(x, dtype=torch.long) - y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ( - (tmp[:, 1:] == VISION_TOKEN_TYPE) & - (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)) - y = y.cumsum(dim=-1) - return y - - -def _get_cogvlm_position_ids(context): - """get cogvlm position_ids.""" - q_seqlens = context.q_seqlens - history_lengths = context.kv_seqlens - q_seqlens - vision_input_info = context.vision_inputs - position_id_offsets = (vision_input_info.history_image_token_lengths - - vision_input_info.history_image_nums * 3) - lang_ids = None - vis_ids = None - if context.is_decoding: - position_ids = history_lengths - position_id_offsets - else: - if vision_input_info.input_embeddings is not None and len( - vision_input_info.input_embeddings) > 0: - starts = history_lengths - vision_input_info.history_lengths - ends = starts + q_seqlens - token_type_ids = vision_input_info.input_embedding_indexing.to( - torch.int) - history_position_lengths = (vision_input_info.history_lengths - - position_id_offsets) - position_ids_all = (history_position_lengths[:, None] + - build_position_ids(token_type_ids)) - position_ids = torch.cat([ - pids[s:e] - for (pids, s, e) in zip(position_ids_all, starts, ends) - ]) - vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids) - vision_token_mask = torch.cat([ - masks[s:e] - for (masks, s, e) in zip(vision_token_mask_all, starts, ends) - ]) - mask_indexing = torch.arange(vision_token_mask.shape[-1], - device=vision_token_mask.device) - vis_ids = mask_indexing[vision_token_mask] - lang_ids = mask_indexing[~vision_token_mask] - - else: - position_ids = context.attention_mask.long().cumsum(-1) - 1 - position_ids += (history_lengths - - position_id_offsets).unsqueeze(-1) - device = position_ids.device - position_ids_1d = [ - ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu()) - ] - position_ids = torch.cat(position_ids_1d).to(device) - - return position_ids, lang_ids, vis_ids - - class SelfAttention(torch.nn.Module): """Parallel self-attention layer abstract class. @@ -410,6 +338,286 @@ def forward(self, input_ids): return embeddings +class PatchEmbedding(nn.Module): + """vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + dtype=dtype, + device=device) + self.cls_embedding = nn.Parameter( + torch.empty(1, config.hidden_size, dtype=dtype, device=device)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size, + dtype=dtype, + device=device) + + def forward(self, images): + """forward.""" + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + """vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + hidden_size = config.hidden_size + num_heads = config.num_heads + head_dim = config.hidden_size // config.num_heads + self.scale = head_dim**-0.5 + + # packed qkv + self.query_key_value = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.dense = build_rowwise_linear(hidden_size, + hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + # qkv proj + qkv_states = self.query_key_value(hidden_states) + q, k, v = self.query_key_value.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.dense(attn_output) + return attn_output + + +class EVA2CLIPMLP(nn.Module): + """vision MLP.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + # gate up + quantization_config = getattr(config, 'quantization_config', None) + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.activation_fn = nn.GELU() + else: + self.activation_fn = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward.""" + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + """vision trans layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device) + self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward(self, hidden_states): + """forward.""" + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + """vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + """forward.""" + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + """GLU.""" + + def __init__(self, + config: PretrainedConfig, + in_features: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.linear_proj = nn.Linear(in_features, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + self.norm1 = nn.LayerNorm(config.hidden_size, + dtype=dtype, + device=device) + self.act1 = nn.GELU() + self.act2 = nn.functional.silu + self.dense_h_to_4h = nn.Linear(config.hidden_size, + config.ffn_hidden_size, + bias=False, + dtype=dtype, + device=device) + self.gate_proj = nn.Linear(config.hidden_size, + config.ffn_hidden_size, + bias=False, + dtype=dtype, + device=device) + self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, x): + x = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + """vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from argparse import Namespace + vision_config = Namespace(**config.vision_config) + + self.patch_embedding = PatchEmbedding(vision_config, + dtype=dtype, + device=device) + self.transformer = EVA2CLIPTransformer(vision_config, + dtype=dtype, + device=device) + self.linear_proj = GLU(config, + in_features=config.hidden_size, + dtype=dtype, + device=device) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2, + dtype=dtype, + device=device) + self.boi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.eoi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images): + """forward.""" + x = self.patch_embedding(images) + x = self.transformer(x) + + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x + + class ChatGLMModel(nn.Module): def __init__(self, @@ -442,19 +650,32 @@ def __init__(self, dtype=dtype, device=device) + self.vision = None + if hasattr(config, 'vision_config'): + self.vision = EVA2CLIPModel(config, dtype=dtype, device=device) + def forward( self, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, + images: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): """forward.""" # token embedding if inputs_embeds is None: + images_features = None + if images is not None: + images_features = self.vision(images) + images_features = images_features.flatten(0, 1)[None] inputs_embeds = self.embedding(input_ids) + if images is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + images_features) hidden_states = inputs_embeds @@ -477,7 +698,8 @@ def get_input_embeddings(self): return self.embedding -class ChatGLMForConditionalGeneration(nn.Module, CudaGraphMixin): +class ChatGLMForConditionalGeneration(nn.Module, DeployModelMixin, + CudaGraphMixin): """rewrote model of LlamaForCausalLM.""" def __init__(self, @@ -491,12 +713,16 @@ def __init__(self, # build Model self.transformer = ChatGLMModel(config, dtype=dtype, device=device) + self.input_processor = ChatGLMInputProcessor(self.config, dtype) + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + images: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs, ): @@ -506,6 +732,8 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) return hidden_states @@ -529,8 +757,23 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - if context.vision_inputs is not None: - position_ids = _get_cogvlm_position_ids(context)[0][None] + + images = None + image_mask = None + if context.input_multimodals is not None: + images = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + images = [data for im_data in images for data in im_data] + if len(images) != 0: + image_token_id = images[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + images = torch.stack([data.data for data in images]) + else: + images = None + image_mask = None # process vision embeddings vision_embeddings = context.input_embeddings @@ -548,9 +791,92 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + if not hasattr(self.config, 'vision_config'): + return model_metas + + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_imgs = [[] for _ in model_metas] + else: + input_imgs = [] + for mm in input_multimodals: + if mm is None: + input_imgs.append([]) + else: + input_imgs.append(mm.get('image', [])) + + config = self.config + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + num_pad = vision_token_num - 3 + + batched_num_img_tokens = [] + new_model_metas = [] + for meta, imgs in zip(model_metas, input_imgs): + if meta is None: + num_img_tokens = 0 + else: + num_img_tokens = meta.get('num_img_tokens', 0) + + batched_num_img_tokens.append(num_img_tokens) + + num_img_tokens += num_pad * len(imgs) + new_model_metas.append(dict(num_img_tokens=num_img_tokens)) + + # prepare cogvlm position_ids + q_seqlens = context.q_seqlens + position_ids = context.position_ids + + if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs): + num_img_tokens = torch.tensor(batched_num_img_tokens, + device=position_ids.device) + position_ids -= num_img_tokens[None] + else: + batched_position_ids = position_ids[0].split(q_seqlens) + for pos_ids, num_img_tok, imgs in zip(batched_position_ids, + batched_num_img_tokens, + input_imgs): + pos_ids -= num_img_tok + if len(imgs) == 0: + continue + + seq_len = pos_ids.size(0) + start = pos_ids[0].cpu().item() + new_pos_ids = [] + + imgs = sorted(imgs, key=lambda img: img.start) + for img in imgs: + img_pad_pos = img.start + 1 - num_img_tok + num_pad = img.end - img.start - 2 + new_pos_ids += list(range(start, img_pad_pos)) + new_pos_ids += [img_pad_pos] * num_pad + start = img_pad_pos + 1 + num_img_tok += num_pad + + remain = seq_len - len(new_pos_ids) + new_pos_ids += list(range(start, start + remain)) + + new_pos_ids = pos_ids.new_tensor(new_pos_ids) + pos_ids[:] = new_pos_ids + + position_ids = torch.cat(batched_position_ids)[None] + context.position_ids = position_ids + + return new_model_metas + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm @@ -558,7 +884,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'transformer.vision' in name: + if '.query_key_value' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) continue + if 'rotary_pos_emb.inv_freq' in name: continue if ('rotary_pos_emb.cos_cached' in name @@ -581,3 +917,53 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class ChatGLMInputProcessor(BaseModelInputProcessor): + """input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + if hasattr(config, 'vision_config'): + vision_config = config.vision_config + self.image_size = vision_config['image_size'] + self.patch_size = vision_config['patch_size'] + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.vision_token_num = self.num_patches // 4 + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + num_pad = input_mm['image_tokens'] + image_token_id = input_mm.get('image_token_id', 0) + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 6caf10df0..f4f1baaff 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -1,20 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. +from argparse import Namespace from typing import Any, Iterable, List, Optional, Tuple import torch import torch.distributed as dist +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.distributed import get_world_rank +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin class VisionExpertAttention(nn.Module): @@ -322,6 +329,283 @@ def forward( return outputs +class PatchEmbedding(nn.Module): + """vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + dtype=dtype, + device=device) + self.cls_embedding = nn.Parameter( + torch.empty(1, config.hidden_size, dtype=dtype, device=device)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size, + dtype=dtype, + device=device) + + def forward(self, images): + """forward.""" + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + """vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + hidden_size = config.hidden_size + num_heads = config.num_heads + head_dim = config.hidden_size // config.num_heads + self.scale = head_dim**-0.5 + + # packed qkv + self.query_key_value = build_qkv_proj( + hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.dense = build_rowwise_linear(hidden_size, + hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + # qkv proj + qkv_states = self.query_key_value(hidden_states) + q, k, v = self.query_key_value.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.dense(attn_output) + return attn_output + + +class EVA2CLIPMLP(nn.Module): + """vision MLP.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + # gate up + quantization_config = getattr(config, 'quantization_config', None) + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.activation_fn = nn.GELU() + else: + self.activation_fn = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """forward.""" + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + """vision trans layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.attention = EVA2CLIPAttention(config, dtype=dtype, device=device) + self.mlp = EVA2CLIPMLP(config, dtype=dtype, device=device) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward(self, hidden_states): + """forward.""" + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + """vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + """forward.""" + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + """GLU.""" + + def __init__(self, + config: PretrainedConfig, + in_features: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.linear_proj = nn.Linear(in_features, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + self.norm1 = nn.LayerNorm(config.hidden_size, + dtype=dtype, + device=device) + self.act1 = nn.GELU() + self.act2 = nn.functional.silu + self.dense_h_to_4h = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False, + dtype=dtype, + device=device) + self.gate_proj = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False, + dtype=dtype, + device=device) + self.dense_4h_to_h = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, x): + x = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + """vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + vision_config = Namespace(**config.vision_config) + + self.patch_embedding = PatchEmbedding(vision_config, + dtype=dtype, + device=device) + self.transformer = EVA2CLIPTransformer(vision_config, + dtype=dtype, + device=device) + self.linear_proj = GLU(config, + in_features=vision_config.hidden_size, + dtype=dtype, + device=device) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=vision_config.hidden_size, + kernel_size=2, + stride=2, + dtype=dtype, + device=device) + self.boi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + self.eoi = nn.Parameter( + torch.empty(1, 1, config.hidden_size, dtype=dtype, device=device)) + + def forward(self, images): + """forward.""" + x = self.patch_embedding(images) + x = self.transformer(x) + + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + return x + + class CogVLMModel(nn.Module): """model.""" @@ -353,6 +637,9 @@ def __init__(self, dtype=dtype, device=device) + # vision model + self.vision = EVA2CLIPModel(config, dtype=dtype, device=device) + # build rotary embedding emb_type = RopeType.LinearScaling rope_dim = config.hidden_size // config.num_attention_heads @@ -371,6 +658,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, attn_metadata: Any = None, + images: torch.Tensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, lang_ids: torch.LongTensor = None, vision_ids: torch.LongTensor = None, @@ -379,7 +667,12 @@ def forward( # token embedding if inputs_embeds is None: + if images is not None: + images_features = self.vision(images) + inputs_embeds = self.embed_tokens(input_ids) + if vision_ids is not None: + inputs_embeds[0, vision_ids] = images_features.flatten(0, 1) hidden_states = inputs_embeds @@ -416,85 +709,7 @@ def get_input_embeddings(self): VISION_TOKEN_TYPE = 1 -def get_vision_expert_mask(token_type_ids: torch.LongTensor): - vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool) - vision_token_mask[:, :-1] = (token_type_ids[:, :-1] - == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] - == VISION_TOKEN_TYPE) - language_token_mask = ~vision_token_mask - return vision_token_mask, language_token_mask - - -def build_position_ids(x: torch.BoolTensor) -> torch.LongTensor: - tmp = x.clone() - # image boi eoi token as LANGUAGE_TOKEN_TYPE - is_boi_eoi = torch.zeros_like(x, dtype=torch.bool) - is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & ( - tmp[:, :-1] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE) - is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & ( - tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) - is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE) - tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE - # final position ids - y = torch.zeros_like(x, dtype=torch.long) - y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ( - (tmp[:, 1:] == VISION_TOKEN_TYPE) & - (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)) - y = y.cumsum(dim=-1) - return y - - -def _get_cogvlm_position_ids(context): - """get cogvlm position_ids.""" - q_seqlens = context.q_seqlens - history_lengths = context.kv_seqlens - q_seqlens - vision_input_info = context.vision_inputs - position_id_offsets = (vision_input_info.history_image_token_lengths - - vision_input_info.history_image_nums * 3) - lang_ids = None - vis_ids = None - if context.is_decoding: - position_ids = history_lengths - position_id_offsets - else: - if vision_input_info.input_embeddings is not None and len( - vision_input_info.input_embeddings) > 0: - starts = history_lengths - vision_input_info.history_lengths - ends = starts + q_seqlens - token_type_ids = vision_input_info.input_embedding_indexing.to( - torch.int) - history_position_lengths = (vision_input_info.history_lengths - - position_id_offsets) - position_ids_all = (history_position_lengths[:, None] + - build_position_ids(token_type_ids)) - position_ids = torch.cat([ - pids[s:e] - for (pids, s, e) in zip(position_ids_all, starts, ends) - ]) - vision_token_mask_all, _ = get_vision_expert_mask(token_type_ids) - vision_token_mask = torch.cat([ - masks[s:e] - for (masks, s, e) in zip(vision_token_mask_all, starts, ends) - ]) - mask_indexing = torch.arange(vision_token_mask.shape[-1], - device=vision_token_mask.device) - vis_ids = mask_indexing[vision_token_mask] - lang_ids = mask_indexing[~vision_token_mask] - - else: - position_ids = context.attention_mask.long().cumsum(-1) - 1 - position_ids += (history_lengths - - position_id_offsets).unsqueeze(-1) - device = position_ids.device - position_ids_1d = [ - ids[:l] for ids, l in zip(position_ids.cpu(), q_seqlens.cpu()) - ] - position_ids = torch.cat(position_ids_1d).to(device) - - return position_ids, lang_ids, vis_ids - - -class CogVLMForCausalLM(nn.Module, CudaGraphMixin): +class CogVLMForCausalLM(nn.Module, CudaGraphMixin, DeployModelMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -512,6 +727,8 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + # preprocessor + self.input_processor = CogVLMInputProcessor(self.config, dtype) # build model self.model = CogVLMModel(config, dtype=dtype, device=device) # build lm_head @@ -527,6 +744,7 @@ def forward( position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + images: torch.Tensor = None, inputs_embeds: torch.Tensor = None, lang_ids: torch.LongTensor = None, vision_ids: torch.LongTensor = None, @@ -538,6 +756,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, inputs_embeds=inputs_embeds, lang_ids=lang_ids, vision_ids=vision_ids, @@ -561,8 +780,36 @@ def prepare_inputs_for_generation( """prepare input.""" # get input_ids, position_ids and attention metadatas input_ids = context.input_ids - position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context) - position_ids = position_ids[None] + + # position_ids, lang_ids, vis_ids = _get_cogvlm_position_ids(context) + position_ids = context.position_ids + lang_ids = None + vis_ids = None + + # vision inputs + images = None + if context.input_multimodals is not None: + images = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + images = [data for im_data in images for data in im_data] + if len(images) == 0: + images = None + + if images is not None: + image_token_id = images[0].meta['image_token_id'] + vis_mask = input_ids[0] == image_token_id + images = torch.stack([data.data for data in images]) + + # get lang_ids + vis_range = torch.arange(0, + input_ids.size(-1), + device=input_ids.device) + vis_ids = vis_range[vis_mask] + lang_ids = vis_range[~vis_mask] + attn_metadata = context.attn_metadata # process vision embeddings @@ -581,6 +828,7 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + images=images, inputs_embeds=inputs_embeds, lang_ids=lang_ids, vision_ids=vis_ids, @@ -597,8 +845,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if 'model.vision' in name: - continue if 'rotary_emb.inv_freq' in name: continue if ('rotary_emb.cos_cached' in name @@ -607,6 +853,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.config.tie_word_embeddings and 'lm_head.weight' in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: + if '.vision.' in name: + continue if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -620,6 +868,136 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): load_weight(param, q, shard_id='q') load_weight(param, k, shard_id='k') load_weight(param, v, shard_id='v') + elif '.query_key_value' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') else: param = params_dict[name] load_weight(param, loaded_weight) + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_imgs = [[] for _ in model_metas] + else: + input_imgs = [] + for mm in input_multimodals: + if mm is None: + input_imgs.append([]) + else: + input_imgs.append(mm.get('image', [])) + + config = self.config + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + num_pad = vision_token_num - 3 + + batched_num_img_tokens = [] + new_model_metas = [] + for meta, imgs in zip(model_metas, input_imgs): + if meta is None: + num_img_tokens = 0 + else: + num_img_tokens = meta.get('num_img_tokens', 0) + + batched_num_img_tokens.append(num_img_tokens) + + num_img_tokens += num_pad * len(imgs) + new_model_metas.append(dict(num_img_tokens=num_img_tokens)) + + # prepare cogvlm position_ids + q_seqlens = context.q_seqlens + position_ids = context.position_ids + + if context.is_decoding or all(len(imgs) == 0 for imgs in input_imgs): + num_img_tokens = torch.tensor(batched_num_img_tokens, + device=position_ids.device) + position_ids -= num_img_tokens[None] + else: + batched_position_ids = position_ids[0].split(q_seqlens) + for pos_ids, num_img_tok, imgs in zip(batched_position_ids, + batched_num_img_tokens, + input_imgs): + pos_ids -= num_img_tok + if len(imgs) == 0: + continue + + seq_len = pos_ids.size(0) + start = pos_ids[0].cpu().item() + new_pos_ids = [] + + imgs = sorted(imgs, key=lambda img: img.start) + for img in imgs: + img_pad_pos = img.start + 1 - num_img_tok + num_pad = img.end - img.start - 2 + new_pos_ids += list(range(start, img_pad_pos)) + new_pos_ids += [img_pad_pos] * num_pad + start = img_pad_pos + 1 + num_img_tok += num_pad + + remain = seq_len - len(new_pos_ids) + new_pos_ids += list(range(start, start + remain)) + + new_pos_ids = pos_ids.new_tensor(new_pos_ids) + pos_ids[:] = new_pos_ids + + position_ids = torch.cat(batched_position_ids)[None] + context.position_ids = position_ids + + return new_model_metas + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class CogVLMInputProcessor(BaseModelInputProcessor): + """input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + image_size: int = config.vision_config['image_size'] + patch_size: int = config.vision_config['patch_size'] + self.vision_token_num = ((image_size // patch_size // 2) * + (image_size // patch_size // 2) + 2) + + def preprocess_input(self, + input_ids: List[int], + input_multimodals=None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 70dd8f215..fe3b2997c 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -1,17 +1,310 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import LayerNorm, RMSNorm +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin -class InternVLChatModel(nn.Module, CudaGraphMixin): +class InternVisionEmbeddings(nn.Module): + """intern vision embedding.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), ) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=dtype, + device=device) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.empty(1, + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False).reshape( + 1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, + width) + ], + dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +NORM2FN = { + 'rms_norm': RMSNorm, + 'layer_norm': LayerNorm, +} + + +class InternAttention(nn.Module): + """intern vl attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.qkv = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=config.qkv_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm( + self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device, + ) + self.k_norm = RMSNorm( + self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device, + ) + + self.scale = self.head_dim**-0.5 + + # o_proj + self.proj = build_rowwise_linear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states): + """forward.""" + + # qkv proj + qkv_states = self.qkv(hidden_states) + q, k, v = self.qkv.split_qkv(qkv_states) + + if self.qk_normalization: + q_shape = q.shape + q = self.q_norm(q.flatten(-2, -1)).view(q_shape) + k = self.k_norm(k.flatten(-2, -1)).view(q_shape) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class InternMLP(nn.Module): + """intern vl mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.act = ACT2FN[config.hidden_act] + + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + """intern vision encoder layer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = InternAttention(config, dtype=dtype, device=device) + self.mlp = InternMLP(config, dtype=dtype, device=device) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + self.ls1 = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + self.ls2 = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + + def forward( + self, + hidden_states: torch.Tensor, + ): + """forward.""" + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 + + hidden_states = hidden_states + self.mlp( + self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2 + + return hidden_states + + +class InternVisionEncoder(nn.Module): + """intern vision encoder.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, dtype=dtype, device=device) + for idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds, + ): + """forward.""" + hidden_states = inputs_embeds + for _, encoder_layer in enumerate(self.layers): + layer_outputs = encoder_layer(hidden_states, ) + hidden_states = layer_outputs + return hidden_states + + +class InternVisionModel(nn.Module): + """intern vision model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + + self.embeddings = InternVisionEmbeddings(config, + dtype=dtype, + device=device) + self.encoder = InternVisionEncoder(config, dtype=dtype, device=device) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + """forward.""" + assert pixel_values.dim() == 4 + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + last_hidden_state = encoder_outputs + + return last_hidden_state + + +class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin): def __init__(self, config: PretrainedConfig, @@ -21,31 +314,106 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + self.select_layer = config.select_layer + llm_config = config.llm_config + self.llm_arch_name = llm_config.architectures[0] + self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' + + vision_config = config.vision_config + if self.is_mono: + from .internvl_patch import InternVisionPatchModel + self.vision_model = InternVisionPatchModel( + vision_config, + dtype=dtype, + device=device, + ) + else: + self.vision_model = InternVisionModel(vision_config, + dtype=dtype, + device=device) + self.language_model = build_model_from_hf_config(llm_config, dtype=dtype, device=device) - self.llm_arch_name = llm_config.architectures[0] + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.llm_config.hidden_size + self.downsample_ratio = config.downsample_ratio + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, + dtype=dtype, + device=device), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_hidden_size, + dtype=dtype, + device=device), nn.GELU(), + nn.Linear(llm_hidden_size, + llm_hidden_size, + dtype=dtype, + device=device)) # for Mono-InternVL - self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' if self.is_mono: assert dtype != torch.float16, ( 'Currently Mono-InternVL does not support FP16 due to' 'numerical instability. Please use BF16 instead.') + self.input_processor = InternVLInputProcessor(self.config, dtype) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> + # N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + """extract vision feature.""" + assert self.select_layer == -1 + vit_embeds = self.vision_model(pixel_values) + if self.is_mono: + if int(vit_embeds.shape[1]**0.5)**2 != vit_embeds.shape[1]: + vit_embeds = vit_embeds[:, 1:, :] + else: + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, vision_embedding_indexing: torch.Tensor = None, text_embedding_indexing: torch.Tensor = None, **kwargs, ): + if inputs_embeds is None and pixel_values is not None: + # extract feature + vit_embeds = self.extract_feature(pixel_values) + lang_embeds = self.language_model.get_input_embeddings()(input_ids) + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) + + inputs_embeds = lang_embeds + if self.is_mono: return self.language_model.forward( input_ids=input_ids, @@ -80,11 +448,38 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - # get inputs from context vision_embeddings = context.input_embeddings - vision_embedding_indexing = context.input_embedding_indexing + vision_embedding_indexing = None + # vision inputs + pixel_values = None + image_mask = None + if context.input_multimodals is not None: + pixel_values = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + pixel_values = [ + data for im_data in pixel_values for data in im_data + ] + if len(pixel_values) > 0: + image_token_id = pixel_values[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat([data.data for data in pixel_values]) + else: + pixel_values = None + image_mask = None + + if self.is_mono and pixel_values is not None: + vision_embedding_indexing = torch.arange(input_ids.shape[1], + device=input_ids.device) + vision_embedding_indexing = vision_embedding_indexing[ + image_mask[0]] + + # get inputs from context if vision_embeddings is not None and len(vision_embeddings) > 0: + vision_embedding_indexing = context.input_embedding_indexing if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds[:, @@ -104,6 +499,8 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, vision_embedding_indexing=vision_embedding_indexing, text_embedding_indexing=text_embedding_indexing, @@ -114,18 +511,85 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" - prefix_length = len('language_model.') + lang_prefix = 'language_model.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name.startswith(lang_prefix): + continue + + if 'qkv' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + lang_prefix_length = len(lang_prefix) new_weights = dict() for key, val in weights: - if not key.startswith('language_model.'): + if not key.startswith(lang_prefix): continue - new_key = key[prefix_length:] + new_key = key[lang_prefix_length:] new_weights[new_key] = val self.language_model.load_weights(new_weights.items()) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class InternVLInputProcessor(BaseModelInputProcessor): + """internvl input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + vision_config = config.vision_config + self.image_size = vision_config.image_size + self.patch_size = vision_config.patch_size + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.vision_token_num = self.num_patches // 4 + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/internvl_patch.py b/lmdeploy/pytorch/models/internvl_patch.py new file mode 100644 index 000000000..d13ad2d39 --- /dev/null +++ b/lmdeploy/pytorch/models/internvl_patch.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.configuration_utils import PretrainedConfig + + +class InternVisionEmbeddings(nn.Module): + """mono vision.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(1, 1, self.embed_dim, dtype=dtype, device=device), ) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + dtype=dtype, + device=device) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.empty(1, + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False) + pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, + width) + ], + dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternVisionPatchModel(nn.Module): + """mono vision.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embeddings = InternVisionEmbeddings(config, + dtype=dtype, + device=device) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + ): + if len(pixel_values.shape) != 4: + raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') + + hidden_states = self.embeddings(pixel_values)[:, 1:] + return hidden_states diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index f38c5ef02..8acd20a8d 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -450,22 +450,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class LlavaLlamaForCausalLM(LlamaForCausalLM): - """llava llama for causallm.""" - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - """load weights.""" - - new_weights = dict() - for key, val in weights: - if key.startswith('model.vision_tower'): - continue - if key.startswith('model.mm_projector'): - continue - if key.startswith('model.image_newline'): - continue - new_weights[key] = val - - super().load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/llava.py b/lmdeploy/pytorch/models/llava.py index 56cb5ca67..4c330fd84 100644 --- a/lmdeploy/pytorch/models/llava.py +++ b/lmdeploy/pytorch/models/llava.py @@ -1,17 +1,443 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.llava.configuration_llava import LlavaConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin -class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin): +class LlavaMultiModalProjector(nn.Module): + + def __init__(self, + config: LlavaConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CLIPVisionEmbeddings(nn.Module): + """clip vision embedding.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.empty(self.embed_dim, dtype=dtype, device=device)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + dtype=dtype, + device=device, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding( + self.num_positions, + self.embed_dim, + dtype=dtype, + device=device, + ) + self.register_buffer('position_ids', + torch.arange(self.num_positions, + device=device).expand((1, -1)), + persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """This method allows to interpolate the pre-trained position + encodings, to be able to use the model on higher resolution images. + + This method is also adapted to support torch.jit tracing. + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing + # to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing( + ) and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + from transformers.utils import torch_int + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, + sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode='bicubic', + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size + or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f' ({self.image_size}*{self.image_size}).') + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding( + self.position_ids) + return embeddings + + +class CLIPAttention(nn.Module): + """clip attention.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.qkv_proj = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + self.scale = self.head_dim**-0.5 + + # o_proj + self.out_proj = build_rowwise_linear(self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + q, k, v = self.qkv_proj.split_qkv(qkv_states) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if attention_mask is not None and causal_attention_mask is not None: + attn_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attn_mask = causal_attention_mask + else: + attn_mask = attention_mask + + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attn_mask, + scale=self.scale) + + # o proj + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.flatten(-2, -1) + attn_output = self.out_proj(attn_output) + return attn_output + + +class CLIPMLP(nn.Module): + """clip mlp.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + from transformers.activations import ACT2FN + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + self.fc2 = build_rowwise_linear( + config.intermediate_size, + config.hidden_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + """clip encoder layer.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config, dtype=dtype, device=device) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.mlp = CLIPMLP(config, dtype=dtype, device=device) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + ): + """forward.""" + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """clip encoder.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + CLIPEncoderLayer(config, dtype=dtype, device=device) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + vision_feature_layer: int = -1, + ): + """forward.""" + hidden_states = inputs_embeds + num_vision_layers = len(self.layers) + vision_feature_layer + 1 + for _, encoder_layer in enumerate(self.layers[:num_vision_layers]): + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask=causal_attention_mask, + ) + + hidden_states = layer_outputs + + return hidden_states + + +class CLIPVisionTransformer(nn.Module): + """clip vision transformer.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config, + dtype=dtype, + device=device) + self.pre_layrnorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + self.encoder = CLIPEncoder(config, dtype=dtype, device=device) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps, + dtype=dtype, + device=device) + + def forward( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + vision_feature_layer: int = -1, + ) -> BaseModelOutputWithPooling: + """forward.""" + hidden_states = self.embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + vision_feature_layer=vision_feature_layer) + + last_hidden_state = encoder_outputs + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=None, + attentions=None, + ) + + +class CLIPVisionModel(nn.Module): + """clip vision model.""" + + def __init__(self, + config, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.vision_model = CLIPVisionTransformer(config, + dtype=dtype, + device=device) + + def forward(self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + vision_feature_layer: int = -1, + **kwargs): + """forward.""" + return self.vision_model( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + vision_feature_layer=vision_feature_layer) + + +def build_vision_model(vision_config, + dtype: torch.dtype = None, + device: torch.device = None): + """build vision model.""" + model_type = vision_config.model_type + + if model_type == 'clip_vision_model': + return CLIPVisionModel(vision_config, dtype, device) + else: + raise NotImplementedError(f'<{model_type}> is not implemented.') + + +class LlavaForConditionalGeneration(nn.Module, CudaGraphMixin, + DeployModelMixin): def __init__(self, config: PretrainedConfig, @@ -22,19 +448,67 @@ def __init__(self, self.config = config self.ctx_mgr = ctx_mgr text_config = config.text_config + + self.vision_tower = build_vision_model(config.vision_config, + dtype=dtype, + device=device) + self.language_model = build_model_from_hf_config(text_config, dtype=dtype, device=device) + self.multi_modal_projector = LlavaMultiModalProjector(config, + dtype=dtype, + device=device) + + self.input_processor = LLavaInputProcessor(config, dtype) + + def get_image_features(self, + pixel_values, + vision_feature_layer: int = -1, + vision_feature_select_strategy: str = 'default'): + """get image features.""" + selected_image_feature = self.vision_tower( + pixel_values, vision_feature_layer=vision_feature_layer)[0] + if vision_feature_select_strategy == 'default': + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == 'full': + selected_image_feature = selected_image_feature + else: + raise ValueError( + f'Unexpected select feature strategy: {vision_feature_select_strategy}' # noqa: E501 + ) + image_features = self.multi_modal_projector(selected_image_feature) + image_features = image_features.flatten(0, 1)[None] + + return image_features + def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_mask: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs, ): + if inputs_embeds is None: + image_features = None + if pixel_values is not None: + vision_feature_layer = self.config.vision_feature_layer + select_strategy = self.config.vision_feature_select_strategy + image_features = self.get_image_features( + pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=select_strategy) + inputs_embeds = self.language_model.get_input_embeddings()( + input_ids) + if pixel_values is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + image_features) + return self.language_model.forward(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, @@ -59,6 +533,27 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata + + # vision inputs + pixel_values = None + image_mask = None + if context.input_multimodals is not None: + pixel_values = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + pixel_values = [ + data for im_data in pixel_values for data in im_data + ] + if len(pixel_values) > 0: + image_token_id = pixel_values[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat([data.data for data in pixel_values]) + else: + pixel_values = None + image_mask = None + # get inputs from context vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -75,18 +570,403 @@ def prepare_inputs_for_generation( position_ids=position_ids, past_key_values=past_key_values, attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, inputs_embeds=inputs_embeds, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" - prefix_length = len('language_model.') + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ] + + # vis model + lang_prefix = 'language_model.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name.startswith(lang_prefix): + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + # language model + prefix_length = len(lang_prefix) new_weights = dict() for key, val in weights: - if not key.startswith('language_model.'): + if not key.startswith(lang_prefix): continue new_key = key[prefix_length:] new_weights[new_key] = val self.language_model.load_weights(new_weights.items()) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class LLavaInputProcessor(BaseModelInputProcessor): + """llava input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict(image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + + from transformers.image_processing_utils import select_best_resolution + + if not isinstance(grid_pinpoints, list): + raise TypeError('grid_pinpoints should be a list of tuples or lists') + + if not isinstance(image_size, (list, tuple)): + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def unpad_image(tensor, original_size): + """Unpads a PyTorch tensor of a padded and resized image.""" + if not isinstance(original_size, (list, tuple)): + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(round(original_height * scale_factor, 7)) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(round(original_width * scale_factor, 7)) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + + +def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): + """Calculate the number of patches after the preprocessing for images of + any resolution.""" + from transformers.image_processing_utils import select_best_resolution + if not isinstance(grid_pinpoints, list): + raise TypeError('grid_pinpoints should be a list of tuples or lists') + + if not isinstance(image_size, (list, tuple)): + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + + num_patches = (height // patch_size) * (width // patch_size) + # add the base patch + num_patches += 1 + return num_patches + + +class LlavaNextForConditionalGeneration(LlavaForConditionalGeneration): + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config=config, + ctx_mgr=ctx_mgr, + dtype=dtype, + device=device) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, + dtype=dtype, + device=device)) + self.input_processor = LLavaNextInputProcessor(config, dtype) + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int, + vision_feature_select_strategy: str, + ): + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) for imsize in image_sizes + ] + if pixel_values.dim() == 5: + # stacked if input is + # (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] + for pix_val, num_patch in zip(pixel_values, image_num_patches) + ] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of + # (num_patches, num_channels, height, width) + raise ValueError(f'pixel_values of shape {pixel_values.shape}, ' + 'expect to be of 4 or 5 dimensions') + + selected_image_feature = self.vision_tower( + pixel_values, vision_feature_layer=vision_feature_layer)[0] + if vision_feature_select_strategy == 'default': + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == 'full': + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features + + def pack_image_features(self, + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=None): + + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = (self.config.vision_config.image_size // + self.config.vision_config.patch_size) + + if vision_feature_select_strategy == 'default': + expected_num_patches = height * width + elif vision_feature_select_strategy == 'full': + expected_num_patches = height * width + 1 + if expected_num_patches != base_image_feature.shape[0]: + raise ValueError('The number of patches is ' + 'not consistent with the image size.') + + (num_patch_height, + num_patch_width) = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + image_feature = image_feature.view(num_patch_height, + num_patch_width, height, + width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, + 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, + image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None].expand( + *image_feature.shape[:-1], 1).to( + image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), + dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat( + (image_feature, image_newline[None].to(image_feature)), + dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + return image_features + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_sizes: torch.Tensor = None, + image_mask: torch.Tensor = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + if inputs_embeds is None: + image_features = None + if pixel_values is not None: + vision_feature_layer = self.config.vision_feature_layer + select_strategy = self.config.vision_feature_select_strategy + image_sizes = image_sizes.tolist() + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=select_strategy) + image_features = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=select_strategy, + image_newline=self.image_newline, + ) + image_features = image_features[None] + inputs_embeds = self.language_model.get_input_embeddings()( + input_ids) + if pixel_values is not None: + inputs_embeds.masked_scatter_(image_mask[..., None], + image_features) + + return self.language_model.forward(input_ids=input_ids, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + attn_metadata=attn_metadata) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # vision inputs + pixel_values = None + image_sizes = None + image_mask = None + if context.input_multimodals is not None: + img_mms = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + img_mms = [data for im_data in img_mms for data in im_data] + if len(img_mms) > 0: + image_token_id = img_mms[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + pixel_values = torch.cat([data.data for data in img_mms]) + image_sizes = torch.cat( + [data.meta['image_sizes'] for data in img_mms]) + else: + pixel_values = None + image_sizes = None + + # get inputs from context + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, + vision_embedding_indexing, :] = vision_embeddings.to( + inputs_embeds) + + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_sizes=image_sizes, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + ) + + +class LLavaNextInputProcessor(BaseModelInputProcessor): + """llava input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + image_sizes = input_mm['image_sizes'] + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict( + image_sizes=image_sizes, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py index 04af4c852..ad2796309 100644 --- a/lmdeploy/pytorch/models/mistral.py +++ b/lmdeploy/pytorch/models/mistral.py @@ -420,22 +420,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class LlavaMistralForCausalLM(MistralForCausalLM): - """llava forcausallm.""" - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - """load weights.""" - - new_weights = dict() - for key, val in weights: - if key.startswith('model.vision_tower'): - continue - if key.startswith('model.mm_projector'): - continue - if key.startswith('model.image_newline'): - continue - new_weights[key] = val - - super().load_weights(new_weights.items()) diff --git a/lmdeploy/pytorch/models/mllama.py b/lmdeploy/pytorch/models/mllama.py index 2596fe529..0a0f0e9f1 100644 --- a/lmdeploy/pytorch/models/mllama.py +++ b/lmdeploy/pytorch/models/mllama.py @@ -3,23 +3,61 @@ import torch from torch import nn +from torch.nn import functional as F from transformers.models.llama import LlamaConfig -from transformers.models.mllama.modeling_mllama import MllamaTextConfig +from transformers.models.mllama.modeling_mllama import (MllamaTextConfig, + MllamaVisionConfig) +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, - SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm, RMSNorm, + RopeType, SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.nn.rotary_embedding import Llama3Parameters from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from .utils.cudagraph import CudaGraphMixin +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2 +from .utils.model import DeployModelMixin MLLAMA_IMAGE_TOKEN_ID = 128256 MLLAMA_IMAGE_TOKEN = '<|image|>' +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, + # max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask * attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + class LlamaAttention(nn.Module): """Rewrite module of LlamaAttention.""" @@ -157,6 +195,7 @@ def __init__(self, self.head_dim, num_kv_heads=self.num_key_value_heads, v_head_size=self.head_dim, + causal=False, ) self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -579,7 +618,542 @@ def get_logits(self, hidden_states: torch.Tensor): return self.lm_head(hidden_states) -class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin): +class MllamaPrecomputedPositionEmbedding(nn.Module): + """vis position embedding.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.config = config + self.num_patches = (config.image_size // config.patch_size)**2 + 1 + self.hidden_size = config.hidden_size + + self.gate = nn.Parameter(torch.empty(1, dtype=dtype, device=device)) + + # position embedding + self.embedding = nn.Parameter( + torch.empty(self.num_patches, + self.hidden_size, + dtype=dtype, + device=device)) + + # tile position embedding + self.tile_embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * + self.num_patches * self.hidden_size, + dtype=dtype, + device=device) + + self._weight_inited = False + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + gate_tanh = self.gate.tanh() + gated_position_embedding = (1 - gate_tanh) * self.embedding + self.gate_tanh = gate_tanh + self.gated_position_embedding = gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) + + self._weight_inited = True + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + """forward.""" + self._init_weight() + + # position embeddings + hidden_state = hidden_state + self.gated_position_embedding + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = (self.gate_tanh * + tile_position_embedding) + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, + config: MllamaVisionConfig, + is_gated: bool = True, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size, + dtype=dtype, + device=device) + if is_gated: + self.gate = nn.Parameter(torch.empty(1, dtype=dtype, + device=device)) + + self._weight_inited = False + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + gate_tanh = self.gate.tanh() + self.gate_tanh = gate_tanh + + self._weight_inited = True + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + self._init_weight() + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate_tanh + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaVisionAttention(nn.Module): + """mllama vision attention.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + + # packed qkv + self.qkv_proj = build_qkv_proj( + self.embed_dim, + num_q_heads=self.num_heads, + num_kv_heads=self.num_heads, + head_size=self.head_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # o_proj + self.o_proj = build_rowwise_linear(self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_state.size(0) + qkv_states = self.qkv_proj(hidden_state) + qkv_states = qkv_states.flatten(0, -2) + query, key, value = self.qkv_proj.split_qkv(qkv_states) + + query = query.unflatten(0, (batch_size, -1)) + key = key.unflatten(0, (batch_size, -1)) + value = value.unflatten(0, (batch_size, -1)) + q_seq_len = query.shape[1] + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(query, + key, + value, + attn_mask=attention_mask) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + return output + + +class MllamaVisionMLP(nn.Module): + """mllama vision mlp.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = build_colwise_linear( + config.hidden_size, + config.intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + self.fc2 = build_rowwise_linear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionEncoderLayer(nn.Module): + """vision encoder layer.""" + + def __init__(self, + config: MllamaVisionConfig, + is_gated: bool, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.hidden_size = config.hidden_size + self.is_gated = is_gated + self.self_attn = MllamaVisionAttention(config, + dtype=dtype, + device=device) + self.mlp = MllamaVisionMLP(config, dtype=dtype, device=device) + + self.input_layernorm = LayerNorm(self.hidden_size, + eps=config.norm_eps, + dtype=dtype, + device=device) + self.post_attention_layernorm = LayerNorm(self.hidden_size, + eps=config.norm_eps, + dtype=dtype, + device=device) + + if is_gated: + self.gate_attn = nn.Parameter( + torch.empty(1, dtype=dtype, device=device)) + self.gate_ffn = nn.Parameter( + torch.empty(1, dtype=dtype, device=device)) + + self._weight_inited = not is_gated + + def _init_weight(self): + """init weight.""" + if self._weight_inited: + return + + self.gate_attn_tanh = self.gate_attn.tanh() + self.gate_ffn_tanh = self.gate_ffn.tanh() + + self._weight_inited = True + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + self._init_weight() + + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) + if self.is_gated: + hidden_state = self.gate_attn_tanh * hidden_state + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + if self.is_gated: + hidden_state = self.gate_ffn_tanh * hidden_state + hidden_state = residual + hidden_state + + outputs = hidden_state + + return outputs + + +class MllamaVisionEncoder(nn.Module): + """vision encoder.""" + + def __init__(self, + config: MllamaVisionConfig, + num_layers=32, + is_gated=False, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, + is_gated, + dtype=dtype, + device=device) for _ in range(num_layers) + ]) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """forward.""" + encoder_states = () + for encoder_layer in self.layers: + encoder_states = encoder_states + (hidden_states, ) + hidden_states = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + ) + encoder_states = encoder_states + (hidden_states, ) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + """vision model.""" + + def __init__(self, + config: MllamaVisionConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + + self.config = config + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.intermediate_layers_indices = config.intermediate_layers_indices + self.dtype = dtype + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding='valid', + bias=False, + dtype=dtype, + device=device, + ) + + self.class_embedding = nn.Parameter( + torch.empty(self.hidden_size, dtype=dtype, device=device)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config, + dtype=dtype, + device=device, + ) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501 + config, + is_gated=True, + dtype=dtype, + device=device, + ) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding( # noqa: E501 + config, + is_gated=True, + dtype=dtype, + device=device, + ) + + # layer norms + self.layernorm_pre = nn.LayerNorm( + self.hidden_size, + dtype=dtype, + device=device, + ) + self.layernorm_post = nn.LayerNorm( + self.hidden_size, + dtype=dtype, + device=device, + ) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + dtype=dtype, + device=device, + ) + self.global_transformer = MllamaVisionEncoder( + config, + config.num_global_layers, + is_gated=True, + dtype=dtype, + device=device, + ) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + ): + """forward.""" + (batch_size, num_concurrent_media, num_tiles, num_channels, height, + width) = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode='constant', value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + all_intermediate_hidden_states = [ + all_intermediate_hidden_states[i] + for i in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack( + all_intermediate_hidden_states, dim=-1) + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) + + return hidden_state + + +class MllamaForConditionalGeneration(nn.Module, CudaGraphMixin, + DeployModelMixin): """rewrote model of MllamaForConditionalGeneration.""" packed_modules_mapping = { @@ -602,16 +1176,32 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + + self.vision_model = MllamaVisionModel( + config.vision_config, + dtype=dtype, + device=device, + ) # build MllamaForCausalLM self.language_model = MllamaForCausalLM(config.text_config, dtype=dtype, device=device) + + self.multi_modal_projector = build_rowwise_linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + dtype=dtype, + device=device, + ) self.dtype = dtype - def flat_encoder_result(self, cross_attention_states: torch.Tensor, - attn_metadata: Any, input_ids: torch.LongTensor): + # preprocessor + self.input_processor = MLlamaInputProcessor(self.config, dtype) + + def flat_encoder_result(self, attn_metadata: Any, + input_ids: torch.LongTensor): # since every state share the same shape - cross_attention_states = torch.cat(cross_attention_states, 0) full_text_row_masked_out_mask = torch.ones( (attn_metadata.q_seqlens.sum(), 1), dtype=torch.bool) start_pos = 0 @@ -621,39 +1211,51 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor, full_text_row_masked_out_mask[start_pos:img_id] = False start_pos += q_seq_len full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( - cross_attention_states.device) + input_ids.device) - return cross_attention_states, full_text_row_masked_out_mask + return full_text_row_masked_out_mask def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List[List[torch.Tensor]], - cross_attention_states: Optional[torch.Tensor] = None, + pixel_values: torch.Tensor = None, + aspect_ratio_ids: torch.Tensor = None, + aspect_ratio_mask: torch.Tensor = None, attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, cross_attn_metadata: Any = None, **kwargs, ): """model forward, return logits.""" + if cross_attn_metadata is None: full_text_row_masked_out_mask = None # FIXME basically, we want to inference # text requests and image requests separately - elif cross_attention_states is None and ( - cross_attn_metadata.kv_seqlens is None - or int(cross_attn_metadata.kv_seqlens.sum()) == 0): + elif pixel_values is None and (cross_attn_metadata.kv_seqlens is None): full_text_row_masked_out_mask = None elif cross_attn_metadata.is_decoding: - cross_attention_states = None - full_text_row_masked_out_mask = torch.ones( - (attn_metadata.q_seqlens.sum(), 1), - dtype=torch.bool, - device=input_ids.device) + full_text_row_masked_out_mask = input_ids.new_ones( + input_ids.size(-1), 1) else: - cross_attention_states, full_text_row_masked_out_mask = \ - self.flat_encoder_result(cross_attention_states, cross_attn_metadata, input_ids) # noqa + full_text_row_masked_out_mask = self.flat_encoder_result( + cross_attn_metadata, input_ids) # noqa + + cross_attention_states = None + if pixel_values is not None: + cross_attention_states = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + ) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + _, bsz, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + hidden_states = self.language_model( input_ids=input_ids, position_ids=position_ids, @@ -670,15 +1272,6 @@ def get_logits(self, hidden_states: torch.Tensor): """compute logits of the model output.""" return self.language_model.get_logits(hidden_states) - def support_cuda_graph( - self, - input_ids: torch.Tensor, - **kwargs, - ): - """support cudagraph.""" - - return False - def get_input_embeddings(self): """get input embeddings.""" return self.language_model.model.get_input_embeddings() @@ -694,13 +1287,30 @@ def prepare_inputs_for_generation( input_ids = context.input_ids position_ids = context.position_ids attn_metadata = context.attn_metadata - cross_attention_states = context.cross_attention_states - if cross_attention_states is not None: - cross_attention_states = [ - t.to(input_ids.device) for t in cross_attention_states - if t is not None - ] cross_attn_metadata = context.cross_attn_metadata + if int(cross_attn_metadata.kv_seqlens.sum()) == 0: + cross_attn_metadata.kv_seqlens = None + device = input_ids.device + + # process image input + pixel_values = None + aspect_ratio_ids = None + aspect_ratio_mask = None + if context.input_multimodals is not None: + pixel_values = [] + aspect_ratio_ids = [] + aspect_ratio_mask = [] + batched_image_data = [ + input_mm['image'] for input_mm in context.input_multimodals + ] + for image_data in batched_image_data: + for data in image_data: + pixel_values.append(data.data) + aspect_ratio_ids.append(data.meta['aspect_ratio_ids']) + aspect_ratio_mask.append(data.meta['aspect_ratio_mask']) + pixel_values = torch.cat(pixel_values, dim=0).to(device) + aspect_ratio_ids = torch.cat(aspect_ratio_ids, dim=0).to(device) + aspect_ratio_mask = torch.cat(aspect_ratio_mask, dim=0).to(device) # process vision embeddings vision_embeddings = context.input_embeddings @@ -719,7 +1329,9 @@ def prepare_inputs_for_generation( past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - cross_attention_states=cross_attention_states, + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, cross_attn_metadata=cross_attn_metadata, ) @@ -742,8 +1354,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): continue - if 'vision_model' in name or 'multi_modal_projector' in name: - continue if self.config.text_config.tie_word_embeddings and 'lm_head.weight' in name: # noqa continue for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -756,3 +1366,161 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) + + def support_cuda_graph( + self, + input_ids: torch.Tensor, + attn_metadata: Any, + cross_attn_metadata: Any, + **kwargs, + ): + """support cudagraph.""" + + if not attn_metadata.is_decoding: + return False + + if cross_attn_metadata is None: + return False + + if cross_attn_metadata.kv_seqlens is None: + return False + + return True + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """make cudagraph buffers from forward inputs.""" + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, + **kwargs) + + device = graph_meta.device + max_batches = graph_meta.max_batchs + input_buffers['cross_kv_seqlens'] = torch.zeros(max_batches, + dtype=torch.int64, + device=device) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """fill cudagraph buffers from forward inputs.""" + input_buffers = graph_meta.input_buffers + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, + **kwargs) + + attn_metadata = new_inputs['attn_metadata'] + cross_attn_metadata = new_inputs['cross_attn_metadata'] + block_offsets = attn_metadata.block_offsets + batch_size, _ = block_offsets.size() + + kv_seqlens = cross_attn_metadata.kv_seqlens + if kv_seqlens.data_ptr() != input_buffers['cross_kv_seqlens'].data_ptr( + ): + input_buffers['cross_kv_seqlens'].zero_() + input_buffers['cross_kv_seqlens'][:batch_size] = kv_seqlens + + new_batch_size = next_power_of_2(batch_size) + cross_attn_metadata.block_offsets = input_buffers[ + 'block_offsets'][:new_batch_size] + cross_attn_metadata.q_start_loc = input_buffers[ + 'q_start_loc'][:new_batch_size] + cross_attn_metadata.q_seqlens = input_buffers[ + 'q_seqlens'][:new_batch_size] + cross_attn_metadata.kv_seqlens = input_buffers[ + 'cross_kv_seqlens'][:new_batch_size] + + new_inputs['cross_attn_metadata'] = cross_attn_metadata + return new_inputs + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + model_metas = context.model_metas + if model_metas is None: + batch_size = context.q_seqlens.size(0) + model_metas = [dict(cross_kv_len=0) for _ in range(batch_size)] + + if context.is_decoding: + return model_metas + + vision_inputs = context.vision_inputs + if vision_inputs is None: + return model_metas + + input_mms = vision_inputs.input_multimodals + if input_mms is None: + return model_metas + + config = self.config.vision_config + image_size = config.image_size + patch_size = config.patch_size + wh = image_size // patch_size + img_kv_len = wh * wh + 1 + img_kv_len = img_kv_len * 4 + + new_model_metas = [] + for idx, input_mm in enumerate(input_mms): + if input_mm is None: + new_model_metas.append(model_metas[idx]) + images = input_mm['image'] + num_img = len(images) + + cross_kv_len = 0 + if model_metas[idx] is not None: + cross_kv_len = model_metas[idx].get('cross_kv_len', + cross_kv_len) + cross_kv_len += img_kv_len * num_img + new_model_metas.append(dict(cross_kv_len=cross_kv_len)) + + return model_metas + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class MLlamaInputProcessor(BaseModelInputProcessor): + """mllama input processor.""" + + def __init__(self, config: LlamaConfig, dtype: torch.dtype) -> None: + self.config = config + self.dtype = dtype + + vision_config = self.config.vision_config + image_size = vision_config.image_size + patch_size = vision_config.patch_size + wh = image_size // patch_size + encoder_len = wh * wh + 1 + encoder_len = encoder_len * 4 + self.encoder_len = encoder_len + + def preprocess_input(self, input_ids, input_multimodals, **kwargs): + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'] + aspect_ratio_ids = input_mm['aspect_ratio_ids'] + aspect_ratio_mask = input_mm['aspect_ratio_mask'] + offset = input_mm['offset'] + + if pixel_values.dtype != self.dtype: + pixel_values = pixel_values.to(self.dtype) + + mm_data = MultiModalTensor( + data=pixel_values, + start=offset, + end=offset + 1, + encoder_len=self.encoder_len, + meta=dict(aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 1059bfee4..e7b460026 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -85,14 +85,10 @@ # llava MODULE_MAP.update( { - 'LlavaLlamaForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlavaLlamaForCausalLM', - 'LlavaMistralForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.mistral.LlavaMistralForCausalLM', 'LlavaForConditionalGeneration': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration', # noqa: E501 'LlavaNextForConditionalGeneration': # noqa: E501 - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaForConditionalGeneration' + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llava.LlavaNextForConditionalGeneration' # noqa: E501 }) # qwen @@ -158,7 +154,7 @@ # phi3 vision MODULE_MAP.update({ 'Phi3VForCausalLM': - f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3.Phi3VForCausalLM', + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.phi3_v.Phi3VForCausalLM', }) # phi-3.5-moe diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index f9477fdab..288fdf3b1 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -435,7 +435,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: param = params_dict[name] load_weight(param, loaded_weight) - - -class Phi3VForCausalLM(Phi3ForCausalLM): - ... diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py new file mode 100644 index 000000000..c4bf72c76 --- /dev/null +++ b/lmdeploy/pytorch/models/phi3_v.py @@ -0,0 +1,476 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig + +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn.linear import build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .phi3 import Phi3ForCausalLM, Phi3Model +from .utils.model import DeployModelMixin + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0, + dropout=0.0, + hidden_act='quick_gelu', + hidden_size=1024, + image_size=336, + initializer_factor=1.0, + initializer_range=0.02, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768) + + +class Phi3ImageEmbedding(nn.Module): + """image embedding.""" + + def __init__(self, + config: PretrainedConfig, + wte=None, + dtype: torch.dtype = None, + device: torch.device = None, + **kwargs): + super().__init__() + self.config = config + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size + + self.wte = wte + + if (isinstance(config.img_processor, dict) and + config.img_processor.get('name', None) == 'clip_vision_model'): + assert 'model_name' in config.img_processor, ( + 'model_name must be provided for CLIPVisionModel') + assert 'image_dim_out' in config.img_processor, ( + 'image_dim_out must be provided for CLIPVisionModel') + assert 'num_img_tokens' in config.img_processor, ( + 'num_img_tokens must be provided for CLIPVisionModel') + assert config.img_processor[ + 'model_name'] == 'openai/clip-vit-large-patch14-336' + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + self.img_processor = CLIPVisionModel(clip_config).to(device).to( + dtype) + image_dim_out = config.img_processor['image_dim_out'] + self.num_img_tokens = config.img_processor['num_img_tokens'] + else: + raise NotImplementedError( + f'img_processor = {config.img_processor}, not implemented') + + self.image_dim_out = image_dim_out + self.img_sizes = None + + self.use_hd_transform = kwargs.get('use_hd_transform', False) + self.with_learnable_separator = kwargs.get('with_learnable_separator', + False) + self.hd_transform_order = kwargs.get('hd_transform_order', 'glb_sub') + # with_hd_transform and with_learnable_separator should have same value + assert (self.use_hd_transform == self.with_learnable_separator), ( + 'use_hd_transform and with_learnable_separator ' + 'should have same value') + if self.with_learnable_separator: + assert self.use_hd_transform, ( + 'learnable separator is only for hd transform') + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter( + torch.empty([1, 1, self.image_dim_out * 4], + dtype=dtype, + device=device)) + self.sub_GN = nn.Parameter( + torch.empty([1, 1, 1, self.image_dim_out * 4], + dtype=dtype, + device=device)) + + projection_cls = kwargs.get('projection_cls', 'linear') + if projection_cls == 'linear': + self.img_projection = nn.Linear(image_dim_out, + hidden_size, + dtype=dtype, + device=device) + elif projection_cls == 'mlp' and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear(image_dim_out * 4, + dim_projection, + dtype=dtype, + device=device) + ] + for _ in range(1, depth): + layers.extend([ + nn.GELU(), + nn.Linear(dim_projection, + dim_projection, + dtype=dtype, + device=device) + ]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == 'mlp': + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear(image_dim_out, + dim_projection, + dtype=dtype, + device=device) + ] + for _ in range(1, depth): + layers.extend([ + nn.GELU(), + nn.Linear(dim_projection, + dim_projection, + dtype=dtype, + device=device) + ]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError( + f'projection_cls = {projection_cls}, not implemented') + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get('layer_idx', -2) + self.type_feature = config.img_processor.get( + 'type_feature', 'patch') + else: + self.layer_idx = -2 + self.type_feature = 'patch' + + def get_img_features(self, + img_embeds: torch.FloatTensor) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor(img_embeds, + output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == 'patch': + patch_feature = img_feature[:, 1:] + return patch_feature + + if TYPE_FEATURE == 'cls_patch': + return img_feature + + raise NotImplementedError + + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + image_sizes=None, + image_mask: torch.Tensor = None, + ) -> torch.FloatTensor: + """forward.""" + + target_device = pixel_values.device + target_dtype = pixel_values.dtype + + img_embeds = pixel_values + img_sizes = image_sizes + img_sizes = img_sizes.cpu() + + if self.use_hd_transform and img_sizes is not None and len(img_sizes): + assert img_embeds.ndim == 5, f'img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform' # noqa E501 + # img_embeds: (num_images, max_num_crops, 3, H, W) + # img_sizes: (num_images, 2).view(1, -1) + + bs = img_embeds.shape[0] + # Nx(HW)xC + img_features = self.get_img_features(img_embeds.flatten(0, 1)) + base_feat_height = base_feat_width = int( + img_features.shape[1]**0.5) + + assert base_feat_height == 24 and base_feat_width == 24, f'base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect 24x24 features for hd transform' # noqa E501 + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view( + bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + h = h // 336 + w = w // 336 + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = global_img_feature.reshape( + 1, H // 2, 2, H // 2, 2, + C).permute(0, 1, 3, 2, 4, + 5).reshape(1, H // 2, H // 2, 4 * C) + temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], + dim=2).reshape(1, -1, 4 * C) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) + # ->(num_crops, 12, 12, 2, 2, 1024) + # -> (num_crops, 12*12, 4*1024) + sub_img = (sub_img.reshape(B_, H // 2, 2, H // 2, 2, + C).permute(0, 1, 3, 2, 4, 5)) + sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute( + 0, 1, 3, 2, 4, 5).reshape(1, h * 12, w * 12, 4 * C) + temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1) + sub_img = torch.cat([sub_img, temp_sub_GN], + dim=2).reshape(1, -1, 4 * C) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == 'glb_sub': + output_imgs.append( + torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == 'sub_glb': + output_imgs.append( + torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError( + f'hd_transform_order = {self.hd_transform_order}' + ) # noqa E501 + + temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + assert temp_len == output_imgs[-1].shape[ + 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}' # noqa E501 + output_len.append(temp_len) + + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection( + _output_img.to(target_device).to(target_dtype)) + img_feature_proj = img_feature_proj.flatten(0, 1) + img_set_tensor.append(img_feature_proj) + img_set_tensor = torch.cat(img_set_tensor)[None] + elif img_embeds.ndim == 4: + tt = (self.get_img_features(img_embeds).to(target_device).to( + target_dtype).reshape(-1, self.image_dim_out)) + img_set_tensor = self.img_projection( + tt) # adapted visual features. + elif img_embeds.ndim == 3: + tt = (img_embeds.to(target_device).to(target_dtype).view( + -1, self.image_dim_out)) + img_set_tensor = self.img_projection( + tt) # adapted visual features. + else: + raise NotImplementedError + + hidden_states = self.wte(input_ids) + + hidden_states.masked_scatter_(image_mask[..., None], img_set_tensor) + + return hidden_states + + +class Phi3VModel(Phi3Model): + """phi3v model.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config=config, dtype=dtype, device=device) + + self.vision_embed_tokens = None + if isinstance(config.embd_layer, dict): + # vision embedding layer + embedding_config = { + 'embedding_cls': config.embd_layer['embedding_cls'], + **config.embd_layer + } + self.vision_embed_tokens = Phi3ImageEmbedding( + config, + wte=self.embed_tokens, + dtype=dtype, + device=device, + **embedding_config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_mask: torch.Tensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of LlamaModel.forward.""" + + if inputs_embeds is None and pixel_values is not None: + inputs_embeds = self.vision_embed_tokens( + input_ids, + pixel_values, + image_sizes, + image_mask, + ) + + return super().forward( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + +class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin): + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config, ctx_mgr, dtype=dtype, device=device) + self.config = config + self.ctx_mgr = ctx_mgr + # build model + self.model = Phi3VModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + self.input_processor = Phi3VInputProcessor(config, dtype) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + pixel_values: torch.Tensor = None, + image_sizes: torch.Tensor = None, + image_mask: torch.Tensor = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """forward.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_sizes=image_sizes, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: torch.Tensor = None, + context: StepContext = None, + ): + """prepare input.""" + output = super().prepare_inputs_for_generation( + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + context=context) + + # vision inputs + pixel_values = None + if context.input_multimodals is not None: + input_mms = [ + input_mm.get('image', []) + for input_mm in context.input_multimodals + ] + # flatten batch + input_mms = [data for im_data in input_mms for data in im_data] + if len(input_mms) > 0: + pixel_values = torch.cat([data.data for data in input_mms]) + image_sizes = torch.cat( + [data.meta['image_sizes'] for data in input_mms]) + image_token_id = input_mms[0].meta['image_token_id'] + image_mask = output['input_ids'] == image_token_id + output['pixel_values'] = pixel_values + output['image_sizes'] = image_sizes + output['image_mask'] = image_mask + + return output + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + super().load_weights(weights) + + vis_prefix = 'vision_embed_tokens.' + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if not (vis_prefix in name): + continue + param = params_dict[name] + load_weight(param, loaded_weight) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +class Phi3VInputProcessor(BaseModelInputProcessor): + """Phi3V input processor.""" + + def __init__(self, config: PretrainedConfig, dtype) -> None: + self.config = config + self.dtype = dtype + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'].to(self.dtype) + image_sizes = input_mm['image_sizes'] + offset = input_mm['offset'] + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=offset, + end=offset + num_pad, + meta=dict( + image_sizes=image_sizes, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index b10baaa4d..4e2b1017b 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -1,18 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import torch from torch import nn from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.engine.input_process import (BaseModelInputProcessor, + PreprocessInputResult) from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, - SiluAndMul, build_rotary_embedding) -from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, FlashAttention, + LayerNorm, RMSNorm, RopeType, SiluAndMul, + build_rotary_embedding) +from lmdeploy.pytorch.nn.linear import (build_colwise_linear, + build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin, next_power_of_2 +from .utils.model import DeployModelMixin def _apply_mrope_selection(hidden_states: torch.Tensor, @@ -337,7 +343,337 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen2VLForConditionalGeneration(nn.Module, CudaGraphMixin): +class PatchEmbed(nn.Module): + """Patch Embed.""" + + def __init__(self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + dtype=dtype, + device=device) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view(-1, self.in_channels, + self.temporal_patch_size, + self.patch_size, self.patch_size) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim) + return hidden_states + + +class VisionRotaryEmbedding(nn.Module): + """vision rotary embedding.""" + + def __init__(self, + dim: int, + theta: float = 10000.0, + device: torch.device = None) -> None: + super().__init__() + inv_freq = 1.0 / (theta**( + torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class VisionAttention(nn.Module): + """Vision attention.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + dim = config.embed_dim + num_heads = config.num_heads + head_dim = dim // num_heads + self.head_dim = head_dim + + # packed qkv + self.qkv = build_qkv_proj( + dim, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attention = FlashAttention( + num_heads, + head_dim, + causal=False, + ) + + # o_proj + self.proj = build_rowwise_linear(dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor] + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + # qkv proj + qkv_states = self.qkv(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + q, k, v = self.qkv.split_qkv(qkv_states) + + cos, sin = rotary_pos_emb + q, k = self.apply_rotary_pos_emb(q, k, cos, sin) + + attn_output = self.attention( + q, + k, + v, + q_start_loc=cu_seqlens[:-1], + q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1], + ) + + attn_output = attn_output.reshape(seq_length, -1) + + # o proj + attn_output = self.proj(attn_output) + return attn_output + + +class VisionMlp(nn.Module): + """Vision mlp.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + dim = config.embed_dim + hidden_dim = int(config.embed_dim * config.mlp_ratio) + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.fc1 = build_colwise_linear( + dim, + hidden_dim, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in [ + 'gelu', 'gelu_fast', 'quick_gelu', 'gelu_python' + ]: + self.act = nn.GELU() + else: + self.act = ACT2FN[config.hidden_act] + + # down + self.fc2 = build_rowwise_linear(hidden_dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + return self.fc2(self.act(self.fc1(x))) + + +class Qwen2VLVisionBlock(nn.Module): + """Vision block.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + self.norm1 = LayerNorm(config.embed_dim, + eps=1e-6, + dtype=dtype, + device=device) + self.norm2 = LayerNorm(config.embed_dim, + eps=1e-6, + dtype=dtype, + device=device) + + self.attn = VisionAttention(config, dtype=dtype, device=device) + + self.mlp = VisionMlp(config, dtype=dtype, device=device) + + def forward(self, + hidden_states, + cu_seqlens, + rotary_pos_emb, + residual: Optional[torch.Tensor] = None) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + else: + hidden_states, residual = self.norm1(hidden_states, residual) + + hidden_states = self.attn(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb) + + hidden_states, residual = self.norm2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class PatchMerger(nn.Module): + """PatchMerger.""" + + def __init__(self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = nn.LayerNorm(context_dim, + eps=1e-6, + dtype=dtype, + device=device) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, + self.hidden_size, + dtype=dtype, + device=device), + nn.GELU(), + nn.Linear(self.hidden_size, dim, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen2VisionTransformerPretrainedModel(nn.Module): + """Vision transformer.""" + + def __init__(self, + config: PretrainedConfig, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + dtype=dtype, + device=device, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2, + device=device) + + self.blocks = nn.ModuleList([ + Qwen2VLVisionBlock(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.depth) + ]) + self.merger = PatchMerger(dim=config.hidden_size, + context_dim=config.embed_dim, + spatial_merge_size=config.spatial_merge_size, + dtype=dtype, + device=device) + + def rot_pos_emb(self, grid_thw): + """rotary position embedding.""" + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.patch_embed(hidden_states) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + + residual = None + for blk in self.blocks: + hidden_states, residual = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + residual=residual) + + hidden_states = hidden_states + residual + + return self.merger(hidden_states) + + +class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixin, + CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -360,6 +696,16 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr + + # preprocessor + self.input_processor = Qwen2VLInputProcessor(self.config) + + # build vision model + self.visual = Qwen2VisionTransformerPretrainedModel( + config.vision_config, + dtype=dtype, + device=device, + ) # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head @@ -377,9 +723,26 @@ def forward( attn_metadata: Any = None, inputs_embeds: torch.Tensor = None, mrope_position_ids: torch.Tensor = None, + pixel_values: torch.Tensor = None, + vis_cu_seqlens: torch.Tensor = None, + vis_pos_emb: torch.Tensor = None, + image_mask: torch.Tensor = None, **kwargs, ): """model forward, return logits.""" + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if pixel_values is not None: + dtype = inputs_embeds.dtype + pixel_values = pixel_values.to(dtype) + vis_pos_emb = (vis_pos_emb[0].to(dtype), + vis_pos_emb[1].to(dtype)) + image_embeds = self.visual(pixel_values, + cu_seqlens=vis_cu_seqlens, + rotary_pos_emb=vis_pos_emb) + inputs_embeds = inputs_embeds.masked_scatter( + image_mask[..., None], image_embeds) + hidden_states = self.model( input_ids=input_ids, position_ids=position_ids, @@ -416,6 +779,36 @@ def prepare_inputs_for_generation( position_ids = context.position_ids attn_metadata = context.attn_metadata + pixel_values = None + vis_cu_seqlens = None + vis_pos_emb = None + image_mask = None + if context.input_multimodals is not None: + image_data = [ + input_mm['image'] for input_mm in context.input_multimodals + ] + + if len(image_data) > 0: + # flatten batch + image_data = [ + data for im_data in image_data for data in im_data + ] + pixel_values = torch.cat([data.data for data in image_data]) + image_token_id = image_data[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + grid_thw = torch.cat( + [data.meta['grid_thw'] for data in image_data]).cpu() + vis_pos_emb = self.visual.rot_pos_emb(grid_thw) + vis_cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).to(pixel_values.device) + vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, + dtype=torch.int32) + vis_pos_emb = vis_pos_emb.repeat(1, 2) + vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin()) + + mrope_position_ids = getattr(context, 'mrope_position_ids', None) + # process vision embeddings vision_embeddings = context.input_embeddings vision_embedding_indexing = context.input_embedding_indexing @@ -433,7 +826,11 @@ def prepare_inputs_for_generation( past_key_values=past_key_values, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, - mrope_position_ids=context.mrope_position_ids, + mrope_position_ids=mrope_position_ids, + pixel_values=pixel_values, + vis_cu_seqlens=vis_cu_seqlens, + vis_pos_emb=vis_pos_emb, + image_mask=image_mask, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -450,8 +847,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - if 'visual' in name: - continue if 'rotary_emb.inv_freq' in name: continue if ('rotary_emb.cos_cached' in name @@ -467,8 +862,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): load_weight(param, loaded_weight, shard_id=shard_id) break else: - param = params_dict[name] - load_weight(param, loaded_weight) + if '.qkv.' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): """make cudagraph buffers from forward inputs.""" @@ -510,3 +912,130 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): 'mrope_position_ids'] return new_inputs + + def _update_model_meta_decoding(self, context: StepContext): + """update model meta for decoding.""" + model_metas = context.model_metas + position_ids = context.position_ids + + mrope_deltas = [meta['mrope_delta'] for meta in model_metas] + mrope_deltas = position_ids.new_tensor(mrope_deltas) + mrope_position_ids = position_ids + mrope_deltas[None] + mrope_position_ids = mrope_position_ids.expand(3, -1) + + context.mrope_position_ids = mrope_position_ids + return model_metas + + def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): + """get mrope ids.""" + t, h, w = grid_thw + h //= 2 + w //= 2 + stride = torch.tensor([h * w, w, 1], device=device)[:, None] + size = torch.tensor([t, h, w], device=device)[:, None] + pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) + pos_ids = pos_ids // stride % size + return pos_ids + + def _update_model_meta_prefilling(self, context: StepContext): + """update model meta for prefilling.""" + model_metas = context.model_metas + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_multimodals = [None] * len(model_metas) + position_ids = context.position_ids + batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) + mrope_position_ids = [] + new_model_metas = [] + for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, + input_multimodals): + images = [] + if input_mm is not None: + images = input_mm['image'] + if model_meta is None or 'mrope_delta' not in model_meta: + mrope_delta = 0 + else: + mrope_delta = model_meta['mrope_delta'] + + pos_start = pos_ids[0].item() + mrope_pos_ids = pos_ids + mrope_delta + mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() + for img in images: + grid_thw = img.meta['grid_thw'][0].tolist() + _, h, w = grid_thw + h //= 2 + w //= 2 + num_pad = img.end - img.start - max(h, w) + mrope_delta -= num_pad + fill_start = img.start - pos_start + fill_end = img.end - pos_start + img_pos_ids = self._get_multimodal_pos_ids( + grid_thw, pos_ids.device) + img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] + mrope_pos_ids[:, fill_end:] -= num_pad + mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids + + mrope_position_ids.append(mrope_pos_ids) + new_model_metas.append(dict(mrope_delta=mrope_delta)) + + mrope_position_ids = torch.cat(mrope_position_ids, dim=1) + context.mrope_position_ids = mrope_position_ids + + return new_model_metas + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + if context.is_decoding: + return self._update_model_meta_decoding(context) + else: + return self._update_model_meta_prefilling(context) + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return self.input_processor + + +InputMultiModalType = List[Dict[str, Any]] + + +class Qwen2VLInputProcessor(BaseModelInputProcessor): + """qwen2 input processor.""" + + def __init__(self, config: PretrainedConfig) -> None: + self.config = config + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'] + image_grid_thw = input_mm['image_grid_thw'] + offset = input_mm['offset'] + start = offset + image_token_id = input_mm.get('image_token_id', 0) + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=start, + end=start + num_pad, + meta=dict( + grid_thw=image_grid_thw, + image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py new file mode 100644 index 000000000..99bd4c4bf --- /dev/null +++ b/lmdeploy/pytorch/models/utils/model.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Iterable, List, Optional, Tuple + +import torch + +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor +from lmdeploy.pytorch.model_inputs import StepContext + + +class DeployModelMixin: + + def forward(self, *args, **kwargs): + """forward of model.""" + raise NotImplementedError('Not Implemented') + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """prepare input.""" + raise NotImplementedError('Not Implemented') + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """load weights.""" + raise NotImplementedError('Not Implemented') + + def get_logits(self, hidden_states: torch.Tensor): + """compute logits of the model output.""" + return hidden_states + + def update_weights(self): + """update weights.""" + pass + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """update model meta.""" + return None + + def get_input_processor(self) -> BaseModelInputProcessor: + """get input processor.""" + return None diff --git a/lmdeploy/pytorch/models/utils/multimodal.py b/lmdeploy/pytorch/models/utils/multimodal.py new file mode 100644 index 000000000..aebcaf407 --- /dev/null +++ b/lmdeploy/pytorch/models/utils/multimodal.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs + +PreparedInputs = Tuple[List[int], MultiModalInputs] + + +class MultiModalMixin: + + def prepare_multimodal_input(self, input_ids, input_multimodals, + **kwargs) -> PreparedInputs: + """prepare multimodals inputs.""" + raise NotImplementedError('prepare input not implemented.') diff --git a/lmdeploy/pytorch/multimodal/__init__.py b/lmdeploy/pytorch/multimodal/__init__.py new file mode 100644 index 000000000..c3e8c6a16 --- /dev/null +++ b/lmdeploy/pytorch/multimodal/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_type import MultiModalData, MultiModalTensor + +__all__ = ['MultiModalData', 'MultiModalTensor'] diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py new file mode 100644 index 000000000..95ec72d26 --- /dev/null +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Any, Dict, List, Union + +from torch import Tensor + + +class MultiModalData: + pass + + +MultiModalDataList = List[MultiModalData] + +NestedTensor = Union[Tensor, List[Tensor]] + + +@dataclass +class MultiModalTensor: + data: NestedTensor + start: int + end: int = None + encoder_len: int = None + meta: Dict[str, Any] = None + + def __post_init__(self): + if self.end is None: + self.end = self.start + + def to_device(self, device: str, non_blocking: bool = False): + """to device.""" + if isinstance(self.data, Tensor): + self.data = self.data.to(device=device, non_blocking=non_blocking) + else: + data = [ + d.to(device=device, non_blocking=non_blocking) + for d in self.data + ] + self.data = data + + if self.meta is not None: + for k, v in self.meta.items(): + if isinstance(v, Tensor): + v = v.to(device=device, non_blocking=non_blocking) + self.meta[k] = v + elif hasattr(v, 'to_device'): + v = v.to_device(device=device, non_blocking=non_blocking) + self.meta[k] = v + return self + + +MultiModalInputs = Dict[str, List[MultiModalTensor]] diff --git a/lmdeploy/pytorch/multimodal/image_type.py b/lmdeploy/pytorch/multimodal/image_type.py new file mode 100644 index 000000000..19211a381 --- /dev/null +++ b/lmdeploy/pytorch/multimodal/image_type.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Any, ClassVar, Dict + +from PIL import Image + +from .data_type import MultiModalData + + +@dataclass +class ImageData(MultiModalData): + data: Image + loc: int + meta: Dict[str, Any] = None + type: ClassVar[str] = 'image' diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 63df9a5ae..4705115bf 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -2,7 +2,7 @@ # attention module is modified from: # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ from .activation import GeluAndMul, SiluAndMul # noqa: F401 -from .attention import Attention # noqa: F401 +from .attention import Attention, FlashAttention # noqa: F401 from .norm import LayerNorm, RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 from .rotary_embedding import RopeType # noqa: F401 diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index 26f1034d3..484041dfc 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -9,6 +9,15 @@ from .utils import get_distribute_size +def _update_num_heads(num_heads: int, num_kv_heads: int, replicate_kv: bool): + """update heads.""" + world_size, rank = get_world_rank() + num_heads = get_distribute_size(num_heads, world_size, rank) + if not replicate_kv: + num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) + return num_heads, num_kv_heads + + class Attention(nn.Module): """Attention layer.""" @@ -23,14 +32,20 @@ def __init__( sliding_window: int = None, logit_softcapping: float = None, replicate_kv: bool = False, + causal: bool = True, **kwargs, ): super().__init__() - num_heads, num_kv_heads = self._update_num_heads( - num_heads, num_kv_heads, replicate_kv) + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_size is None: + v_head_size = head_size + num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads, + replicate_kv) layer_backend = get_backend() - impl_builder = layer_backend.get_layer_impl_builder(OpType.Attention) + impl_builder = layer_backend.get_layer_impl_builder( + OpType.PagedAttention) self.impl = impl_builder.build( num_heads=num_heads, @@ -41,18 +56,10 @@ def __init__( alibi=alibi, sliding_window=sliding_window, logit_softcapping=logit_softcapping, + causal=causal, **kwargs, ) - def _update_num_heads(self, num_heads: int, num_kv_heads: int, - replicate_kv: bool): - """update heads.""" - world_size, rank = get_world_rank() - num_heads = get_distribute_size(num_heads, world_size, rank) - if not replicate_kv: - num_kv_heads = get_distribute_size(num_kv_heads, world_size, rank) - return num_heads, num_kv_heads - def forward( self, query: torch.Tensor, @@ -77,3 +84,77 @@ def forward( v_scales_zeros=v_scales_zeros, inplace=inplace, ) + + +class FlashAttention(nn.Module): + """flash attention w/o paging.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float = None, + num_kv_heads: int = None, + v_head_dim: int = None, + causal: bool = True, + sliding_window: int = None, + logit_softcapping: float = None, + replicate_kv: bool = False, + **kwargs, + ): + super().__init__() + if num_kv_heads is None: + num_kv_heads = num_heads + if v_head_dim is None: + v_head_dim = head_dim + num_heads, num_kv_heads = _update_num_heads(num_heads, num_kv_heads, + replicate_kv) + + layer_backend = get_backend() + + impl_builder = layer_backend.get_layer_impl_builder( + OpType.FlashAttention) + + self.impl = impl_builder.build( + num_heads=num_heads, + head_dim=head_dim, + scale=scale, + num_kv_heads=num_kv_heads, + v_head_dim=v_head_dim, + causal=causal, + sliding_window=sliding_window, + logit_softcapping=logit_softcapping, + **kwargs, + ) + + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + q_start_loc: torch.Tensor, + q_seqlens: torch.Tensor, + kv_start_loc: torch.Tensor = None, + kv_seqlens: torch.Tensor = None, + max_q_seqlen: int = None) -> torch.Tensor: + """forward.""" + + if max_q_seqlen is None: + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + + if kv_start_loc is None and kv_seqlens is None: + kv_start_loc = q_start_loc + kv_seqlens = q_seqlens + + assert kv_start_loc is not None + assert kv_seqlens is not None + + return self.impl.forward( + query, + key, + value, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_q_seqlen=max_q_seqlen, + ) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 400c492b0..ec4957608 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -18,5 +18,5 @@ tiktoken torch<=2.4.0,>=2.0.0 torchvision<=0.19.0,>=0.15.0 transformers -triton>=2.2.0,<=3.0.0; sys_platform == "linux" +triton==3.0.0; sys_platform == "linux" uvicorn diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py index 7d4b7a7f3..e56de44b3 100644 --- a/tests/pytorch/kernel/test_flash_attention.py +++ b/tests/pytorch/kernel/test_flash_attention.py @@ -10,20 +10,26 @@ def _conti_input(data, q_seqlens): return data -def _make_bias(q_seqlens, history_lens, neg_val): - full_seq_lens = q_seqlens + history_lens +def _make_bias(q_seqlens, history_lens, neg_val, causal): + kv_seqlens = q_seqlens + history_lens max_seq_len = q_seqlens.max().item() - max_full_len = full_seq_lens.max().item() - seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens] - for r, l in zip(seq_ranges, q_seqlens): - r[l:] = -max_full_len - seq_ranges = torch.stack(seq_ranges, dim=0).cuda() - kv_ranges = [torch.arange(max_full_len) for _ in full_seq_lens] - kv_ranges = torch.stack(kv_ranges, 0).cuda() - mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, - None, - None] - return mask.float() * neg_val + max_kv_len = kv_seqlens.max().item() + if causal: + seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens] + for r, l in zip(seq_ranges, q_seqlens): + r[l:] = -max_kv_len + seq_ranges = torch.stack(seq_ranges, dim=0).cuda() + kv_ranges = [torch.arange(max_kv_len) for _ in kv_seqlens] + kv_ranges = torch.stack(kv_ranges, 0).cuda() + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > + history_lens[:, None, None]) + return mask.float() * neg_val + else: + q_mask = torch.arange(max_seq_len)[None].cuda() < q_seqlens[:, None] + k_mask = torch.arange(max_kv_len)[None].cuda() < kv_seqlens[:, None] + mask = q_mask[:, :, None] & k_mask[:, None, :] + + return (~mask).float() * neg_val def _naive_attention(batched_q, batched_kv, bias): @@ -100,6 +106,10 @@ def num_heads_q(self, request): def num_heads_k(self, request): yield request.param + @pytest.fixture + def causal(self, request): + yield request.param + @pytest.fixture def q_seqlens(self, request): yield torch.tensor(request.param, device='cuda') @@ -138,8 +148,8 @@ def batched_kv(self, q_seqlens, history_lens, num_heads_k, head_dim_k, head_dim_v, dtype): torch.manual_seed(123) batch_size = len(q_seqlens) - full_seq_lens = q_seqlens + history_lens - max_seq_len = full_seq_lens.max().item() + kv_seqlens = q_seqlens + history_lens + max_seq_len = kv_seqlens.max().item() k = torch.rand(batch_size, max_seq_len, num_heads_k, @@ -167,9 +177,9 @@ def conti_kv(self, kv_seqlens, batched_kv): yield (conti_k, conti_v) @pytest.fixture - def mask(self, q_seqlens, history_lens): + def mask(self, q_seqlens, history_lens, causal): neg_val = -1e30 - yield _make_bias(q_seqlens, history_lens, neg_val) + yield _make_bias(q_seqlens, history_lens, neg_val, causal) @pytest.fixture def gt(self, batched_q, batched_kv, mask): @@ -183,11 +193,13 @@ def conti_gt(self, gt, q_seqlens): @pytest.mark.parametrize('head_dim_v', [32], indirect=True) @pytest.mark.parametrize('num_heads_q', [8, 2], indirect=True) @pytest.mark.parametrize('num_heads_k', [2], indirect=True) + @pytest.mark.parametrize('causal', [True, False], indirect=True) @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([30, 50, 70, 90], [50, 40, 30, 20])], indirect=True) def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, - kv_start_loc, kv_seqlens, head_dim_v, conti_gt): + kv_start_loc, kv_seqlens, head_dim_v, causal, + conti_gt): from lmdeploy.pytorch.kernels.cuda.flashattention import \ flash_attention_fwd max_seq_len = q_seqlens.max().item() @@ -202,7 +214,8 @@ def test_flash_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, q_seqlens=q_seqlens, kv_start_loc=kv_start_loc, kv_seqlens=kv_seqlens, - max_seqlen=max_seq_len) + max_seqlen=max_seq_len, + causal=causal) torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) @pytest.fixture