Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP) Support targeting the embedding layer for LoRA #501

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
115 changes: 52 additions & 63 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lorax_server.adapters.config import AdapterConfig, ModuleMap
from lorax_server.adapters.types import LORA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights
from lorax_server.utils.lora import EMBED_TOKENS
from lorax_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
Expand Down Expand Up @@ -40,14 +41,20 @@ def map_weights_for_model(
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if EMBED_TOKENS in weight_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to make this embed_tokens name a property of the model rather than a constant, as I imagine it will vary from one architecture to the next.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I will make the change.

lora_a_name = f"base_model.model.{weight_name}.lora_embedding_A"
lora_b_name = f"base_model.model.{weight_name}.lora_embedding_B"
else:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue

# note(ajinkya): popping the weights so that we know which weights are
# can be used as lora weights (supported) and which cannot
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
"lora_A": (adapter_weights.pop(lora_a_name), lora_a_name),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the purpose of using pop here, as it doesn't look like the adapter_weights are used below (unless it's used from the caller). In general, it's good to avoid modifying input objects unless it's clear that the function does that from the name, etc.

In this case, I would suggest cloning the adapter_weights dict at the top to avoid modifying the input, and then returning the modified adapter_weights if the caller needs to check which elements haven't been popped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I just realized that I may not even need to change this part since adapter_weight_names captures which weights were consumed, and I can use it in the caller to figure out if all weights were consumed.

"lora_B": (adapter_weights.pop(lora_b_name), lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
Expand Down Expand Up @@ -90,6 +97,7 @@ def __init__(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
is_embed: bool = False,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
Expand All @@ -98,7 +106,8 @@ def __init__(
self._is_transposed = False

# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
if not is_embed:
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)

# [num_layers, r, hidden_size]
Expand Down Expand Up @@ -158,8 +167,10 @@ def load(
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]

base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if EMBED_TOKENS in layer_type:
base_device = layer.base_layer.weight.device
else:
base_device = layer.base_layer.linear.weight.device

if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
Expand Down Expand Up @@ -196,13 +207,15 @@ def load(

return LoraWeights(
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
config,
adapter_config=config,
is_embed=(layer_type == EMBED_TOKENS),
)


@dataclass
class RankSegments:
rank: int
adapter_index_map: int

lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
Expand Down Expand Up @@ -242,6 +255,7 @@ def load(
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
is_embed: bool,
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)}
Expand All @@ -252,64 +266,39 @@ def load(
device = first_weights.weights_a.device
segment_indices = meta.segment_indices

lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights}
lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights}

max_rank = max(adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights)

if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)

adapter_index_configs = {
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights
}

adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}

rank_indices = defaultdict(list)
lora_a, lora_b, adapter_index_configs, adapter_to_segment = {}, {}, {}, {}
max_rank, rank_indices = 0, defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
adapter_to_segment[adapter_idx] = segment_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely looks cleaner. I believe we have a few unit tests to verify this is working correctly, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw some test cases, but I am planning to add missing ones as well. I anyway need to add test cases to make sure that ours and HF implementation match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a proper look at the test cases.

if adapter_idx in adapter_weights:
adapter_weight = adapter_weights[adapter_idx]
adapter_index_configs[adapter_idx] = adapter_weight.adapter_config
max_rank = max(max_rank, adapter_weight.lora_a_r)
rank_indices[adapter_weight.lora_a_r].append(segment_idx)
lora_a[adapter_idx] = adapter_weight.weights_a
lora_b[adapter_idx] = adapter_weight.weights_b

use_sgmv = prefill or max_rank > BGMV_MAX_RANK
lora_a_ptr, lora_b_ptr = [], []
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx in adapter_weights:
adapter_weight = adapter_weights[adapter_idx]
lora_a_ptr.append(
(adapter_weight.weights_a if use_sgmv or is_embed else adapter_weight.weights_a_t).data_ptr()
)
lora_b_ptr.append(
(adapter_weight.weights_b if use_sgmv else adapter_weight.weights_b_t).data_ptr()
)
else:
lora_a_ptr.append(EMPTY_TENSOR.data_ptr())
lora_b_ptr.append(EMPTY_TENSOR.data_ptr())
lora_a_ptr = torch.tensor(lora_a_ptr, dtype=torch.int64, device=device)
lora_b_ptr = torch.tensor(lora_b_ptr, dtype=torch.int64, device=device)

if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
# j cannot go out of bounds as that would mean there are tokens without segments
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
Expand All @@ -335,13 +324,13 @@ def load(
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()]
batch_indices = [idx if idx in rank_indices else -1 for idx in batch_indices]
batch_indices = [idx if idx in set(indices) else -1 for idx in batch_indices]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would rather keep the separate variable as the call to set(indices) each iteration of the loop is unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes. I will revert back to the original code. The reason to change it was to not have rank_indices variable inside the for loop since the for loop itself loops over another rank_indices variable. Maybe, I can rename the rank_indices here.

batch_indices = torch.tensor(batch_indices, dtype=torch.int64, device=device)

rank_data[rank] = RankSegments(
rank=rank,
adapter_index_map=indices,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
Expand Down
7 changes: 4 additions & 3 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from lorax_server.adapters.types import LORA
from lorax_server.utils.lora import LM_HEAD
from lorax_server.utils.lora import EMBED_TOKENS, LM_HEAD


@dataclass
Expand Down Expand Up @@ -82,6 +82,7 @@ def get_data(
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
is_embed: bool,
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict)
Expand All @@ -91,7 +92,7 @@ def get_data(

batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices)
batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices, is_embed)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
Expand All @@ -117,7 +118,7 @@ def from_meta(
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None)
data[k] = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None, k == EMBED_TOKENS)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)

def ranks(self) -> Set[int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from lorax_server.utils.layers import (
MultiAdapterHead,
PositionRotaryEmbedding,
TensorParallelAdapterRowEmbedding,
TensorParallelAdapterRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
Expand All @@ -43,6 +44,7 @@
)
from lorax_server.utils.lora import (
DOWN_PROJ,
EMBED_TOKENS,
GATE_PROJ,
K_PROJ,
LM_HEAD,
Expand Down Expand Up @@ -457,7 +459,12 @@ def __init__(self, config, weights):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights)
self.embed_tokens = TensorParallelAdapterRowEmbedding(
base_layer=TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights),
layer_id=0,
layer_name=EMBED_TOKENS,
process_group=process_group,
)
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
Expand Down Expand Up @@ -488,7 +495,7 @@ def forward(
max_s: int,
adapter_data: AdapterBatchData,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states = self.embed_tokens(input_ids, adapter_data)

# Get rotary cos and sin for this forward
# Avoid to index in each layer
Expand Down
7 changes: 4 additions & 3 deletions server/lorax_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from lorax_server.utils.lora import (
DOWN_PROJ,
EMBED_TOKENS,
GATE_PROJ,
K_PROJ,
LM_HEAD,
Expand All @@ -29,8 +30,7 @@
tracer = trace.get_tracer(__name__)


# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ] # LM_HEAD
ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD, EMBED_TOKENS]
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}


Expand Down Expand Up @@ -136,6 +136,7 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights[(i, DOWN_PROJ)] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)

layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
layer_weights[(0, EMBED_TOKENS)] = ("model.embed_tokens", self.model.model.embed_tokens)
return layer_weights

@property
Expand All @@ -147,7 +148,7 @@ def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]

def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
return 1 if layer_type == LM_HEAD or layer_type == EMBED_TOKENS else len(self.model.model.layers)

def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
10 changes: 10 additions & 0 deletions server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple

from loguru import logger
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer

Expand Down Expand Up @@ -158,4 +159,13 @@ def load_module_map(

# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names)

# note(ajinkya): adapter weights are consumed during above mapping but if some are not then we may not be
# supporting all the weights in the adapter which should be an error but for now just logging it
if len(adapter_weights) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per above comment, would return the modified adapter weights as unused_adapter_weights or similar rather than relying on the input to be modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

logger.warning(
f"Adapter {adapter_id} for the model {model_id}" + \
f" contains unsupported weights: {', '.join(adapter_weights.keys())}"
)

return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
2 changes: 2 additions & 0 deletions server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def get_max_graph_state(
rank_data={
MAX_RANK: RankSegments(
rank=MAX_RANK,
adapter_index_map=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
lora_a_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
lora_b_ptr=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
indices=torch.zeros((MAX_BATCH_SIZE,), dtype=torch.int64, device=device),
Expand Down Expand Up @@ -193,6 +194,7 @@ def trace(
{
max_rank: RankSegments(
rank=max_rank,
adapter_index_map=weight_data.rank_data[MAX_RANK].adapter_index_map[:batch_size],
lora_a_ptr=weight_data.rank_data[MAX_RANK].lora_a_ptr[:segment_size],
lora_b_ptr=weight_data.rank_data[MAX_RANK].lora_b_ptr[:segment_size],
indices=weight_data.rank_data[MAX_RANK].indices[:batch_size],
Expand Down
Loading
Loading