Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 11, 2024
1 parent 428d304 commit dfb567d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 20 deletions.
3 changes: 2 additions & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def forward_layer_type(
# Triton Punica kernels
key = (layer_type, self.layer_id)
if (
adapter_data.punica_wrapper is not None and adapter_data.punica_wrapper.enabled
adapter_data.punica_wrapper is not None
and adapter_data.punica_wrapper.enabled
and key in adapter_data.layer_to_lora_weights
and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size
and can_vectorize
Expand Down
24 changes: 9 additions & 15 deletions server/lorax_server/utils/logits_process.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import math
from contextlib import contextmanager
from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

from loguru import logger
import torch
from loguru import logger
from transformers import (
LogitsProcessor,
LogitsWarper,
Expand Down Expand Up @@ -99,9 +99,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: float):
self.penalty = penalty

def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
Expand Down Expand Up @@ -154,26 +152,22 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
The parameter for frequency penalty. 0.0 means no penalty.
"""

def __init__(self, frequency_penalty: List[float], presence_penalty: List[float], dtype: torch.dtype, device: torch.device):
def __init__(
self, frequency_penalty: List[float], presence_penalty: List[float], dtype: torch.dtype, device: torch.device
):
self.frequency_penalty = frequency_penalty
self.frequency_penalty_tensor = torch.tensor(
frequency_penalty, dtype=dtype, device=device
).unsqueeze(1)
self.frequency_penalty_tensor = torch.tensor(frequency_penalty, dtype=dtype, device=device).unsqueeze(1)

self.presence_penalty = presence_penalty
self.presence_penalty_tensor = torch.tensor(
presence_penalty, dtype=dtype, device=device
).unsqueeze(1)
self.presence_penalty_tensor = torch.tensor(presence_penalty, dtype=dtype, device=device).unsqueeze(1)

def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
batch_size, input_size = input_ids.size()
vocab_size = scores.size(1)

# Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
)
token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float))
mask = token_freq > 0
token_freq /= input_size

Expand Down
6 changes: 2 additions & 4 deletions server/lorax_server/utils/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,7 @@ def __init__(
)

self.frequency_processor = (
HeterogeneousFrequencyPenaltyLogitsProcessor(
frequency_penalty, presence_penalty, dtype, device
)
HeterogeneousFrequencyPenaltyLogitsProcessor(frequency_penalty, presence_penalty, dtype, device)
if any([x != 0.0 for x in frequency_penalty]) or any([x != 0.0 for x in presence_penalty])
else None
)
Expand Down Expand Up @@ -462,7 +460,7 @@ def filter(self, indices):

if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices)

if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)

Expand Down

0 comments on commit dfb567d

Please sign in to comment.