diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index b206907ad..fadfc0823 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -127,7 +127,7 @@ def valid_adapter_id(self): @field_validator("adapter_source") def valid_adapter_source(cls, v): if v is not None and v not in ADAPTER_SOURCES: - raise ValidationError(f"`adapter_source` must be one of {ADAPTER_SOURCES}") + raise ValidationError(f"`adapter_source={v}` must be one of {ADAPTER_SOURCES}") return v @field_validator("best_of") diff --git a/server/lorax_server/adapters/medusa.py b/server/lorax_server/adapters/medusa.py index 38fceb8ed..aff5ea5d0 100644 --- a/server/lorax_server/adapters/medusa.py +++ b/server/lorax_server/adapters/medusa.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import torch import torch.distributed @@ -8,16 +8,24 @@ from lorax_server.adapters.types import MEDUSA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights from lorax_server.utils.layers import FastLinear, TensorParallelColumnLinear +from lorax_server.utils.sgmv import segmented_matmul +from lorax_server.utils.state import get_speculative_tokens from lorax_server.utils.weights import AbstractWeights, InMemoryWeights if TYPE_CHECKING: from lorax_server.models.model import Model +EMPTY_TENSOR = torch.tensor([]) + +_MEDUSA_ENABLED = False + + @dataclass class MedusaConfig(AdapterConfig): medusa_num_heads: int medusa_num_layers: int + version: int @property def quantize(self) -> Optional[str]: @@ -39,12 +47,30 @@ def load_batched_adapter_weights( unused_weight_names: Set[str], dynamic: bool, ) -> Optional[AdapterWeights]: + global _MEDUSA_ENABLED if dynamic: - raise ValueError( - "Dynamic adapter loading is not supported for Medusa at this time. " - "Instead, initialize the LoRAX server with the Medusa adapter and it will be applied to every request." - ) + if not _MEDUSA_ENABLED: + raise ValueError( + "Medusa adapters can only be loaded at request time when LoRAX was initialized with a default " + "Medusa adapter." + ) + + if self.version < 2: + raise ValueError( + f"Dynamic adapter loading is not supported for Medusa version {self.version} at this time. " + f"Instead, initialize the LoRAX server with the Medusa adapter and it will be applied to every " + f"request, or migrate to a v2 adapter." + ) + + if get_speculative_tokens() != self.medusa_num_heads: + raise ValueError( + f"Cannot load a Medusa adapter dynamically with a different number of heads " + f"({self.medusa_num_heads}) from the default speculative tokens ({get_speculative_tokens()})." + ) + else: + _MEDUSA_ENABLED = True + # TODO(travis): load to GPU and offload to CPU in accordance with lorax scheduler return MedusaWeights.load( self, model, @@ -59,9 +85,18 @@ def load(cls, config: dict) -> "MedusaConfig": base_model_name_or_path=config["base_model_name_or_path"], medusa_num_heads=config["medusa_num_heads"], medusa_num_layers=config["medusa_num_layers"], + version=float(config.get("version", 1)), ) +@dataclass +class MedusaSegments: + w: List[torch.Tensor] + b: List[torch.Tensor] + s_start: torch.Tensor + s_end: torch.Tensor + + class ResBlock(torch.nn.Module): def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights): super().__init__() @@ -96,7 +131,7 @@ def __init__(self, config: MedusaConfig, weights: AbstractWeights): [MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config.medusa_num_heads)] ) - def forward(self, x, lm_head): + def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None): logits = lm_head(x) speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) return logits, speculative_logits @@ -121,7 +156,7 @@ def __init__(self, config: MedusaConfig, weights: AbstractWeights): self.act = torch.nn.SiLU() - def forward(self, x, lm_head): + def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None): # If we have too many tokens, we skip speculative logits if x.shape[0] > 128: logits = lm_head(x) @@ -134,8 +169,23 @@ def forward(self, x, lm_head): x_block = x[:, start:stop] + if segments is not None: + # Multi-Medusa + # TODO(travis): custom kernel similar to SGMV + y = torch.empty((x.shape[0], self.n_medusa_heads * x_block.shape[-1]), device=x.device, dtype=x.dtype) + segmented_matmul( + y, + x, + segments.w, + segments.b, + segments.s_start, + segments.s_end, + ) + else: + y = self.linear(x) + # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1 - medusa_res = self.act(self.linear(x)).reshape(*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]) + medusa_res = self.act(y).reshape(*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]) # Apply all residual medusa heads output = x[:, start:stop].unsqueeze(-2) + medusa_res @@ -167,14 +217,15 @@ def __init__(self, config: MedusaConfig, weights: AbstractWeights): else: self.medusa = MedusaV2(config, weights) - def forward(self, x, lm_head): - return self.medusa(x, lm_head) + def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None): + return self.medusa(x, lm_head, segments) class MedusaWeights(AdapterWeights): def __init__(self, config: MedusaConfig, module_map: ModuleMap, model: "Model"): self.config = config self.model = MedusaModel(config, InMemoryWeights(module_map, model.device, model.dtype, model.process_group)) + self.process_group = model.process_group @classmethod def get_batch_type(cls) -> BatchAdapterWeights: @@ -202,6 +253,7 @@ def load( class BatchMedusaWeights(BatchAdapterWeights): adapter_to_medusa: Dict[int, MedusaWeights] default_medusa: Optional[MedusaWeights] = None + segments: Optional[MedusaSegments] = None def has_adapter(self, adapter_index: int) -> bool: return adapter_index in self.adapter_to_medusa @@ -210,13 +262,30 @@ def has_adapter(self, adapter_index: int) -> bool: def key(cls) -> str: return MEDUSA + def __call__(self, x, lm_head): + if self.default_medusa: + return self.default_medusa.model(x, lm_head, self.segments) + return lm_head(x) + @classmethod def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchMedusaWeights": adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)} - adapter_to_medusa = {idx: adapter_weights[idx] for idx in meta.segment_indices if idx in adapter_weights} + indices = [idx for idx, s in enumerate(meta.segment_indices) if s in adapter_to_medusa] return BatchMedusaWeights( adapter_to_medusa=adapter_to_medusa, default_medusa=adapter_weights.get(0), + segments=MedusaSegments( + w=[ + (adapter_weights[idx].model.medusa.linear.linear.weight if idx in adapter_weights else EMPTY_TENSOR) + for idx in meta.segment_indices + ], + b=[ + (adapter_weights[idx].model.medusa.linear.linear.bias if idx in adapter_weights else EMPTY_TENSOR) + for idx in meta.segment_indices + ], + s_start=meta.adapter_segments[indices], + s_end=meta.adapter_segments[[i + 1 for i in indices]], + ), ) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index bbc65b213..a1bb3e3c8 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -687,7 +687,7 @@ def forward( if data is not None and data.default_medusa is not None: forward = super().forward lm_head = lambda x: forward(x, adapter_data) # noqa: E731 - logits, speculative_logits = data.default_medusa.model(input, lm_head) + logits, speculative_logits = data(input, lm_head) # TODO(travis): support multiple medusa adapters with masking: # for adapter_index in adapter_data.meta.adapter_set: diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index b78c38756..3e565d990 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -1,7 +1,7 @@ import os import warnings from functools import lru_cache -from typing import Tuple +from typing import List, Tuple import torch import torch.nn.functional as F @@ -215,3 +215,21 @@ def add_lora_b_bgmv( layer_idx: int, ): _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) + + +def segmented_matmul( + y: torch.Tensor, + x: torch.Tensor, + w: List[torch.Tensor], + b: List[torch.Tensor], + s_start: torch.IntTensor, + s_end: torch.IntTensor, +): + for i in range(len(w)): + if s_end[i] - s_start[i] <= 0: + continue + + xi = x[s_start[i] : s_end[i]] + wi = w[i] + bi = b[i] + y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)