Skip to content

Commit

Permalink
Fix LM head interaction with Medusa (#567)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Aug 5, 2024
1 parent 2e47e77 commit 8454a82
Show file tree
Hide file tree
Showing 26 changed files with 56 additions and 14 deletions.
12 changes: 11 additions & 1 deletion server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lorax_server.adapters.config import AdapterConfig, ModuleMap
from lorax_server.adapters.types import LORA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights
from lorax_server.utils.lora import LM_HEAD
from lorax_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
Expand Down Expand Up @@ -225,12 +226,17 @@ class BatchLoraWeights(BatchAdapterWeights):
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool
layer_name: str
prefill_head_indices: Optional[torch.Tensor]

def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs

def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values())
return (
all(rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values())
and self.layer_name != LM_HEAD
)

@classmethod
def key(cls) -> str:
Expand All @@ -241,6 +247,7 @@ def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
layer_name: str,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
Expand Down Expand Up @@ -348,6 +355,7 @@ def load(
if segment_indices[idx] not in idx_locs:
# save the first location of encountering a particular adapter index
idx_locs[segment_indices[idx]] = loc

# second, iterate over the adapter index for each token and find its location in the `indices` array
batch_indices = torch.tensor(
[
Expand Down Expand Up @@ -375,6 +383,8 @@ def load(
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
layer_name=layer_name,
prefill_head_indices=prefill_head_indices,
)


Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
layer_name: str,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchMedusaWeights"]:
Expand Down
6 changes: 4 additions & 2 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
layer_name: str,
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
Expand Down Expand Up @@ -80,6 +81,7 @@ def is_empty(self) -> bool:
def get_data(
self,
meta: AdapterBatchMetadata,
layer_name: str,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
Expand All @@ -91,7 +93,7 @@ def get_data(

batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(adapter_weights, meta, prefill, prefill_head_indices)
batched_weights = batch_type.load(adapter_weights, meta, layer_name, prefill, prefill_head_indices)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
Expand All @@ -117,7 +119,7 @@ def from_meta(
for k, v in weights.items():
if v.is_empty():
continue
layer_weights = v.get_data(meta, prefill, prefill_head_indices if k == LM_HEAD else None)
layer_weights = v.get_data(meta, k, prefill, prefill_head_indices if k == LM_HEAD else None)
if layer_weights:
data[k] = layer_weights
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
Expand Down
3 changes: 2 additions & 1 deletion server/lorax_server/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,11 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False, weight_scale=None,
if fan_in_fan_out:
weight = weight.T.contiguous()

if quantize is None or (quantize == 'fp8' and weight_scale is None):
if quantize is None or (quantize == "fp8" and weight_scale is None):
linear = FastLinear(weight, bias)
elif quantize == "fp8":
from lorax_server.layers.fp8 import Fp8Linear

linear = Fp8Linear(weight, bias, weight_scale=weight_scale, input_scale=input_scale)

elif quantize == "bitsandbytes":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def forward(
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def forward(
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def forward(
kv_cache[0],
kv_cache[1],
self.num_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def forward(
kv_cache[0],
kv_cache[1],
self.num_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down Expand Up @@ -311,6 +312,7 @@ def forward(
kv_cache[0],
kv_cache[1],
self.num_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def forward(
kv_cache[0],
kv_cache[1],
self.num_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
3 changes: 1 addition & 2 deletions server/lorax_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
tracer = trace.get_tracer(__name__)


# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ] # LM_HEAD
ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD]
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}


Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ async def Classify(self, request: generate_pb2.ClassifyRequest, context):
self.model.device,
)
predicated_token_class, confidence_scores = self.model.classify(batch)
ner_results = self.model.batch_type.to_pb_classify(
batch, predicated_token_class, confidence_scores
)
ner_results = self.model.batch_type.to_pb_classify(batch, predicated_token_class, confidence_scores)
return ner_results

async def Embed(self, request: generate_pb2.EmbedRequest, context):
Expand Down
6 changes: 6 additions & 0 deletions server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lorax_server.adapters.lora import BatchLoraWeights, RankSegments
from lorax_server.adapters.types import LORA
from lorax_server.utils.constants import BLOCK_SIZE
from lorax_server.utils.lora import LM_HEAD
from lorax_server.utils.sgmv import BGMV_MAX_RANK

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,6 +120,8 @@ def get_max_graph_state(
),
},
use_sgmv=False, # bgmv during decode
layer_name=layer_name,
prefill_head_indices=None,
)

return GraphState(
Expand Down Expand Up @@ -207,6 +210,8 @@ def trace(
else {}
),
use_sgmv=False, # bgmv during decode
layer_name=layer_name,
prefill_head_indices=None,
)
}

Expand Down Expand Up @@ -341,6 +346,7 @@ def can_use_graph(
and nranks <= 1
and max_rank in _allowed_ranks
and all(k == LORA for k in adapter_keys)
and not any(k == LM_HEAD for k in adapter_data.layer_names())
)

def get_estimated_cache_memory(self) -> int:
Expand Down
8 changes: 7 additions & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lorax_server.adapters.types import LORA, MEDUSA
from lorax_server.layers.linear import FastLinear, get_linear # noqa: F401
from lorax_server.layers.tensor_parallel import SuperLayer, TensorParallelColumnLinear, TensorParallelHead # noqa: F401
from lorax_server.utils.lora import LM_HEAD
from lorax_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
Expand Down Expand Up @@ -135,9 +136,14 @@ def forward_layer_type(
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
adapter_indices = adapter_data.meta.adapter_indices
if data is not None and data.prefill_head_indices is not None and data.layer_name == LM_HEAD:
# LM_HEAD inputs have different shape during prefill than other layers
adapter_indices = adapter_indices[data.prefill_head_indices]

for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
adapter_mask = (adapter_data.meta.adapter_indices == adapter_index).to(input.dtype).view(-1, 1)
adapter_mask = (adapter_indices == adapter_index).to(input.dtype).view(-1, 1)
layer_result = self.forward_lora(input, data, adapter_index, adapter_mask)
result[:, start_idx:end_idx] += layer_result

Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
Expand Down Expand Up @@ -71,7 +72,6 @@ def attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
num_kv_heads = 1 + kv_head_mapping.max().item()

if SYSTEM == "xpu":
query = query.contiguous()
Expand Down
1 change: 1 addition & 0 deletions server/tests/adapters/test_medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_batched_medusa_weights(default_causal_lm: CausalLM):
1: medusa_weights,
},
meta,
layer_name=LM_HEAD,
prefill=False,
prefill_head_indices=None,
)
Expand Down
8 changes: 5 additions & 3 deletions server/tests/utils/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lorax_server.adapters.lora import LoraWeights
from lorax_server.adapters.types import LORA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights
from lorax_server.utils.lora import LM_HEAD
from lorax_server.utils.sgmv import MIN_RANK_CUSTOM


Expand All @@ -34,6 +35,7 @@ def key(cls) -> str:
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
layer_name: str,
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
Expand Down Expand Up @@ -78,7 +80,7 @@ def test_batched_lora_weights(lora_ranks: List[int]):
)

with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA)
data = batched_weights.get_data(meta, LM_HEAD, prefill=True, prefill_head_indices=None).get(LORA)

assert len(data.lora_a) == 2
assert data.lora_a.keys() == meta.adapter_set
Expand Down Expand Up @@ -153,7 +155,7 @@ def test_batched_lora_weights_decode(
)

with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None).get(LORA)
data = batched_weights.get_data(meta, LM_HEAD, prefill=False, prefill_head_indices=None).get(LORA)

for lora_rank, rd in data.rank_data.items():
expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device)
Expand Down Expand Up @@ -197,6 +199,6 @@ def test_batched_lora_weights_no_segments():
)

with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))):
data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA)
data = batched_weights.get_data(meta, LM_HEAD, prefill=True, prefill_head_indices=None).get(LORA)

print(data)

0 comments on commit 8454a82

Please sign in to comment.