Skip to content

Commit

Permalink
Update restrictions
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed May 2, 2024
1 parent ccc4003 commit f9afe58
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

EMPTY_TENSOR = torch.tensor([])

_MEDUSA_ENABLED = False


@dataclass
class MedusaConfig(AdapterConfig):
Expand Down Expand Up @@ -45,20 +47,30 @@ def load_batched_adapter_weights(
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
global _MEDUSA_ENABLED
if dynamic:
if not _MEDUSA_ENABLED:
raise ValueError(
"Medusa adapters can only be loaded at request time when LoRAX was initialized with a default "
"Medusa adapter."
)

if self.version < 2:
raise ValueError(
f"Dynamic adapter loading is not supported for Medusa version {self.version} at this time. "
f"Instead, initialize the LoRAX server with the Medusa adapter and it will be applied to every "
f"request."
f"request, or migrate to a v2 adapter."
)

if get_speculative_tokens() != self.medusa_num_heads:
raise ValueError(
f"Cannot load a Medusa adapter dynamically with a different number of heads "
f"({self.medusa_num_heads}) from the default speculative tokens ({get_speculative_tokens()})."
)
else:
_MEDUSA_ENABLED = True

# TODO(travis): load to GPU and offload to CPU in accordance with lorax scheduler
return MedusaWeights.load(
self,
model,
Expand Down

0 comments on commit f9afe58

Please sign in to comment.