From c747d27a8d1c9421fdcb4542191efe8553535763 Mon Sep 17 00:00:00 2001 From: Geoffrey Angus Date: Mon, 30 Oct 2023 14:48:08 -0700 Subject: [PATCH] Fix Llama-70b adapter merging (#1) --- .../custom_modeling/flash_llama_modeling.py | 21 ++++++++++ .../models/flash_llama.py | 41 ++++++++----------- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index db21c9bbb..22d47c304 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -232,6 +232,27 @@ def __init__( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) + def get_query_key_value_weights(self, clone=True): + """Gets the query, key, and value weights from the attention layer. + + If `clone`, then the weights are cloned before being returned. + + NOTE: if not `clone`, then the weights are returned as views, meaning + that changes to the weights will be reflected in the attention layer. + """ + query, key, value = self.query_key_value.linear.weight.split( + [ + self.head_size * self.num_heads, + self.head_size * self.num_key_value_heads, + self.head_size * self.num_key_value_heads, + ], + dim=0, + ) + + if clone: + return query.clone(), key.clone(), value.clone() + return query, key, value + def forward( self, hidden_states, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index d67d09099..1930a5f00 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -121,18 +121,15 @@ def __init__( self.orig_weights = {} prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - d_qkv, _ = layer.self_attn.query_key_value.linear.weight.shape - d_q = d_qkv // 3 # break up d_qkv into 3 parts + q_proj, _, v_proj = layer.self_attn.get_query_key_value_weights(clone=True) - orig_q_proj = layer.self_attn.query_key_value.linear.weight[:d_q] - orig_q_proj_device = orig_q_proj.device + orig_q_proj_device = q_proj.device weight_name = f"{prefix}.{i}.self_attn.q_proj" - self.orig_weights[weight_name] = (orig_q_proj.cpu(), orig_q_proj_device) + self.orig_weights[weight_name] = (q_proj.cpu(), orig_q_proj_device) - orig_v_proj = layer.self_attn.query_key_value.linear.weight[2*d_q:] - orig_v_proj_device = orig_v_proj.device + orig_v_proj_device = v_proj.device weight_name = f"{prefix}.{i}.self_attn.v_proj" - self.orig_weights[weight_name] = (orig_v_proj.cpu(), orig_v_proj_device) + self.orig_weights[weight_name] = (v_proj.cpu(), orig_v_proj_device) def load_adapter(self, adapter_id, adapter_source): if not self.dynamic_adapter_loading_enabled: @@ -151,14 +148,15 @@ def load_adapter(self, adapter_id, adapter_source): # if the adapter_id is the base model, then just reset the weights prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - qkv_d, _ = layer.self_attn.query_key_value.linear.weight.shape - q_d = qkv_d // 3 # break up qkv_d into 3 parts - - # place the original weights (on their original device) back into the model - q_proj, q_proj_device = self.orig_weights[f"{prefix}.{i}.self_attn.q_proj"] - layer.self_attn.query_key_value.linear.weight[:q_d] = q_proj.to(q_proj_device) - v_proj, v_proj_device = self.orig_weights[f"{prefix}.{i}.self_attn.v_proj"] - layer.self_attn.query_key_value.linear.weight[2*q_d:] = v_proj.to(v_proj_device) + # replace the target matrices in place + q_proj, _, v_proj = layer.self_attn.get_query_key_value_weights(clone=False) + + # place original weights (on their original device) by setting in place + orig_q_proj, orig_q_proj_device = self.orig_weights[f"{prefix}.{i}.self_attn.q_proj"] + q_proj[:] = orig_q_proj.to(orig_q_proj_device) + orig_v_proj, orig_v_proj_device = self.orig_weights[f"{prefix}.{i}.self_attn.v_proj"] + v_proj[:] = orig_v_proj.to(orig_v_proj_device) + self.adapter_id = adapter_id else: weight_names = tuple(self.orig_weights.keys()) @@ -189,11 +187,8 @@ def compute_merged_weight(weight_name): desc=f"Merging weights for adapter {adapter_id}", total=len(self.model.model.layers) ): - d_qkv, _ = layer.self_attn.query_key_value.linear.weight.shape - d_q = d_qkv // 3 # break up d_qkv into 3 parts - - layer.self_attn.query_key_value.linear.weight[:d_q] = compute_merged_weight( - f"{prefix}.{i}.self_attn.q_proj") - layer.self_attn.query_key_value.linear.weight[2*d_q:] = compute_merged_weight( - f"{prefix}.{i}.self_attn.v_proj") + # replace the target matrices in place + q_proj, _, v_proj = layer.self_attn.get_query_key_value_weights(clone=False) + q_proj[:] = compute_merged_weight(f"{prefix}.{i}.self_attn.q_proj") + v_proj[:] = compute_merged_weight(f"{prefix}.{i}.self_attn.v_proj") self.adapter_id = adapter_id