Skip to content

Commit

Permalink
Added Medusa adapters per request (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored May 2, 2024
1 parent 3581b26 commit fae986e
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 14 deletions.
2 changes: 1 addition & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def valid_adapter_id(self):
@field_validator("adapter_source")
def valid_adapter_source(cls, v):
if v is not None and v not in ADAPTER_SOURCES:
raise ValidationError(f"`adapter_source` must be one of {ADAPTER_SOURCES}")
raise ValidationError(f"`adapter_source={v}` must be one of {ADAPTER_SOURCES}")
return v

@field_validator("best_of")
Expand Down
91 changes: 80 additions & 11 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

import torch
import torch.distributed
Expand All @@ -8,16 +8,24 @@
from lorax_server.adapters.types import MEDUSA
from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights
from lorax_server.utils.layers import FastLinear, TensorParallelColumnLinear
from lorax_server.utils.sgmv import segmented_matmul
from lorax_server.utils.state import get_speculative_tokens
from lorax_server.utils.weights import AbstractWeights, InMemoryWeights

if TYPE_CHECKING:
from lorax_server.models.model import Model


EMPTY_TENSOR = torch.tensor([])

_MEDUSA_ENABLED = False


@dataclass
class MedusaConfig(AdapterConfig):
medusa_num_heads: int
medusa_num_layers: int
version: int

@property
def quantize(self) -> Optional[str]:
Expand All @@ -39,12 +47,30 @@ def load_batched_adapter_weights(
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
global _MEDUSA_ENABLED
if dynamic:
raise ValueError(
"Dynamic adapter loading is not supported for Medusa at this time. "
"Instead, initialize the LoRAX server with the Medusa adapter and it will be applied to every request."
)
if not _MEDUSA_ENABLED:
raise ValueError(
"Medusa adapters can only be loaded at request time when LoRAX was initialized with a default "
"Medusa adapter."
)

if self.version < 2:
raise ValueError(
f"Dynamic adapter loading is not supported for Medusa version {self.version} at this time. "
f"Instead, initialize the LoRAX server with the Medusa adapter and it will be applied to every "
f"request, or migrate to a v2 adapter."
)

if get_speculative_tokens() != self.medusa_num_heads:
raise ValueError(
f"Cannot load a Medusa adapter dynamically with a different number of heads "
f"({self.medusa_num_heads}) from the default speculative tokens ({get_speculative_tokens()})."
)
else:
_MEDUSA_ENABLED = True

# TODO(travis): load to GPU and offload to CPU in accordance with lorax scheduler
return MedusaWeights.load(
self,
model,
Expand All @@ -59,9 +85,18 @@ def load(cls, config: dict) -> "MedusaConfig":
base_model_name_or_path=config["base_model_name_or_path"],
medusa_num_heads=config["medusa_num_heads"],
medusa_num_layers=config["medusa_num_layers"],
version=float(config.get("version", 1)),
)


@dataclass
class MedusaSegments:
w: List[torch.Tensor]
b: List[torch.Tensor]
s_start: torch.Tensor
s_end: torch.Tensor


class ResBlock(torch.nn.Module):
def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights):
super().__init__()
Expand Down Expand Up @@ -96,7 +131,7 @@ def __init__(self, config: MedusaConfig, weights: AbstractWeights):
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config.medusa_num_heads)]
)

def forward(self, x, lm_head):
def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None):
logits = lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return logits, speculative_logits
Expand All @@ -121,7 +156,7 @@ def __init__(self, config: MedusaConfig, weights: AbstractWeights):

self.act = torch.nn.SiLU()

def forward(self, x, lm_head):
def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None):
# If we have too many tokens, we skip speculative logits
if x.shape[0] > 128:
logits = lm_head(x)
Expand All @@ -134,8 +169,23 @@ def forward(self, x, lm_head):

x_block = x[:, start:stop]

if segments is not None:
# Multi-Medusa
# TODO(travis): custom kernel similar to SGMV
y = torch.empty((x.shape[0], self.n_medusa_heads * x_block.shape[-1]), device=x.device, dtype=x.dtype)
segmented_matmul(
y,
x,
segments.w,
segments.b,
segments.s_start,
segments.s_end,
)
else:
y = self.linear(x)

# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res = self.act(self.linear(x)).reshape(*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1])
medusa_res = self.act(y).reshape(*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1])

# Apply all residual medusa heads
output = x[:, start:stop].unsqueeze(-2) + medusa_res
Expand Down Expand Up @@ -167,14 +217,15 @@ def __init__(self, config: MedusaConfig, weights: AbstractWeights):
else:
self.medusa = MedusaV2(config, weights)

def forward(self, x, lm_head):
return self.medusa(x, lm_head)
def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None):
return self.medusa(x, lm_head, segments)


class MedusaWeights(AdapterWeights):
def __init__(self, config: MedusaConfig, module_map: ModuleMap, model: "Model"):
self.config = config
self.model = MedusaModel(config, InMemoryWeights(module_map, model.device, model.dtype, model.process_group))
self.process_group = model.process_group

@classmethod
def get_batch_type(cls) -> BatchAdapterWeights:
Expand Down Expand Up @@ -202,6 +253,7 @@ def load(
class BatchMedusaWeights(BatchAdapterWeights):
adapter_to_medusa: Dict[int, MedusaWeights]
default_medusa: Optional[MedusaWeights] = None
segments: Optional[MedusaSegments] = None

def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_to_medusa
Expand All @@ -210,13 +262,30 @@ def has_adapter(self, adapter_index: int) -> bool:
def key(cls) -> str:
return MEDUSA

def __call__(self, x, lm_head):
if self.default_medusa:
return self.default_medusa.model(x, lm_head, self.segments)
return lm_head(x)

@classmethod
def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchMedusaWeights":
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)}

adapter_to_medusa = {idx: adapter_weights[idx] for idx in meta.segment_indices if idx in adapter_weights}
indices = [idx for idx, s in enumerate(meta.segment_indices) if s in adapter_to_medusa]

return BatchMedusaWeights(
adapter_to_medusa=adapter_to_medusa,
default_medusa=adapter_weights.get(0),
segments=MedusaSegments(
w=[
(adapter_weights[idx].model.medusa.linear.linear.weight if idx in adapter_weights else EMPTY_TENSOR)
for idx in meta.segment_indices
],
b=[
(adapter_weights[idx].model.medusa.linear.linear.bias if idx in adapter_weights else EMPTY_TENSOR)
for idx in meta.segment_indices
],
s_start=meta.adapter_segments[indices],
s_end=meta.adapter_segments[[i + 1 for i in indices]],
),
)
2 changes: 1 addition & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def forward(
if data is not None and data.default_medusa is not None:
forward = super().forward
lm_head = lambda x: forward(x, adapter_data) # noqa: E731
logits, speculative_logits = data.default_medusa.model(input, lm_head)
logits, speculative_logits = data(input, lm_head)

# TODO(travis): support multiple medusa adapters with masking:
# for adapter_index in adapter_data.meta.adapter_set:
Expand Down
20 changes: 19 additions & 1 deletion server/lorax_server/utils/sgmv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from functools import lru_cache
from typing import Tuple
from typing import List, Tuple

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -215,3 +215,21 @@ def add_lora_b_bgmv(
layer_idx: int,
):
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)


def segmented_matmul(
y: torch.Tensor,
x: torch.Tensor,
w: List[torch.Tensor],
b: List[torch.Tensor],
s_start: torch.IntTensor,
s_end: torch.IntTensor,
):
for i in range(len(w)):
if s_end[i] - s_start[i] <= 0:
continue

xi = x[s_start[i] : s_end[i]]
wi = w[i]
bi = b[i]
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)

0 comments on commit fae986e

Please sign in to comment.