Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: typical_p threshold sampling #343

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aphrodite/common/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class SamplingParams:
typical_p: Float that controls the cumulative probability of tokens
closest in surprise to the expected surprise to consider.
Must be in (0, 1]. Set to 1 to disable.
typical_p_sigma: Used to scale the maximum threshold for positive
deviations in typical_p. Range in [0, inf). Set to 0 to disable.
mirostat_mode: Can either be 0 (disabled) or 2 (Mirostat v2).
mirostat_tau: Target "surprisal" that mirostat works towards.
Range [0, inf).
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
eta_cutoff: float = 0.0,
epsilon_cutoff: float = 0.0,
typical_p: float = 1.0,
typical_p_sigma: float = 0.0,
mirostat_mode: int = 0,
mirostat_tau: float = 0,
mirostat_eta: float = 0,
Expand Down Expand Up @@ -175,6 +178,7 @@ def __init__(
self.eta_cutoff = eta_cutoff
self.epsilon_cutoff = epsilon_cutoff
self.typical_p = typical_p
self.typical_p_sigma = typical_p_sigma
self.mirostat_mode = mirostat_mode
self.mirostat_tau = mirostat_tau
self.mirostat_eta = mirostat_eta
Expand Down Expand Up @@ -219,6 +223,7 @@ def __init__(
"eta_cutoff": 0.0,
"epsilon_cutoff": 0.0,
"typical_p": 1.0,
"typical_p_sigma": 0.0,
"mirostat_mode": 0,
"mirostat_tau": 0,
"mirostat_eta": 0,
Expand Down Expand Up @@ -295,6 +300,9 @@ def _verify_args(self) -> None:
if not 0.0 <= self.typical_p <= 1.0:
raise ValueError(
f"typical_p must be in (0, 1], got {self.typical_p}.")
if not self.typical_p_sigma >= 0:
raise ValueError(f"typical_p_sigma must be non negative, got "
f"{self.typical_p_sigma}.")
if not self.dynatemp_min >= 0:
raise ValueError(
f"dynatemp_min must be non negative, got {self.dynatemp_min}.")
Expand Down
5 changes: 5 additions & 0 deletions aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ChatCompletionRequest(BaseModel):
eta_cutoff: Optional[float] = 0.0
epsilon_cutoff: Optional[float] = 0.0
typical_p: Optional[float] = 1.0
typical_p_sigma: Optional[float] = 0.0
n: Optional[int] = 1
max_tokens: Optional[int] = None
seed: Optional[int] = None
Expand Down Expand Up @@ -132,6 +133,7 @@ def logit_bias_logits_processor(
eta_cutoff=self.eta_cutoff,
epsilon_cutoff=self.epsilon_cutoff,
typical_p=self.typical_p,
typical_p_sigma=self.typical_p_sigma,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
Expand Down Expand Up @@ -186,6 +188,7 @@ class CompletionRequest(BaseModel):
eta_cutoff: Optional[float] = 0.0
epsilon_cutoff: Optional[float] = 0.0
typical_p: Optional[float] = 1.0
typical_p_sigma: Optional[float] = 0.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
Expand Down Expand Up @@ -254,6 +257,7 @@ def logit_bias_logits_processor(
eta_cutoff=self.eta_cutoff,
epsilon_cutoff=self.epsilon_cutoff,
typical_p=self.typical_p,
typical_p_sigma=self.typical_p_sigma,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
Expand Down Expand Up @@ -405,6 +409,7 @@ class KoboldSamplingParams(BaseModel):
eta_cutoff: float = Field(0.0, alias="eta_cutoff")
epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
typical_p: float = Field(1.0, alias="typical_p")
typical_p_sigma: float = Field(0.0, alias="typical_p_sigma")
use_beam_search: bool = Field(False, alias="use_beam_search")
length_penalty: float = Field(1.0, alias="length_penalty")
early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
Expand Down
9 changes: 9 additions & 0 deletions aphrodite/endpoints/openai/samplers.json
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,15 @@
"default": 1,
"description": "Control the cumulative probability of tokens closest in surprise to the expected surprise to consider."
},
"typical_threshold": {
"pretty_name": "Typical Threshold",
"type": "float",
"minimum": 0,
"maximum": 10,
"step": 0.01,
"default": 0,
"description": "Scale the maximum threshold for positive deviations in typical_p."
},
"eta_cutoff": {
"pretty_name": "Eta Cutoff",
"type": "float",
Expand Down
36 changes: 29 additions & 7 deletions aphrodite/modeling/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def forward(
sampling_tensors.epsilon_cutoffs)
if do_typical_ps:
logits = _apply_typical_sampling(logits,
sampling_tensors.typical_ps)
sampling_tensors.typical_ps,
sampling_tensors.typical_p_sigmas)
if do_quadratic:
logits = _apply_quadratic_sampling(
logits, sampling_tensors.smoothing_factors,
Expand Down Expand Up @@ -508,24 +509,45 @@ def _apply_epsilon_cutoff(
def _apply_typical_sampling(
logits: torch.Tensor,
typical_p: torch.Tensor,
typical_p_sigma: torch.Tensor,
) -> torch.Tensor:
typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
typ_p = typical_p.clone().detach().to(logits.device).to(logits.dtype)
typ_threshold = typical_p_sigma.clone().detach().to(logits.device).to(
logits.dtype)
THRESHOLD = 1000

shifted_logits = torch.log_softmax(logits, dim=-1)
probs = shifted_logits.exp()
probs = torch.exp(shifted_logits)

neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
# NOTE: We don't take the absolute value of the surprisal deviations
# This deviates from the original implementation
surprisal_deviations = neg_entropy - shifted_logits
Copy link

@BugReporterZ BugReporterZ Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with this change (or the intentions in the code here). The modification in my original hack (not posted in this PR) was intended to retain the basic behavior of Typical_P, which first sorts the surprisal deviations by their absolute value.

Only after this is done, then, using the signed surprisal deviations (copied into a different tensor before computing the absolute values for the other), you would obtain a second subset for extending the token selection as in the algorithm described in the explanation in the discussion.


surprisal_deviations = (neg_entropy - shifted_logits).abs()
_, indices = torch.sort(surprisal_deviations)
reordered_probs = probs.gather(-1, indices)
typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)

min_tokens_to_keep = 1
# Keep at least min_tokens_to_keep
typ_mask_sorted[..., :min_tokens_to_keep] = 0
# Invert the minimum deviation from the expected information content of the
# probability distribution for the next token and scale it based on the
# provided typical_p_sigma parameter.
max_threshold = surprisal_deviations.min().negative() * typ_threshold

# Mask negative deviations and positive deviations above the max threshold
max_threshold = max_threshold.unsqueeze(1)
surprisal_deviations[surprisal_deviations <= 0] = THRESHOLD
surprisal_deviations[surprisal_deviations > max_threshold] = THRESHOLD
positive_mask = surprisal_deviations == THRESHOLD

typ_mask_sorted[..., :1] = 0
typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)

# Merging the mask created above with the one from the standard Typical-p.
# Masked out tokens in the distribution are True
typ_mask = typ_mask.bitwise_and(positive_mask)

logits[typ_mask] = -float("inf")

return logits


Expand Down
36 changes: 24 additions & 12 deletions aphrodite/modeling/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class SamplingTensors:
eta_cutoffs: torch.Tensor
epsilon_cutoffs: torch.Tensor
typical_ps: torch.Tensor
typical_p_sigmas: torch.Tensor
miro_taus: torch.Tensor
miro_etas: torch.Tensor
miro_mus: torch.Tensor
Expand Down Expand Up @@ -129,6 +130,7 @@ def from_sampling_metadata(
eta_cutoffs: List[float] = []
epsilon_cutoffs: List[float] = []
typical_ps: List[float] = []
typical_p_sigmas: List[float] = []
miro_taus: List[float] = []
miro_etas: List[float] = []
miro_mus: List[float] = []
Expand Down Expand Up @@ -168,6 +170,7 @@ def from_sampling_metadata(
eta_cutoff = sampling_params.eta_cutoff
epsilon_cutoff = sampling_params.epsilon_cutoff
typical_p = sampling_params.typical_p
typical_p_sigma = sampling_params.typical_p_sigma
miro_tau = sampling_params.mirostat_tau
miro_eta = sampling_params.mirostat_eta
dynatemp_min = sampling_params.dynatemp_min
Expand Down Expand Up @@ -196,7 +199,8 @@ def from_sampling_metadata(
do_eta_cutoffs = True
if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
do_epsilon_cutoffs = True
if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
if do_typical_ps is False and (typical_p < 1.0 - _SAMPLING_EPS
or typical_p_sigma > 0.0):
do_typical_ps = True
if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
or smoothing_curve > 1.0):
Expand All @@ -222,6 +226,7 @@ def from_sampling_metadata(
eta_cutoffs += [0] * (prompt_len - 1)
epsilon_cutoffs += [0] * (prompt_len - 1)
typical_ps += [1] * (prompt_len - 1)
typical_p_sigmas += [0] * (prompt_len - 1)
dynatemp_mins += [dynatemp_min] * (prompt_len - 1)
dynatemp_maxs += [dynatemp_max] * (prompt_len - 1)
dynatemp_exps += [dynatemp_exp] * (prompt_len - 1)
Expand All @@ -245,6 +250,7 @@ def from_sampling_metadata(
eta_cutoffs += [eta_cutoff] * len(seq_ids)
epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
typical_ps += [typical_p] * len(seq_ids)
typical_p_sigmas += [typical_p_sigma] * len(seq_ids)
dynatemp_mins += [dynatemp_min] * len(seq_ids)
dynatemp_maxs += [dynatemp_max] * len(seq_ids)
dynatemp_exps += [dynatemp_exp] * len(seq_ids)
Expand All @@ -265,10 +271,10 @@ def from_sampling_metadata(
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
dynatemp_exps, miro_taus, miro_etas, miro_mus, miro_indices,
miro_seqids, smoothing_factors, smoothing_curves, prompt_tokens,
output_tokens, vocab_size, device, dtype)
epsilon_cutoffs, typical_ps, typical_p_sigmas, dynatemp_mins,
dynatemp_maxs, dynatemp_exps, smoothing_factors, smoothing_curves,
miro_taus, miro_etas, miro_mus, miro_indices, miro_seqids,
prompt_tokens, output_tokens, vocab_size, device, dtype)
return (sampling_tensors, do_temperatures, do_penalties, do_topks,
do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat)
Expand All @@ -280,12 +286,12 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float], tfss: List[float],
eta_cutoffs: List[float], epsilon_cutoffs: List[float],
typical_ps: List[float], dynatemp_mins: List[float],
dynatemp_maxs: List[float], dynatemp_exps: List[float],
miro_taus: List[float], miro_etas: List[float],
miro_mus: List[float], miro_indices: List[int],
miro_seqids: List[int], smoothing_factors: List[float],
smoothing_curves: List[float],
typical_ps: List[float], typical_p_sigmas: List[float],
dynatemp_mins: List[float], dynatemp_maxs: List[float],
dynatemp_exps: List[float], smoothing_factors: List[float],
smoothing_curves: List[float], miro_taus: List[float],
miro_etas: List[float], miro_mus: List[float],
miro_indices: List[int], miro_seqids: List[int],
prompt_tokens: List[List[int]],
output_tokens: List[List[int]], vocab_size: int,
device: torch.device,
Expand Down Expand Up @@ -352,6 +358,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
device="cpu",
dtype=dtype,
pin_memory=pin_memory)
typical_p_sigmas_t = torch.tensor(typical_p_sigmas,
device="cpu",
dtype=dtype,
pin_memory=pin_memory)
dynatemp_mins_t = torch.tensor(dynatemp_mins,
device="cpu",
dtype=dtype,
Expand Down Expand Up @@ -414,6 +424,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
non_blocking=True),
typical_ps=typical_ps_t.to(device=device, non_blocking=True),
typical_p_sigmas=typical_p_sigmas_t.to(device=device,
non_blocking=True),
dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
Expand All @@ -426,7 +439,6 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
miro_mus=miro_mus_t.to(device=device, non_blocking=True),
miro_indices=miro_indices_t.to(device=device, non_blocking=True),
miro_seqids=miro_seqids,
typical_ps=typical_ps_t.to(device=device, non_blocking=True),
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
output_tokens=output_tensor.to(device=device, non_blocking=True),
)
39 changes: 39 additions & 0 deletions tests/samplers/test_typical_p_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch

from aphrodite.modeling.layers.sampler import _apply_typical_sampling

def test_typical_sampling_shape():
logits = torch.randn(10, 5)
typical_p = torch.randn(10)
typical_p_sigma = torch.randn(10)
output = _apply_typical_sampling(logits, typical_p, typical_p_sigma)
assert output.shape == logits.shape, "Output shape should match input shape"

def test_typical_sampling_dtype():
logits = torch.randn(10, 5)
typical_p = torch.randn(10)
typical_p_sigma = torch.randn(10)
output = _apply_typical_sampling(logits, typical_p, typical_p_sigma)
assert output.dtype == logits.dtype, "Output dtype should match input dtype"

def test_typical_sampling_device():
logits = torch.randn(10, 5)
typical_p = torch.randn(10)
typical_p_sigma = torch.randn(10)
output = _apply_typical_sampling(logits, typical_p, typical_p_sigma)
assert output.device == logits.device, "Output dev should match input dev"

def test_typical_sampling_inf():
logits = torch.randn(10, 5)
typical_p = torch.randn(10)
typical_p_sigma = torch.randn(10)
output = _apply_typical_sampling(logits, typical_p, typical_p_sigma)
assert not torch.isinf(output).any(), "Output should not contain inf"

def test_typical_sampling_nan():
logits = torch.randn(10, 5)
typical_p = torch.randn(10)
typical_p_sigma = torch.randn(10)
output = _apply_typical_sampling(logits, typical_p, typical_p_sigma)
assert not torch.isnan(output).any(), "Output should not contain NaN"
Loading