diff --git a/aphrodite/common/sampling_params.py b/aphrodite/common/sampling_params.py index 265ab92b7..1d47d6390 100644 --- a/aphrodite/common/sampling_params.py +++ b/aphrodite/common/sampling_params.py @@ -72,6 +72,8 @@ class SamplingParams: typical_p: Float that controls the cumulative probability of tokens closest in surprise to the expected surprise to consider. Must be in (0, 1]. Set to 1 to disable. + typical_p_sigma: Used to scale the maximum threshold for positive + deviations in typical_p. Range in [0, inf). Set to 0 to disable. mirostat_mode: Can either be 0 (disabled) or 2 (Mirostat v2). mirostat_tau: Target "surprisal" that mirostat works towards. Range [0, inf). @@ -137,6 +139,7 @@ def __init__( eta_cutoff: float = 0.0, epsilon_cutoff: float = 0.0, typical_p: float = 1.0, + typical_p_sigma: float = 0.0, mirostat_mode: int = 0, mirostat_tau: float = 0, mirostat_eta: float = 0, @@ -175,6 +178,7 @@ def __init__( self.eta_cutoff = eta_cutoff self.epsilon_cutoff = epsilon_cutoff self.typical_p = typical_p + self.typical_p_sigma = typical_p_sigma self.mirostat_mode = mirostat_mode self.mirostat_tau = mirostat_tau self.mirostat_eta = mirostat_eta @@ -219,6 +223,7 @@ def __init__( "eta_cutoff": 0.0, "epsilon_cutoff": 0.0, "typical_p": 1.0, + "typical_p_sigma": 0.0, "mirostat_mode": 0, "mirostat_tau": 0, "mirostat_eta": 0, @@ -295,6 +300,9 @@ def _verify_args(self) -> None: if not 0.0 <= self.typical_p <= 1.0: raise ValueError( f"typical_p must be in (0, 1], got {self.typical_p}.") + if not self.typical_p_sigma >= 0: + raise ValueError(f"typical_p_sigma must be non negative, got " + f"{self.typical_p_sigma}.") if not self.dynatemp_min >= 0: raise ValueError( f"dynatemp_min must be non negative, got {self.dynatemp_min}.") diff --git a/aphrodite/endpoints/openai/protocol.py b/aphrodite/endpoints/openai/protocol.py index e30de2f0b..513ac9828 100644 --- a/aphrodite/endpoints/openai/protocol.py +++ b/aphrodite/endpoints/openai/protocol.py @@ -64,6 +64,7 @@ class ChatCompletionRequest(BaseModel): eta_cutoff: Optional[float] = 0.0 epsilon_cutoff: Optional[float] = 0.0 typical_p: Optional[float] = 1.0 + typical_p_sigma: Optional[float] = 0.0 n: Optional[int] = 1 max_tokens: Optional[int] = None seed: Optional[int] = None @@ -132,6 +133,7 @@ def logit_bias_logits_processor( eta_cutoff=self.eta_cutoff, epsilon_cutoff=self.epsilon_cutoff, typical_p=self.typical_p, + typical_p_sigma=self.typical_p_sigma, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=self.repetition_penalty, @@ -186,6 +188,7 @@ class CompletionRequest(BaseModel): eta_cutoff: Optional[float] = 0.0 epsilon_cutoff: Optional[float] = 0.0 typical_p: Optional[float] = 1.0 + typical_p_sigma: Optional[float] = 0.0 n: Optional[int] = 1 stream: Optional[bool] = False logprobs: Optional[int] = None @@ -254,6 +257,7 @@ def logit_bias_logits_processor( eta_cutoff=self.eta_cutoff, epsilon_cutoff=self.epsilon_cutoff, typical_p=self.typical_p, + typical_p_sigma=self.typical_p_sigma, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=self.repetition_penalty, @@ -405,6 +409,7 @@ class KoboldSamplingParams(BaseModel): eta_cutoff: float = Field(0.0, alias="eta_cutoff") epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff") typical_p: float = Field(1.0, alias="typical_p") + typical_p_sigma: float = Field(0.0, alias="typical_p_sigma") use_beam_search: bool = Field(False, alias="use_beam_search") length_penalty: float = Field(1.0, alias="length_penalty") early_stopping: Union[bool, str] = Field(False, alias="early_stopping") diff --git a/aphrodite/endpoints/openai/samplers.json b/aphrodite/endpoints/openai/samplers.json index 9c540e51d..8833751fd 100644 --- a/aphrodite/endpoints/openai/samplers.json +++ b/aphrodite/endpoints/openai/samplers.json @@ -209,6 +209,15 @@ "default": 1, "description": "Control the cumulative probability of tokens closest in surprise to the expected surprise to consider." }, + "typical_threshold": { + "pretty_name": "Typical Threshold", + "type": "float", + "minimum": 0, + "maximum": 10, + "step": 0.01, + "default": 0, + "description": "Scale the maximum threshold for positive deviations in typical_p." + }, "eta_cutoff": { "pretty_name": "Eta Cutoff", "type": "float", diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index d75a15738..8391fd928 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -112,7 +112,8 @@ def forward( sampling_tensors.epsilon_cutoffs) if do_typical_ps: logits = _apply_typical_sampling(logits, - sampling_tensors.typical_ps) + sampling_tensors.typical_ps, + sampling_tensors.typical_p_sigmas) if do_quadratic: logits = _apply_quadratic_sampling( logits, sampling_tensors.smoothing_factors, @@ -508,24 +509,45 @@ def _apply_epsilon_cutoff( def _apply_typical_sampling( logits: torch.Tensor, typical_p: torch.Tensor, + typical_p_sigma: torch.Tensor, ) -> torch.Tensor: - typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device) + typ_p = typical_p.clone().detach().to(logits.device).to(logits.dtype) + typ_threshold = typical_p_sigma.clone().detach().to(logits.device).to( + logits.dtype) + THRESHOLD = 1000 + shifted_logits = torch.log_softmax(logits, dim=-1) - probs = shifted_logits.exp() + probs = torch.exp(shifted_logits) neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True) + # NOTE: We don't take the absolute value of the surprisal deviations + # This deviates from the original implementation + surprisal_deviations = neg_entropy - shifted_logits - surprisal_deviations = (neg_entropy - shifted_logits).abs() _, indices = torch.sort(surprisal_deviations) reordered_probs = probs.gather(-1, indices) typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1) - min_tokens_to_keep = 1 - # Keep at least min_tokens_to_keep - typ_mask_sorted[..., :min_tokens_to_keep] = 0 + # Invert the minimum deviation from the expected information content of the + # probability distribution for the next token and scale it based on the + # provided typical_p_sigma parameter. + max_threshold = surprisal_deviations.min().negative() * typ_threshold + + # Mask negative deviations and positive deviations above the max threshold + max_threshold = max_threshold.unsqueeze(1) + surprisal_deviations[surprisal_deviations <= 0] = THRESHOLD + surprisal_deviations[surprisal_deviations > max_threshold] = THRESHOLD + positive_mask = surprisal_deviations == THRESHOLD + typ_mask_sorted[..., :1] = 0 typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted) + + # Merging the mask created above with the one from the standard Typical-p. + # Masked out tokens in the distribution are True + typ_mask = typ_mask.bitwise_and(positive_mask) + logits[typ_mask] = -float("inf") + return logits diff --git a/aphrodite/modeling/sampling_metadata.py b/aphrodite/modeling/sampling_metadata.py index 3f65f9111..ad3e6d008 100644 --- a/aphrodite/modeling/sampling_metadata.py +++ b/aphrodite/modeling/sampling_metadata.py @@ -96,6 +96,7 @@ class SamplingTensors: eta_cutoffs: torch.Tensor epsilon_cutoffs: torch.Tensor typical_ps: torch.Tensor + typical_p_sigmas: torch.Tensor miro_taus: torch.Tensor miro_etas: torch.Tensor miro_mus: torch.Tensor @@ -129,6 +130,7 @@ def from_sampling_metadata( eta_cutoffs: List[float] = [] epsilon_cutoffs: List[float] = [] typical_ps: List[float] = [] + typical_p_sigmas: List[float] = [] miro_taus: List[float] = [] miro_etas: List[float] = [] miro_mus: List[float] = [] @@ -168,6 +170,7 @@ def from_sampling_metadata( eta_cutoff = sampling_params.eta_cutoff epsilon_cutoff = sampling_params.epsilon_cutoff typical_p = sampling_params.typical_p + typical_p_sigma = sampling_params.typical_p_sigma miro_tau = sampling_params.mirostat_tau miro_eta = sampling_params.mirostat_eta dynatemp_min = sampling_params.dynatemp_min @@ -196,7 +199,8 @@ def from_sampling_metadata( do_eta_cutoffs = True if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS: do_epsilon_cutoffs = True - if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS: + if do_typical_ps is False and (typical_p < 1.0 - _SAMPLING_EPS + or typical_p_sigma > 0.0): do_typical_ps = True if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS or smoothing_curve > 1.0): @@ -222,6 +226,7 @@ def from_sampling_metadata( eta_cutoffs += [0] * (prompt_len - 1) epsilon_cutoffs += [0] * (prompt_len - 1) typical_ps += [1] * (prompt_len - 1) + typical_p_sigmas += [0] * (prompt_len - 1) dynatemp_mins += [dynatemp_min] * (prompt_len - 1) dynatemp_maxs += [dynatemp_max] * (prompt_len - 1) dynatemp_exps += [dynatemp_exp] * (prompt_len - 1) @@ -245,6 +250,7 @@ def from_sampling_metadata( eta_cutoffs += [eta_cutoff] * len(seq_ids) epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids) typical_ps += [typical_p] * len(seq_ids) + typical_p_sigmas += [typical_p_sigma] * len(seq_ids) dynatemp_mins += [dynatemp_min] * len(seq_ids) dynatemp_maxs += [dynatemp_max] * len(seq_ids) dynatemp_exps += [dynatemp_exp] * len(seq_ids) @@ -265,10 +271,10 @@ def from_sampling_metadata( sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties, frequency_penalties, repetition_penalties, tfss, eta_cutoffs, - epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs, - dynatemp_exps, miro_taus, miro_etas, miro_mus, miro_indices, - miro_seqids, smoothing_factors, smoothing_curves, prompt_tokens, - output_tokens, vocab_size, device, dtype) + epsilon_cutoffs, typical_ps, typical_p_sigmas, dynatemp_mins, + dynatemp_maxs, dynatemp_exps, smoothing_factors, smoothing_curves, + miro_taus, miro_etas, miro_mus, miro_indices, miro_seqids, + prompt_tokens, output_tokens, vocab_size, device, dtype) return (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat) @@ -280,12 +286,12 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], frequency_penalties: List[float], repetition_penalties: List[float], tfss: List[float], eta_cutoffs: List[float], epsilon_cutoffs: List[float], - typical_ps: List[float], dynatemp_mins: List[float], - dynatemp_maxs: List[float], dynatemp_exps: List[float], - miro_taus: List[float], miro_etas: List[float], - miro_mus: List[float], miro_indices: List[int], - miro_seqids: List[int], smoothing_factors: List[float], - smoothing_curves: List[float], + typical_ps: List[float], typical_p_sigmas: List[float], + dynatemp_mins: List[float], dynatemp_maxs: List[float], + dynatemp_exps: List[float], smoothing_factors: List[float], + smoothing_curves: List[float], miro_taus: List[float], + miro_etas: List[float], miro_mus: List[float], + miro_indices: List[int], miro_seqids: List[int], prompt_tokens: List[List[int]], output_tokens: List[List[int]], vocab_size: int, device: torch.device, @@ -352,6 +358,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], device="cpu", dtype=dtype, pin_memory=pin_memory) + typical_p_sigmas_t = torch.tensor(typical_p_sigmas, + device="cpu", + dtype=dtype, + pin_memory=pin_memory) dynatemp_mins_t = torch.tensor(dynatemp_mins, device="cpu", dtype=dtype, @@ -414,6 +424,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True), epsilon_cutoffs=epsilon_cutoffs_t.to(device=device, non_blocking=True), + typical_ps=typical_ps_t.to(device=device, non_blocking=True), + typical_p_sigmas=typical_p_sigmas_t.to(device=device, + non_blocking=True), dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True), dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True), dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True), @@ -426,7 +439,6 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], miro_mus=miro_mus_t.to(device=device, non_blocking=True), miro_indices=miro_indices_t.to(device=device, non_blocking=True), miro_seqids=miro_seqids, - typical_ps=typical_ps_t.to(device=device, non_blocking=True), prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), output_tokens=output_tensor.to(device=device, non_blocking=True), ) diff --git a/tests/samplers/test_typical_p_sampling.py b/tests/samplers/test_typical_p_sampling.py new file mode 100644 index 000000000..150d6f2b2 --- /dev/null +++ b/tests/samplers/test_typical_p_sampling.py @@ -0,0 +1,39 @@ +import pytest +import torch + +from aphrodite.modeling.layers.sampler import _apply_typical_sampling + +def test_typical_sampling_shape(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert output.shape == logits.shape, "Output shape should match input shape" + +def test_typical_sampling_dtype(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert output.dtype == logits.dtype, "Output dtype should match input dtype" + +def test_typical_sampling_device(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert output.device == logits.device, "Output dev should match input dev" + +def test_typical_sampling_inf(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert not torch.isinf(output).any(), "Output should not contain inf" + +def test_typical_sampling_nan(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert not torch.isnan(output).any(), "Output should not contain NaN"