Skip to content

Commit

Permalink
Add support for targeting cross_attn layers in mllama (#693)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar authored Nov 27, 2024
1 parent da95224 commit c96ff88
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 66 deletions.
136 changes: 82 additions & 54 deletions server/lorax_server/models/custom_modeling/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@
TensorParallelRowLinear,
)
from lorax_server.utils.lora import (
DOWN_PROJ,
FC1,
FC2,
GATE_PROJ,
K_PROJ,
O_PROJ,
Q_PROJ,
UP_PROJ,
V_PROJ,
)

Expand Down Expand Up @@ -242,7 +245,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type):
def forward(self, hidden_states: torch.Tensor, adapter_data: AdapterBatchData) -> torch.Tensor:
hidden_states = self.fc1(hidden_states, adapter_data)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states, adapter_data)
hidden_states = self.fc2(hidden_states.view(-1, hidden_states.shape[-1]), adapter_data)
return hidden_states


Expand Down Expand Up @@ -329,7 +332,7 @@ def forward(
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
attn_output = attn_output.view(batch_size * q_seq_len, -1)

output = self.o_proj(attn_output, adapter_data)
return output
Expand Down Expand Up @@ -691,29 +694,55 @@ def __init__(self, *, prefix, config, weights, layer_idx):
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = self.num_key_value_heads // weights.process_group.size()

self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
self.q_proj = TensorParallelMultiAdapterLinear.load(
TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj"],
weights=weights,
dim=0,
bias=False,
),
layer_idx,
[Q_PROJ],
sizes=[self.head_size * self.num_heads],
process_group=weights.process_group,
)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
self.k_proj = TensorParallelMultiAdapterLinear.load(
TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.k_proj"],
weights=weights,
dim=0,
bias=False,
),
layer_idx,
[K_PROJ],
sizes=[self.head_size * self.num_key_value_heads],
process_group=weights.process_group,
)
self.v_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
self.v_proj = TensorParallelMultiAdapterLinear.load(
TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.v_proj"],
weights=weights,
dim=0,
bias=False,
),
layer_idx,
[V_PROJ],
sizes=[self.head_size * self.num_key_value_heads],
process_group=weights.process_group,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
self.o_proj = TensorParallelAdapterRowLinear.load(
TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
),
layer_idx,
O_PROJ,
process_group=weights.process_group,
)

self.q_norm = MllamaTextRMSNorm.load(prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps)
Expand All @@ -727,11 +756,12 @@ def forward(
# past_key_value=None,
# attention_mask: Optional[torch.Tensor] = None,
# cache_position: Optional[torch.LongTensor] = None,
adapter_data: Optional[AdapterBatchData] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# hidden_states = hidden_states.unsqueeze(0)
# bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = self.q_proj(hidden_states, adapter_data)
query_states = query_states.view(-1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)

Expand All @@ -744,8 +774,8 @@ def forward(
indices,
) = cross_attention_states

key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = self.k_proj(cross_attention_states, adapter_data)
value_states = self.v_proj(cross_attention_states, adapter_data)
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
key_states = self.k_norm(key_states)
Expand Down Expand Up @@ -779,38 +809,54 @@ def forward(
False,
None,
)[0]
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)

return attn_output


# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // weights.process_group.size()
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_idx,
[GATE_PROJ, UP_PROJ],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_idx,
DOWN_PROJ,
process_group=weights.process_group,
)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
def forward(self, x, adapter_data):
shape = x.shape
gate_up_states = self.gate_up_proj(x)
gate_up_states = self.gate_up_proj(x, adapter_data)
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
result = self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1])
result = self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data)
return result


Expand All @@ -834,7 +880,7 @@ def __init__(self, layer_id, prefix, config, weights) -> None:
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
)

self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_idx=layer_idx)
self.post_attention_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
Expand Down Expand Up @@ -877,12 +923,13 @@ def forward(
hidden_states=hidden_states,
# attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, adapter_data)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states

out_hidden_states[indices] = hidden_states
Expand Down Expand Up @@ -922,29 +969,10 @@ def __init__(self, prefix, config, weights):
config.text_config._attn_implementation = "sdpa"
self.hidden_size = config.text_config.hidden_size
cross_attention_layers = getattr(config.text_config, "cross_attention_layers", [])
# note(ajinkya): Since cross attention layers are not currently targeted, we need to handle
# the case of some layers not having lora adapters which lorax doesn't currently support.
# Hence, this hack where we a dict that goes from actual layer index to index if the layers
# were filtered according to their types. For exmaple:
# all layers = [0, 1, 2, 3, 4]
# cross attention layers = [1, 3]
# layer wise layer ids = [0, 0, 1, 1, 2]
# since layers 1 and 3 are of different type they are indexed as if they are sequential
# this prevents illegal memory access errors from running the punica kernels
layer_wise_layer_id = [0] * config.text_config.num_hidden_layers
i = j = 0
for k in range(config.text_config.num_hidden_layers):
if j == len(cross_attention_layers) or k < cross_attention_layers[j]:
layer_wise_layer_id[k] = i
i += 1
else:
layer_wise_layer_id[k] = j
j += 1

def create_layer(layer_id, prefix, config, weights):
layer_cls = FlashLlamaCrossLayer if layer_id in cross_attention_layers else FlashLlamaLayer
return layer_cls(
layer_wise_layer_id[layer_id],
layer_id,
prefix=prefix,
config=config,
weights=weights,
Expand Down
10 changes: 8 additions & 2 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,11 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) ->

return out

# note(ajinkya): hack needed to make sure that we can target cross_attn layers in mllama
# default behavior is to just return prefill state, but mllama always returns True
def adapter_prefill_state(self, prefill: bool) -> bool:
return prefill

@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch, is_warmup: bool = False
Expand Down Expand Up @@ -1595,13 +1600,14 @@ def generate_token(

# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
self.punica_wrapper.update_metadata(adapter_meta, prefill)
adapter_prefill_state = self.adapter_prefill_state(prefill)
self.punica_wrapper.update_metadata(adapter_meta, adapter_prefill_state)
adapter_data = AdapterBatchData.from_meta(
adapter_meta,
self.layer_to_adapter_weights,
self.layer_to_lora_weights,
self.punica_wrapper,
prefill,
adapter_prefill_state,
batch.prefill_head_indices,
)

Expand Down
27 changes: 17 additions & 10 deletions server/lorax_server/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,23 +203,24 @@ def get_num_layers_for_type(self, layer_type: str) -> int:
return len(self.model.vision_model.global_transformer.layers)
if "VISION_TRANSFORMER_" in layer_type:
return len(self.model.vision_model.transformer.layers)
return [
layer_id
for layer_id, layer in enumerate(self.model.text_model.model.layers)
if not isinstance(layer, FlashLlamaCrossLayer)
]

return len(self.model.text_model.model.layers)

def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}

prefix = "language_model.model.layers"
for i, layer in enumerate(self.model.text_model.model.layers):
if isinstance(layer, FlashLlamaCrossLayer):
continue
layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value)
layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value)
layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value)
layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj)
layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.cross_attn.q_proj", layer.cross_attn.q_proj)
layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.cross_attn.k_proj", layer.cross_attn.k_proj)
layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.cross_attn.v_proj", layer.cross_attn.v_proj)
layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.cross_attn.o_proj", layer.cross_attn.o_proj)
else:
layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value)
layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value)
layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value)
layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj)

layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj)
layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj)
Expand Down Expand Up @@ -255,6 +256,12 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:

return layer_weights

# note(ajinkya): for cross_attn in mllama we need to disable bgmv kernels
# during decode, but doing this selectively for cross_attn is tricky so
# simply resorting to sgmv kernels by always passing prefill=True
def adapter_prefill_state(self, prefill: bool) -> bool:
return True

def forward(
self,
batch: VlmCausalLMBatch,
Expand Down

0 comments on commit c96ff88

Please sign in to comment.