Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 committed Dec 19, 2023
1 parent 60c837b commit 505ad9d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 13 deletions.
59 changes: 59 additions & 0 deletions server/lorax_server/utils/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ class Output:


class ResBlock(torch.nn.Module):
"""
Residual block module.
Args:
config (dict): Configuration for the block.
prefix (str): Prefix for the block.
weights (torch.Tensor): Weights for the block.
Attributes:
linear (FastLinear): Linear layer.
act (torch.nn.SiLU): Activation function.
"""

def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
Expand All @@ -18,10 +32,33 @@ def __init__(self, config, prefix, weights):
self.act = torch.nn.SiLU()

def forward(self, x):
"""
Forward pass of the residual block.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
return x + self.act(self.linear(x))


class MedusaModel(torch.nn.Module):
"""
MedusaModel is a PyTorch module that represents the Medusa model.
Args:
config (dict): Configuration parameters for the Medusa model.
weights (list): List of weights for the Medusa model.
lm_head (torch.nn.Module): Language model head for the Medusa model.
Attributes:
heads (torch.nn.ModuleList): List of MedusaHead modules.
lm_head (torch.nn.Module): Language model head for the Medusa model.
"""

def __init__(self, config, weights, lm_head):
super().__init__()

Expand All @@ -33,12 +70,34 @@ def __init__(self, config, weights, lm_head):
self.lm_head = lm_head

def forward(self, x):
"""
Forward pass of the MedusaModel.
Args:
x (torch.Tensor): Input tensor.
Returns:
tuple: A tuple containing the logits and speculative logits.
"""
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return logits, speculative_logits


class MedusaHead(torch.nn.Module):
"""
MedusaHead is a module that represents the head of the Medusa network.
Args:
config (dict): Configuration parameters for the Medusa network.
prefix (str): Prefix for naming the layers of the MedusaHead module.
weights (dict): Pretrained weights for the Medusa network.
Attributes:
blocks (torch.nn.ModuleList): List of ResBlock modules.
out (FastLinear): Output layer of the MedusaHead module.
"""

def __init__(self, config, prefix, weights):
super().__init__()

Expand Down
28 changes: 15 additions & 13 deletions server/lorax_server/utils/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,12 +313,12 @@ def __call__(
- speculative_ids (Optional[torch.Tensor]): The selected speculative token IDs.
"""
if speculation_ids is not None:
B = scores.shape[0] // (speculation_ids.shape[1] + 1) if speculation_ids is not None else scores.shape[0]
S = speculation_ids.shape[1] + 1 if speculation_ids is not None else 1
scores = scores.view(B, S, -1)
_batches = scores.shape[0] // (speculation_ids.shape[1] + 1) if speculation_ids is not None else scores.shape[0]
_speculations = speculation_ids.shape[1] + 1 if speculation_ids is not None else 1
scores = scores.view(_batches, _speculations, -1)

next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S):
next_ids = torch.zeros((_batches, _speculations), device=scores.device, dtype=torch.long)
for j in range(_speculations):
_scores = scores[:, j]
if self.watermark_processor is not None:
_scores = self.watermark_processor(input_ids, _scores)
Expand All @@ -331,19 +331,21 @@ def __call__(
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B * S)
scores = scores.view(B * S, -1)
next_ids = next_ids.view(_batches * _speculations)
scores = scores.view(_batches * _speculations, -1)

if speculation_ids is not None:
accepted_ids = []
B = next_ids.shape[0] // (speculation_ids.shape[1] + 1)
S = speculation_ids.shape[1] + 1
# number of batches
_batches = next_ids.shape[0] // (speculation_ids.shape[1] + 1)
# number of speculations
_speculations = speculation_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i * S : (i + 1) * S]
for i in range(_batches):
_next_ids = next_ids[i * _speculations : (i + 1) * _speculations]
_speculated_ids = speculation_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
index = i * _speculations
accepted = 1
indices.append(index)
for valid in validate_speculative.tolist():
Expand All @@ -360,7 +362,7 @@ def __call__(
)
next_ids = next_ids[indices]
scores = scores[indices]
indices = torch.arange(B, device=input_ids.device) * S
indices = torch.arange(_batches, device=input_ids.device) * _speculations
if speculation_scores is not None:
speculation_scores = speculation_scores[indices + accepted_ids - 1]
else:
Expand Down

0 comments on commit 505ad9d

Please sign in to comment.