-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(WIP) Support targeting the embedding layer for LoRA #501
base: main
Are you sure you want to change the base?
Changes from 8 commits
03679f3
ddc996e
16d9ebf
1541517
6e3cbc0
a7ee1c5
b72957c
d40c52b
f73317e
4f1086e
891049f
63019b9
65a4c6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
from lorax_server.adapters.config import AdapterConfig, ModuleMap | ||
from lorax_server.adapters.types import LORA | ||
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights | ||
from lorax_server.utils.lora import EMBED_TOKENS | ||
from lorax_server.utils.sgmv import ( | ||
BGMV_MAX_RANK, | ||
MAX_RANK_CUSTOM, | ||
|
@@ -40,14 +41,20 @@ def map_weights_for_model( | |
adapter_weight_names = set() | ||
module_map = {} | ||
for weight_name in weight_names: | ||
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" | ||
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" | ||
if EMBED_TOKENS in weight_name: | ||
lora_a_name = f"base_model.model.{weight_name}.lora_embedding_A" | ||
lora_b_name = f"base_model.model.{weight_name}.lora_embedding_B" | ||
else: | ||
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" | ||
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" | ||
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: | ||
continue | ||
|
||
# note(ajinkya): popping the weights so that we know which weights are | ||
# can be used as lora weights (supported) and which cannot | ||
module_map[weight_name] = { | ||
"lora_A": (adapter_weights[lora_a_name], lora_a_name), | ||
"lora_B": (adapter_weights[lora_b_name], lora_b_name), | ||
"lora_A": (adapter_weights.pop(lora_a_name), lora_a_name), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I understand the purpose of using In this case, I would suggest cloning the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. I just realized that I may not even need to change this part since |
||
"lora_B": (adapter_weights.pop(lora_b_name), lora_b_name), | ||
} | ||
adapter_weight_names.add(lora_a_name) | ||
adapter_weight_names.add(lora_b_name) | ||
|
@@ -90,6 +97,7 @@ def __init__( | |
weights_a: List[torch.Tensor], | ||
weights_b: List[torch.Tensor], | ||
adapter_config: LoraConfig, | ||
is_embed: bool = False, | ||
): | ||
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 | ||
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 | ||
|
@@ -98,7 +106,8 @@ def __init__( | |
self._is_transposed = False | ||
|
||
# [num_layers, hidden_size, r] | ||
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] | ||
if not is_embed: | ||
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] | ||
self._weights_a = torch.stack(weights_a) | ||
|
||
# [num_layers, r, hidden_size] | ||
|
@@ -158,8 +167,10 @@ def load( | |
key = (layer_id, layer_type) | ||
weight_name, layer = model.target_to_layer[key] | ||
|
||
base_weight = layer.base_layer.linear.weight | ||
base_device = base_weight.device | ||
if EMBED_TOKENS in layer_type: | ||
base_device = layer.base_layer.weight.device | ||
else: | ||
base_device = layer.base_layer.linear.weight.device | ||
|
||
if weight_name not in module_map: | ||
# There is no LoRA weight for this layer type in the adapter | ||
|
@@ -196,13 +207,15 @@ def load( | |
|
||
return LoraWeights( | ||
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), | ||
config, | ||
adapter_config=config, | ||
is_embed=(layer_type == EMBED_TOKENS), | ||
) | ||
|
||
|
||
@dataclass | ||
class RankSegments: | ||
rank: int | ||
adapter_index_map: int | ||
|
||
lora_a_ptr: torch.Tensor | ||
lora_b_ptr: torch.Tensor | ||
|
@@ -242,6 +255,7 @@ def load( | |
meta: AdapterBatchMetadata, | ||
prefill: bool, | ||
prefill_head_indices: Optional[torch.Tensor], | ||
is_embed: bool, | ||
) -> Optional["BatchLoraWeights"]: | ||
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} | ||
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)} | ||
|
@@ -252,64 +266,39 @@ def load( | |
device = first_weights.weights_a.device | ||
segment_indices = meta.segment_indices | ||
|
||
lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights} | ||
lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights} | ||
|
||
max_rank = max(adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights) | ||
|
||
if prefill or max_rank > BGMV_MAX_RANK: | ||
use_sgmv = True | ||
lora_a_ptr = torch.tensor( | ||
[ | ||
(adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) | ||
for idx in segment_indices | ||
], | ||
dtype=torch.int64, | ||
device=device, | ||
) | ||
lora_b_ptr = torch.tensor( | ||
[ | ||
(adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) | ||
for idx in segment_indices | ||
], | ||
dtype=torch.int64, | ||
device=device, | ||
) | ||
else: | ||
use_sgmv = False | ||
lora_a_ptr = torch.tensor( | ||
[ | ||
(adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) | ||
for idx in segment_indices | ||
], | ||
dtype=torch.int64, | ||
device=device, | ||
) | ||
lora_b_ptr = torch.tensor( | ||
[ | ||
(adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) | ||
for idx in segment_indices | ||
], | ||
dtype=torch.int64, | ||
device=device, | ||
) | ||
|
||
adapter_index_configs = { | ||
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights | ||
} | ||
|
||
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} | ||
|
||
rank_indices = defaultdict(list) | ||
lora_a, lora_b, adapter_index_configs, adapter_to_segment = {}, {}, {}, {} | ||
max_rank, rank_indices = 0, defaultdict(list) | ||
for segment_idx, adapter_idx in enumerate(segment_indices): | ||
if adapter_idx not in adapter_weights: | ||
continue | ||
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) | ||
adapter_to_segment[adapter_idx] = segment_idx | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely looks cleaner. I believe we have a few unit tests to verify this is working correctly, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw some test cases, but I am planning to add missing ones as well. I anyway need to add test cases to make sure that ours and HF implementation match. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will take a proper look at the test cases. |
||
if adapter_idx in adapter_weights: | ||
adapter_weight = adapter_weights[adapter_idx] | ||
adapter_index_configs[adapter_idx] = adapter_weight.adapter_config | ||
max_rank = max(max_rank, adapter_weight.lora_a_r) | ||
rank_indices[adapter_weight.lora_a_r].append(segment_idx) | ||
lora_a[adapter_idx] = adapter_weight.weights_a | ||
lora_b[adapter_idx] = adapter_weight.weights_b | ||
|
||
use_sgmv = prefill or max_rank > BGMV_MAX_RANK | ||
lora_a_ptr, lora_b_ptr = [], [] | ||
for segment_idx, adapter_idx in enumerate(segment_indices): | ||
if adapter_idx in adapter_weights: | ||
adapter_weight = adapter_weights[adapter_idx] | ||
lora_a_ptr.append( | ||
(adapter_weight.weights_a if use_sgmv or is_embed else adapter_weight.weights_a_t).data_ptr() | ||
) | ||
lora_b_ptr.append( | ||
(adapter_weight.weights_b if use_sgmv else adapter_weight.weights_b_t).data_ptr() | ||
) | ||
else: | ||
lora_a_ptr.append(EMPTY_TENSOR.data_ptr()) | ||
lora_b_ptr.append(EMPTY_TENSOR.data_ptr()) | ||
lora_a_ptr = torch.tensor(lora_a_ptr, dtype=torch.int64, device=device) | ||
lora_b_ptr = torch.tensor(lora_b_ptr, dtype=torch.int64, device=device) | ||
|
||
if prefill_head_indices is not None: | ||
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] | ||
for head_index in prefill_head_indices: | ||
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters | ||
# j cannot go out of bounds as that would mean there are tokens without segments | ||
if head_index < meta.adapter_segments[j]: | ||
prefill_head_segment_ends[-1] += 1 | ||
else: | ||
|
@@ -335,13 +324,13 @@ def load( | |
segment_starts[i] = prefill_head_segment_starts[segment_index] | ||
segment_ends[i] = prefill_head_segment_ends[segment_index] | ||
else: | ||
rank_indices = set(indices) | ||
batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()] | ||
batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices] | ||
batch_indices = [idx if idx in set(indices) else -1 for idx in batch_indices] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would rather keep the separate variable as the call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, yes. I will revert back to the original code. The reason to change it was to not have |
||
batch_indices = torch.tensor(batch_indices, dtype=torch.int64, device=device) | ||
|
||
rank_data[rank] = RankSegments( | ||
rank=rank, | ||
adapter_index_map=indices, | ||
tmp_shrink=tmp_shrink, | ||
tmp_expand=tmp_expand, | ||
lora_a_ptr=lora_a_ptr[indices], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
from functools import lru_cache | ||
from typing import TYPE_CHECKING, Set, Tuple | ||
|
||
from loguru import logger | ||
from safetensors.torch import load_file | ||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer | ||
|
||
|
@@ -158,4 +159,13 @@ def load_module_map( | |
|
||
# map the model weights to the relevant adapter weights (LoRA A and B matrices) | ||
module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names) | ||
|
||
# note(ajinkya): adapter weights are consumed during above mapping but if some are not then we may not be | ||
# supporting all the weights in the adapter which should be an error but for now just logging it | ||
if len(adapter_weights) > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per above comment, would return the modified adapter weights as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. |
||
logger.warning( | ||
f"Adapter {adapter_id} for the model {model_id}" + \ | ||
f" contains unsupported weights: {', '.join(adapter_weights.keys())}" | ||
) | ||
|
||
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to make this
embed_tokens
name a property of the model rather than a constant, as I imagine it will vary from one architecture to the next.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I will make the change.