Skip to content

Commit

Permalink
fix: Phi LoRA loading (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Dec 16, 2023
1 parent 2fae25e commit af59e54
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions server/lorax_server/models/flash_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights[(i, MLP_FC1)] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1)
layer_weights[(i, MLP_FC2)] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2)

layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
layer_weights[(0, LM_HEAD)] = ("lm_head.linear", self.model.lm_head.linear)
return layer_weights

@property
Expand All @@ -138,10 +138,17 @@ def get_num_layers_for_type(self, layer_type: str) -> int:
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

def split_lora_b_qkv(self, t: torch.Tensor, projection_size: int) -> torch.Tensor:
def split_lora_b_qkv(self, t: torch.Tensor, head_size: int, num_heads: int, num_key_value_heads: int) -> torch.Tensor:
# Because we're splitting on the hidden size dimension, we need to
# account for the separate q, k, and v matrices.
chunks = torch.split(t, projection_size, dim=1)
chunks = t.split(
[
head_size * num_heads,
head_size * num_key_value_heads,
head_size * num_key_value_heads,
],
dim=1,
)
assert len(chunks) == 3
chunks = [
shard_on_dim(w, dim=1, process_group=self.process_group)
Expand All @@ -167,9 +174,12 @@ def shard_lora_weights(
# [r, hidden_size]
# Because we're splitting on the hidden size dimension, we need to
# account for the separate q, k, and v matrices.
projection_size = (self.config.hidden_size // self.config.num_attention_heads) * self.config.num_attention_heads
num_heads = self.config.n_head
hidden_size = self.config.n_embd
head_size = hidden_size // num_heads
num_key_value_heads = getattr(self.config, "n_head_kv", None) or num_heads
weights_b = [
self.split_lora_b_qkv(w, projection_size)
self.split_lora_b_qkv(w, head_size, num_heads, num_key_value_heads)
for w in weights_b
]

Expand Down

0 comments on commit af59e54

Please sign in to comment.