diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 88b69dbb5..83a43e848 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -514,6 +514,7 @@ def forward( max_s: int, adapter_data: AdapterBatchData, prefill_cache_indices: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = inputs_embeds diff --git a/server/lorax_server/models/vlm_causal_lm.py b/server/lorax_server/models/vlm_causal_lm.py index 7c4a3b543..bc7bbfd05 100644 --- a/server/lorax_server/models/vlm_causal_lm.py +++ b/server/lorax_server/models/vlm_causal_lm.py @@ -1,5 +1,5 @@ from io import BytesIO -from typing import Iterable, List, Optional, Tuple, Type +from typing import Dict, Iterable, List, Optional, Tuple, Type import torch import torch.distributed @@ -19,12 +19,22 @@ from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.state import PREFIX_CACHING from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.lora import LM_HEAD tracer = trace.get_tracer(__name__) IDEFICS2_FAKE_TOKEN = "" IDEFICS2_IMAGE_TOKEN = "" +LANGUAGE_ATTN_Q_PROJ = "self_attn.language.q_proj" +LANGUAGE_ATTN_K_PROJ = "self_attn.language.k_proj" +LANGUAGE_ATTN_V_PROJ = "self_attn.language.v_proj" +LANGUAGE_ATTN_O_PROJ = "self_attn.language.out_proj" +VISION_ATTN_Q_PROJ = "self_attn.vision.q_proj" +VISION_ATTN_K_PROJ = "self_attn.vision.k_proj" +VISION_ATTN_V_PROJ = "self_attn.vision.v_proj" +VISION_ATTN_O_PROJ = "self_attn.vision.out_proj" + def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ @@ -290,6 +300,53 @@ def batch_type(self) -> Type[VlmCausalLMBatch]: def max_past(self) -> Optional[int]: return getattr(self.model.text_model, "max_past", None) + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + language_prefix = "language_model.model.layers" + vision_prefix = "vision_tower.vision_model.encoder.layers" + for i, layer in enumerate(self.model.text_model.model.layers): + layer_weights[(i, LANGUAGE_ATTN_K_PROJ)] = ( + f"{language_prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, LANGUAGE_ATTN_V_PROJ)] = ( + f"{language_prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, LANGUAGE_ATTN_O_PROJ)] = ( + f"{language_prefix}.{i}.self_attn.out_proj", + layer.self_attn.o_proj, + ) + layer_weights[(i, LANGUAGE_ATTN_Q_PROJ)] = ( + f"{language_prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + for i, layer in enumerate(self.model.vision_tower.encoder.layers): + layer_weights[(i, VISION_ATTN_K_PROJ)] = ( + f"{vision_prefix}.{i}.self_attn.k_proj", + layer.self_attn.qkv, + ) + layer_weights[(i, VISION_ATTN_V_PROJ)] = ( + f"{vision_prefix}.{i}.self_attn.v_proj", + layer.self_attn.qkv, + ) + layer_weights[(i, VISION_ATTN_O_PROJ)] = ( + f"{vision_prefix}.{i}.self_attn.out_proj", + layer.self_attn.out_proj, + ) + layer_weights[(i, VISION_ATTN_Q_PROJ)] = ( + f"{vision_prefix}.{i}.self_attn.q_proj", + layer.self_attn.qkv, + ) + + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.text_model.lm_head) + return layer_weights + + def forward( self, batch: VlmCausalLMBatch,