diff --git a/server/lorax_server/models/custom_modeling/mllama.py b/server/lorax_server/models/custom_modeling/mllama.py index 5c367ac7f..d912f4d5d 100644 --- a/server/lorax_server/models/custom_modeling/mllama.py +++ b/server/lorax_server/models/custom_modeling/mllama.py @@ -277,13 +277,6 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type): self.head_size = config.hidden_size // self.num_heads self.num_key_value_heads = getattr(config, "n_head_kv", None) or self.num_heads - self.qkv_proj = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, - ) self.qkv_proj = load_attention( config, prefix, @@ -929,11 +922,29 @@ def __init__(self, prefix, config, weights): config.text_config._attn_implementation = "sdpa" self.hidden_size = config.text_config.hidden_size cross_attention_layers = getattr(config.text_config, "cross_attention_layers", []) + # note(ajinkya): Since cross attention layers are not currently targeted, we need to handle + # the case of some layers not having lora adapters which lorax doesn't currently support. + # Hence, this hack where we a dict that goes from actual layer index to index if the layers + # were filtered according to their types. For exmaple: + # all layers = [0, 1, 2, 3, 4] + # cross attention layers = [1, 3] + # layer wise layer ids = [0, 0, 1, 1, 2] + # since layers 1 and 3 are of different type they are indexed as if they are sequential + # this prevents illegal memory access errors from running the punica kernels + layer_wise_layer_id = [0] * config.text_config.num_hidden_layers + i = j = 0 + for k in range(config.text_config.num_hidden_layers): + if j == len(cross_attention_layers) or k < cross_attention_layers[j]: + layer_wise_layer_id[k] = i + i += 1 + else: + layer_wise_layer_id[k] = j + j += 1 def create_layer(layer_id, prefix, config, weights): layer_cls = FlashLlamaCrossLayer if layer_id in cross_attention_layers else FlashLlamaLayer return layer_cls( - layer_id, + layer_wise_layer_id[layer_id], prefix=prefix, config=config, weights=weights, diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index 4c514f508..a80d5ce22 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -186,11 +186,9 @@ def supports_adapter_loading(self) -> bool: @property def adapter_layers(self) -> List[str]: - return ( - [f"TEXT_{layer_type}" for layer_type in TEXT_ADAPTER_LAYERS] - + [f"VISION_GLOBAL_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS] - + [f"VISION_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS] - ) + return TEXT_ADAPTER_LAYERS \ + + [f'VISION_GLOBAL_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] \ + + [f'VISION_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] @property def default_traced_adapter_layers(self) -> List[str]: @@ -199,16 +197,15 @@ def default_traced_adapter_layers(self) -> List[str]: def get_num_layers_for_type(self, layer_type: str) -> int: if "LM_HEAD" in layer_type: return 1 - if "TEXT_" in layer_type: - return [ - layer_id - for layer_id, layer in enumerate(self.model.text_model.model.layers) - if not isinstance(layer, FlashLlamaCrossLayer) - ] - if "VISION_GLOBAL_TRANSFORMER_" in layer_type: + if 'VISION_GLOBAL_TRANSFORMER_' in layer_type: return len(self.model.vision_model.global_transformer.layers) if "VISION_TRANSFORMER_" in layer_type: return len(self.model.vision_model.transformer.layers) + return [ + layer_id + for layer_id, layer in enumerate(self.model.text_model.model.layers) + if not isinstance(layer, FlashLlamaCrossLayer) + ] def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} @@ -217,27 +214,15 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: for i, layer in enumerate(self.model.text_model.model.layers): if isinstance(layer, FlashLlamaCrossLayer): continue - layer_weights[(i, f"TEXT_{Q_PROJ}")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, f"TEXT_{K_PROJ}")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, f"TEXT_{V_PROJ}")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, f"TEXT_{O_PROJ}")] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) - - layer_weights[(i, f"TEXT_{GATE_PROJ}")] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) - layer_weights[(i, f"TEXT_{UP_PROJ}")] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) - layer_weights[(i, f"TEXT_{DOWN_PROJ}")] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) - layer_weights[(0, f"TEXT_{LM_HEAD}")] = ( - "base_model.model.language_model.lm_head", - self.model.text_model.lm_head, - ) + layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) + layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) + layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) + layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) + + layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) + layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) + layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) + layer_weights[(0, LM_HEAD)] = ("base_model.model.language_model.lm_head", self.model.text_model.lm_head) vision_layer_mappings = [ ("vision_model.global_transformer.layers", self.model.vision_model.global_transformer.layers),