Skip to content

Commit

Permalink
Fix Llama-70b adapter merging (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus authored Oct 30, 2023
1 parent fdb1242 commit c747d27
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,27 @@ def __init__(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)

def get_query_key_value_weights(self, clone=True):
"""Gets the query, key, and value weights from the attention layer.
If `clone`, then the weights are cloned before being returned.
NOTE: if not `clone`, then the weights are returned as views, meaning
that changes to the weights will be reflected in the attention layer.
"""
query, key, value = self.query_key_value.linear.weight.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
],
dim=0,
)

if clone:
return query.clone(), key.clone(), value.clone()
return query, key, value

def forward(
self,
hidden_states,
Expand Down
41 changes: 18 additions & 23 deletions server/text_generation_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,15 @@ def __init__(
self.orig_weights = {}
prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
d_qkv, _ = layer.self_attn.query_key_value.linear.weight.shape
d_q = d_qkv // 3 # break up d_qkv into 3 parts
q_proj, _, v_proj = layer.self_attn.get_query_key_value_weights(clone=True)

orig_q_proj = layer.self_attn.query_key_value.linear.weight[:d_q]
orig_q_proj_device = orig_q_proj.device
orig_q_proj_device = q_proj.device
weight_name = f"{prefix}.{i}.self_attn.q_proj"
self.orig_weights[weight_name] = (orig_q_proj.cpu(), orig_q_proj_device)
self.orig_weights[weight_name] = (q_proj.cpu(), orig_q_proj_device)

orig_v_proj = layer.self_attn.query_key_value.linear.weight[2*d_q:]
orig_v_proj_device = orig_v_proj.device
orig_v_proj_device = v_proj.device
weight_name = f"{prefix}.{i}.self_attn.v_proj"
self.orig_weights[weight_name] = (orig_v_proj.cpu(), orig_v_proj_device)
self.orig_weights[weight_name] = (v_proj.cpu(), orig_v_proj_device)

def load_adapter(self, adapter_id, adapter_source):
if not self.dynamic_adapter_loading_enabled:
Expand All @@ -151,14 +148,15 @@ def load_adapter(self, adapter_id, adapter_source):
# if the adapter_id is the base model, then just reset the weights
prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
qkv_d, _ = layer.self_attn.query_key_value.linear.weight.shape
q_d = qkv_d // 3 # break up qkv_d into 3 parts

# place the original weights (on their original device) back into the model
q_proj, q_proj_device = self.orig_weights[f"{prefix}.{i}.self_attn.q_proj"]
layer.self_attn.query_key_value.linear.weight[:q_d] = q_proj.to(q_proj_device)
v_proj, v_proj_device = self.orig_weights[f"{prefix}.{i}.self_attn.v_proj"]
layer.self_attn.query_key_value.linear.weight[2*q_d:] = v_proj.to(v_proj_device)
# replace the target matrices in place
q_proj, _, v_proj = layer.self_attn.get_query_key_value_weights(clone=False)

# place original weights (on their original device) by setting in place
orig_q_proj, orig_q_proj_device = self.orig_weights[f"{prefix}.{i}.self_attn.q_proj"]
q_proj[:] = orig_q_proj.to(orig_q_proj_device)
orig_v_proj, orig_v_proj_device = self.orig_weights[f"{prefix}.{i}.self_attn.v_proj"]
v_proj[:] = orig_v_proj.to(orig_v_proj_device)

self.adapter_id = adapter_id
else:
weight_names = tuple(self.orig_weights.keys())
Expand Down Expand Up @@ -189,11 +187,8 @@ def compute_merged_weight(weight_name):
desc=f"Merging weights for adapter {adapter_id}",
total=len(self.model.model.layers)
):
d_qkv, _ = layer.self_attn.query_key_value.linear.weight.shape
d_q = d_qkv // 3 # break up d_qkv into 3 parts

layer.self_attn.query_key_value.linear.weight[:d_q] = compute_merged_weight(
f"{prefix}.{i}.self_attn.q_proj")
layer.self_attn.query_key_value.linear.weight[2*d_q:] = compute_merged_weight(
f"{prefix}.{i}.self_attn.v_proj")
# replace the target matrices in place
q_proj, _, v_proj = layer.self_attn.get_query_key_value_weights(clone=False)
q_proj[:] = compute_merged_weight(f"{prefix}.{i}.self_attn.q_proj")
v_proj[:] = compute_merged_weight(f"{prefix}.{i}.self_attn.v_proj")
self.adapter_id = adapter_id

0 comments on commit c747d27

Please sign in to comment.