From 1f8e3f0759028ce4ccbc8aee72f5ab0c196cdec5 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 08:20:07 +0000 Subject: [PATCH 01/10] add typical threshold --- aphrodite/common/sampling_params.py | 9 ++++++ aphrodite/modeling/layers/sampler.py | 29 +++++++++++++++---- aphrodite/modeling/sampling_metadata.py | 37 +++++++++++++++++-------- 3 files changed, 58 insertions(+), 17 deletions(-) diff --git a/aphrodite/common/sampling_params.py b/aphrodite/common/sampling_params.py index 265ab92b7..2f07bd980 100644 --- a/aphrodite/common/sampling_params.py +++ b/aphrodite/common/sampling_params.py @@ -72,6 +72,8 @@ class SamplingParams: typical_p: Float that controls the cumulative probability of tokens closest in surprise to the expected surprise to consider. Must be in (0, 1]. Set to 1 to disable. + typical_threshold: Used to scale the maximum threshold for positive + deviations in typical_p. Range in [0, inf). Set to 0 to disable. mirostat_mode: Can either be 0 (disabled) or 2 (Mirostat v2). mirostat_tau: Target "surprisal" that mirostat works towards. Range [0, inf). @@ -137,6 +139,7 @@ def __init__( eta_cutoff: float = 0.0, epsilon_cutoff: float = 0.0, typical_p: float = 1.0, + typical_threshold: float = 0.0, mirostat_mode: int = 0, mirostat_tau: float = 0, mirostat_eta: float = 0, @@ -175,6 +178,7 @@ def __init__( self.eta_cutoff = eta_cutoff self.epsilon_cutoff = epsilon_cutoff self.typical_p = typical_p + self.typical_threshold = typical_threshold self.mirostat_mode = mirostat_mode self.mirostat_tau = mirostat_tau self.mirostat_eta = mirostat_eta @@ -219,6 +223,7 @@ def __init__( "eta_cutoff": 0.0, "epsilon_cutoff": 0.0, "typical_p": 1.0, + "typical_threshold": 0.0, "mirostat_mode": 0, "mirostat_tau": 0, "mirostat_eta": 0, @@ -295,6 +300,10 @@ def _verify_args(self) -> None: if not 0.0 <= self.typical_p <= 1.0: raise ValueError( f"typical_p must be in (0, 1], got {self.typical_p}.") + if not self.typical_threshold >= 0: + raise ValueError( + f"typical_threshold must be non negative, got " + f"{self.typical_threshold}.") if not self.dynatemp_min >= 0: raise ValueError( f"dynatemp_min must be non negative, got {self.dynatemp_min}.") diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index d75a15738..9c9ccb1be 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -508,24 +508,43 @@ def _apply_epsilon_cutoff( def _apply_typical_sampling( logits: torch.Tensor, typical_p: torch.Tensor, + typical_threshold: torch.Tensor, ) -> torch.Tensor: typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device) + typ_threshold = torch.tensor(typical_threshold, dtype=logits.dtype, + device=logits.device) + shifted_logits = torch.log_softmax(logits, dim=-1) - probs = shifted_logits.exp() + probs = torch.exp(shifted_logits) neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True) + # NOTE: We don't take the absolute value of the surprisal deviations + # This deviates from the original implementation + surprisal_deviations = neg_entropy - shifted_logits - surprisal_deviations = (neg_entropy - shifted_logits).abs() _, indices = torch.sort(surprisal_deviations) reordered_probs = probs.gather(-1, indices) typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1) - min_tokens_to_keep = 1 - # Keep at least min_tokens_to_keep - typ_mask_sorted[..., :min_tokens_to_keep] = 0 + # Invert the minimum deviation from the expected information content of the + # probability distribution for the next token and scale it based on the + # provided typical_threshold parameter. + max_threshold = surprisal_deviations.min().negative() * typ_threshold + # Mask negative deviations and positive deviations above the max threshold + surprisal_deviations[surprisal_deviations <= 0] = 1000 + surprisal_deviations[surprisal_deviations > max_threshold] = 1000 + positive_mask = surprisal_deviations == 1000 + + typ_mask_sorted[..., :1] = 0 typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted) + + # Merging the mask created above with the one from the standard Typical-p. + # Masked out tokens in the distribution are True + typ_mask = typ_mask.bitwise_and(positive_mask) + logits[typ_mask] = -float("inf") + return logits diff --git a/aphrodite/modeling/sampling_metadata.py b/aphrodite/modeling/sampling_metadata.py index 3f65f9111..a64ac0747 100644 --- a/aphrodite/modeling/sampling_metadata.py +++ b/aphrodite/modeling/sampling_metadata.py @@ -96,6 +96,7 @@ class SamplingTensors: eta_cutoffs: torch.Tensor epsilon_cutoffs: torch.Tensor typical_ps: torch.Tensor + typical_thresholds: torch.Tensor miro_taus: torch.Tensor miro_etas: torch.Tensor miro_mus: torch.Tensor @@ -129,6 +130,7 @@ def from_sampling_metadata( eta_cutoffs: List[float] = [] epsilon_cutoffs: List[float] = [] typical_ps: List[float] = [] + typical_thresholds: List[float] = [] miro_taus: List[float] = [] miro_etas: List[float] = [] miro_mus: List[float] = [] @@ -168,6 +170,7 @@ def from_sampling_metadata( eta_cutoff = sampling_params.eta_cutoff epsilon_cutoff = sampling_params.epsilon_cutoff typical_p = sampling_params.typical_p + typical_threshold = sampling_params.typical_threshold miro_tau = sampling_params.mirostat_tau miro_eta = sampling_params.mirostat_eta dynatemp_min = sampling_params.dynatemp_min @@ -196,7 +199,8 @@ def from_sampling_metadata( do_eta_cutoffs = True if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS: do_epsilon_cutoffs = True - if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS: + if do_typical_ps is False and (typical_p < 1.0 - _SAMPLING_EPS + or typical_threshold > 0.0): do_typical_ps = True if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS or smoothing_curve > 1.0): @@ -222,6 +226,7 @@ def from_sampling_metadata( eta_cutoffs += [0] * (prompt_len - 1) epsilon_cutoffs += [0] * (prompt_len - 1) typical_ps += [1] * (prompt_len - 1) + typical_thresholds += [0] * (prompt_len - 1) dynatemp_mins += [dynatemp_min] * (prompt_len - 1) dynatemp_maxs += [dynatemp_max] * (prompt_len - 1) dynatemp_exps += [dynatemp_exp] * (prompt_len - 1) @@ -245,6 +250,7 @@ def from_sampling_metadata( eta_cutoffs += [eta_cutoff] * len(seq_ids) epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids) typical_ps += [typical_p] * len(seq_ids) + typical_thresholds += [typical_threshold] * len(seq_ids) dynatemp_mins += [dynatemp_min] * len(seq_ids) dynatemp_maxs += [dynatemp_max] * len(seq_ids) dynatemp_exps += [dynatemp_exp] * len(seq_ids) @@ -265,10 +271,10 @@ def from_sampling_metadata( sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties, frequency_penalties, repetition_penalties, tfss, eta_cutoffs, - epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs, - dynatemp_exps, miro_taus, miro_etas, miro_mus, miro_indices, - miro_seqids, smoothing_factors, smoothing_curves, prompt_tokens, - output_tokens, vocab_size, device, dtype) + epsilon_cutoffs, typical_ps, typical_thresholds, dynatemp_mins, + dynatemp_maxs, dynatemp_exps, smoothing_factors, smoothing_curves, + miro_taus, miro_etas, miro_mus, miro_indices, miro_seqids, + prompt_tokens, output_tokens, vocab_size, device, dtype) return (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat) @@ -280,12 +286,13 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], frequency_penalties: List[float], repetition_penalties: List[float], tfss: List[float], eta_cutoffs: List[float], epsilon_cutoffs: List[float], - typical_ps: List[float], dynatemp_mins: List[float], - dynatemp_maxs: List[float], dynatemp_exps: List[float], - miro_taus: List[float], miro_etas: List[float], - miro_mus: List[float], miro_indices: List[int], - miro_seqids: List[int], smoothing_factors: List[float], - smoothing_curves: List[float], + typical_ps: List[float], typical_thresholds: List[float], + dynatemp_mins: List[float], dynatemp_maxs: List[float], + dynatemp_exps: List[float], + smoothing_factors: List[float], + smoothing_curves: List[float], miro_taus: List[float], + miro_etas: List[float], miro_mus: List[float], + miro_indices: List[int], miro_seqids: List[int], prompt_tokens: List[List[int]], output_tokens: List[List[int]], vocab_size: int, device: torch.device, @@ -352,6 +359,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], device="cpu", dtype=dtype, pin_memory=pin_memory) + typical_thresholds_t = torch.tensor(typical_thresholds, + device="cpu", + dtype=dtype, + pin_memory=pin_memory) dynatemp_mins_t = torch.tensor(dynatemp_mins, device="cpu", dtype=dtype, @@ -414,6 +425,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True), epsilon_cutoffs=epsilon_cutoffs_t.to(device=device, non_blocking=True), + typical_ps=typical_ps_t.to(device=device, non_blocking=True), + typical_thresholds=typical_thresholds_t.to(device=device, + non_blocking=True), dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True), dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True), dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True), @@ -426,7 +440,6 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], miro_mus=miro_mus_t.to(device=device, non_blocking=True), miro_indices=miro_indices_t.to(device=device, non_blocking=True), miro_seqids=miro_seqids, - typical_ps=typical_ps_t.to(device=device, non_blocking=True), prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), output_tokens=output_tensor.to(device=device, non_blocking=True), ) From 36ac0e85277b0ae9777057e63289add2b6bdb2a2 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 08:24:07 +0000 Subject: [PATCH 02/10] add to api --- aphrodite/endpoints/openai/protocol.py | 5 +++++ aphrodite/endpoints/openai/samplers.json | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/aphrodite/endpoints/openai/protocol.py b/aphrodite/endpoints/openai/protocol.py index e30de2f0b..bf284c01b 100644 --- a/aphrodite/endpoints/openai/protocol.py +++ b/aphrodite/endpoints/openai/protocol.py @@ -64,6 +64,7 @@ class ChatCompletionRequest(BaseModel): eta_cutoff: Optional[float] = 0.0 epsilon_cutoff: Optional[float] = 0.0 typical_p: Optional[float] = 1.0 + typical_threshold: Optional[float] = 0.0 n: Optional[int] = 1 max_tokens: Optional[int] = None seed: Optional[int] = None @@ -132,6 +133,7 @@ def logit_bias_logits_processor( eta_cutoff=self.eta_cutoff, epsilon_cutoff=self.epsilon_cutoff, typical_p=self.typical_p, + typical_threshold=self.typical_threshold, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=self.repetition_penalty, @@ -186,6 +188,7 @@ class CompletionRequest(BaseModel): eta_cutoff: Optional[float] = 0.0 epsilon_cutoff: Optional[float] = 0.0 typical_p: Optional[float] = 1.0 + typical_threshold: Optional[float] = 0.0 n: Optional[int] = 1 stream: Optional[bool] = False logprobs: Optional[int] = None @@ -254,6 +257,7 @@ def logit_bias_logits_processor( eta_cutoff=self.eta_cutoff, epsilon_cutoff=self.epsilon_cutoff, typical_p=self.typical_p, + typical_threshold=self.typical_threshold, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=self.repetition_penalty, @@ -405,6 +409,7 @@ class KoboldSamplingParams(BaseModel): eta_cutoff: float = Field(0.0, alias="eta_cutoff") epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff") typical_p: float = Field(1.0, alias="typical_p") + typical_threshold: float = Field(0.0, alias="typical_threshold") use_beam_search: bool = Field(False, alias="use_beam_search") length_penalty: float = Field(1.0, alias="length_penalty") early_stopping: Union[bool, str] = Field(False, alias="early_stopping") diff --git a/aphrodite/endpoints/openai/samplers.json b/aphrodite/endpoints/openai/samplers.json index 9c540e51d..8833751fd 100644 --- a/aphrodite/endpoints/openai/samplers.json +++ b/aphrodite/endpoints/openai/samplers.json @@ -209,6 +209,15 @@ "default": 1, "description": "Control the cumulative probability of tokens closest in surprise to the expected surprise to consider." }, + "typical_threshold": { + "pretty_name": "Typical Threshold", + "type": "float", + "minimum": 0, + "maximum": 10, + "step": 0.01, + "default": 0, + "description": "Scale the maximum threshold for positive deviations in typical_p." + }, "eta_cutoff": { "pretty_name": "Eta Cutoff", "type": "float", From 9818904b62dcff6dd9902763b9a1eb9ad7779638 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 08:28:26 +0000 Subject: [PATCH 03/10] formatting --- aphrodite/common/sampling_params.py | 5 ++--- aphrodite/modeling/layers/sampler.py | 7 ++++--- aphrodite/modeling/sampling_metadata.py | 9 ++++----- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/aphrodite/common/sampling_params.py b/aphrodite/common/sampling_params.py index 2f07bd980..0a58a5f5a 100644 --- a/aphrodite/common/sampling_params.py +++ b/aphrodite/common/sampling_params.py @@ -301,9 +301,8 @@ def _verify_args(self) -> None: raise ValueError( f"typical_p must be in (0, 1], got {self.typical_p}.") if not self.typical_threshold >= 0: - raise ValueError( - f"typical_threshold must be non negative, got " - f"{self.typical_threshold}.") + raise ValueError(f"typical_threshold must be non negative, got " + f"{self.typical_threshold}.") if not self.dynatemp_min >= 0: raise ValueError( f"dynatemp_min must be non negative, got {self.dynatemp_min}.") diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 9c9ccb1be..26324b6cf 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -511,7 +511,8 @@ def _apply_typical_sampling( typical_threshold: torch.Tensor, ) -> torch.Tensor: typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device) - typ_threshold = torch.tensor(typical_threshold, dtype=logits.dtype, + typ_threshold = torch.tensor(typical_threshold, + dtype=logits.dtype, device=logits.device) shifted_logits = torch.log_softmax(logits, dim=-1) @@ -538,11 +539,11 @@ def _apply_typical_sampling( typ_mask_sorted[..., :1] = 0 typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted) - + # Merging the mask created above with the one from the standard Typical-p. # Masked out tokens in the distribution are True typ_mask = typ_mask.bitwise_and(positive_mask) - + logits[typ_mask] = -float("inf") return logits diff --git a/aphrodite/modeling/sampling_metadata.py b/aphrodite/modeling/sampling_metadata.py index a64ac0747..8f53179ae 100644 --- a/aphrodite/modeling/sampling_metadata.py +++ b/aphrodite/modeling/sampling_metadata.py @@ -288,8 +288,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], eta_cutoffs: List[float], epsilon_cutoffs: List[float], typical_ps: List[float], typical_thresholds: List[float], dynatemp_mins: List[float], dynatemp_maxs: List[float], - dynatemp_exps: List[float], - smoothing_factors: List[float], + dynatemp_exps: List[float], smoothing_factors: List[float], smoothing_curves: List[float], miro_taus: List[float], miro_etas: List[float], miro_mus: List[float], miro_indices: List[int], miro_seqids: List[int], @@ -360,9 +359,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=dtype, pin_memory=pin_memory) typical_thresholds_t = torch.tensor(typical_thresholds, - device="cpu", - dtype=dtype, - pin_memory=pin_memory) + device="cpu", + dtype=dtype, + pin_memory=pin_memory) dynatemp_mins_t = torch.tensor(dynatemp_mins, device="cpu", dtype=dtype, From 0fdce07ad59a0ea6951e94ce7257b35e76ee95e2 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 08:58:38 +0000 Subject: [PATCH 04/10] correctly pass the threshold param --- aphrodite/modeling/layers/sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 26324b6cf..87313640e 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -112,7 +112,8 @@ def forward( sampling_tensors.epsilon_cutoffs) if do_typical_ps: logits = _apply_typical_sampling(logits, - sampling_tensors.typical_ps) + sampling_tensors.typical_ps, + sampling_tensors.typical_thresholds) if do_quadratic: logits = _apply_quadratic_sampling( logits, sampling_tensors.smoothing_factors, From 0965d5d03b70d026a357fc15f3ad5bc1ab2b7e70 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 09:10:02 +0000 Subject: [PATCH 05/10] clone and detached the params from the original tensors --- aphrodite/modeling/layers/sampler.py | 7 +++---- tests/samplers/test_typical_sampling.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) create mode 100644 tests/samplers/test_typical_sampling.py diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 87313640e..03f7d1041 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -511,10 +511,9 @@ def _apply_typical_sampling( typical_p: torch.Tensor, typical_threshold: torch.Tensor, ) -> torch.Tensor: - typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device) - typ_threshold = torch.tensor(typical_threshold, - dtype=logits.dtype, - device=logits.device) + typ_p = typical_p.clone().detach().to(logits.device).to(logits.dtype) + typ_threshold = typical_threshold.clone().detach().to(logits.device).to( + logits.dtype) shifted_logits = torch.log_softmax(logits, dim=-1) probs = torch.exp(shifted_logits) diff --git a/tests/samplers/test_typical_sampling.py b/tests/samplers/test_typical_sampling.py new file mode 100644 index 000000000..b95f62048 --- /dev/null +++ b/tests/samplers/test_typical_sampling.py @@ -0,0 +1,25 @@ +import torch +import unittest + +from aphrodite.modeling.layers.sampler import _apply_typical_sampling, _apply_clone_typical_sampling + + +class TestTypicalSampling(unittest.TestCase): + def setUp(self): + self.batch_sizes = [1, 10, 100] + self.logits_sizes = [10, 100, 1000] + + def test_consistency(self): + for batch_size in self.batch_sizes: + for logits_size in self.logits_sizes: + logits = torch.randn(batch_size, logits_size) + typical_p = torch.rand(1) + typical_threshold = torch.rand(1) + + original_result = _apply_typical_sampling(logits.clone(), typical_p.clone(), typical_threshold.clone()) + modified_result = _apply_clone_typical_sampling(logits.clone(), typical_p.clone(), typical_threshold.clone()) + + self.assertTrue(torch.allclose(original_result, modified_result)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From aee216c5c517cbf70e6fd5b4540681ffd1044c7d Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 09:10:55 +0000 Subject: [PATCH 06/10] remove unneeded test unit --- tests/samplers/test_typical_sampling.py | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 tests/samplers/test_typical_sampling.py diff --git a/tests/samplers/test_typical_sampling.py b/tests/samplers/test_typical_sampling.py deleted file mode 100644 index b95f62048..000000000 --- a/tests/samplers/test_typical_sampling.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import unittest - -from aphrodite.modeling.layers.sampler import _apply_typical_sampling, _apply_clone_typical_sampling - - -class TestTypicalSampling(unittest.TestCase): - def setUp(self): - self.batch_sizes = [1, 10, 100] - self.logits_sizes = [10, 100, 1000] - - def test_consistency(self): - for batch_size in self.batch_sizes: - for logits_size in self.logits_sizes: - logits = torch.randn(batch_size, logits_size) - typical_p = torch.rand(1) - typical_threshold = torch.rand(1) - - original_result = _apply_typical_sampling(logits.clone(), typical_p.clone(), typical_threshold.clone()) - modified_result = _apply_clone_typical_sampling(logits.clone(), typical_p.clone(), typical_threshold.clone()) - - self.assertTrue(torch.allclose(original_result, modified_result)) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file From 356881de03f6272ce23fc053df828bf98221631d Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 09:21:56 +0000 Subject: [PATCH 07/10] formatting again --- aphrodite/modeling/layers/sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 03f7d1041..578608a83 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -111,9 +111,9 @@ def forward( logits = _apply_epsilon_cutoff(logits, sampling_tensors.epsilon_cutoffs) if do_typical_ps: - logits = _apply_typical_sampling(logits, - sampling_tensors.typical_ps, - sampling_tensors.typical_thresholds) + logits = _apply_typical_sampling( + logits, sampling_tensors.typical_ps, + sampling_tensors.typical_thresholds) if do_quadratic: logits = _apply_quadratic_sampling( logits, sampling_tensors.smoothing_factors, From 3e3bb75300222c273c7081756fa7a955ad66cc99 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 14:03:10 +0000 Subject: [PATCH 08/10] typical_threshold -> typical_p_sigma Co-authored-by: BugReporterZ <26941368+BugReporterZ@users.noreply.github.com> --- aphrodite/common/sampling_params.py | 14 ++++++------- aphrodite/endpoints/openai/protocol.py | 10 ++++----- aphrodite/modeling/layers/sampler.py | 12 +++++------ aphrodite/modeling/sampling_metadata.py | 28 ++++++++++++------------- 4 files changed, 32 insertions(+), 32 deletions(-) diff --git a/aphrodite/common/sampling_params.py b/aphrodite/common/sampling_params.py index 0a58a5f5a..1d47d6390 100644 --- a/aphrodite/common/sampling_params.py +++ b/aphrodite/common/sampling_params.py @@ -72,7 +72,7 @@ class SamplingParams: typical_p: Float that controls the cumulative probability of tokens closest in surprise to the expected surprise to consider. Must be in (0, 1]. Set to 1 to disable. - typical_threshold: Used to scale the maximum threshold for positive + typical_p_sigma: Used to scale the maximum threshold for positive deviations in typical_p. Range in [0, inf). Set to 0 to disable. mirostat_mode: Can either be 0 (disabled) or 2 (Mirostat v2). mirostat_tau: Target "surprisal" that mirostat works towards. @@ -139,7 +139,7 @@ def __init__( eta_cutoff: float = 0.0, epsilon_cutoff: float = 0.0, typical_p: float = 1.0, - typical_threshold: float = 0.0, + typical_p_sigma: float = 0.0, mirostat_mode: int = 0, mirostat_tau: float = 0, mirostat_eta: float = 0, @@ -178,7 +178,7 @@ def __init__( self.eta_cutoff = eta_cutoff self.epsilon_cutoff = epsilon_cutoff self.typical_p = typical_p - self.typical_threshold = typical_threshold + self.typical_p_sigma = typical_p_sigma self.mirostat_mode = mirostat_mode self.mirostat_tau = mirostat_tau self.mirostat_eta = mirostat_eta @@ -223,7 +223,7 @@ def __init__( "eta_cutoff": 0.0, "epsilon_cutoff": 0.0, "typical_p": 1.0, - "typical_threshold": 0.0, + "typical_p_sigma": 0.0, "mirostat_mode": 0, "mirostat_tau": 0, "mirostat_eta": 0, @@ -300,9 +300,9 @@ def _verify_args(self) -> None: if not 0.0 <= self.typical_p <= 1.0: raise ValueError( f"typical_p must be in (0, 1], got {self.typical_p}.") - if not self.typical_threshold >= 0: - raise ValueError(f"typical_threshold must be non negative, got " - f"{self.typical_threshold}.") + if not self.typical_p_sigma >= 0: + raise ValueError(f"typical_p_sigma must be non negative, got " + f"{self.typical_p_sigma}.") if not self.dynatemp_min >= 0: raise ValueError( f"dynatemp_min must be non negative, got {self.dynatemp_min}.") diff --git a/aphrodite/endpoints/openai/protocol.py b/aphrodite/endpoints/openai/protocol.py index bf284c01b..513ac9828 100644 --- a/aphrodite/endpoints/openai/protocol.py +++ b/aphrodite/endpoints/openai/protocol.py @@ -64,7 +64,7 @@ class ChatCompletionRequest(BaseModel): eta_cutoff: Optional[float] = 0.0 epsilon_cutoff: Optional[float] = 0.0 typical_p: Optional[float] = 1.0 - typical_threshold: Optional[float] = 0.0 + typical_p_sigma: Optional[float] = 0.0 n: Optional[int] = 1 max_tokens: Optional[int] = None seed: Optional[int] = None @@ -133,7 +133,7 @@ def logit_bias_logits_processor( eta_cutoff=self.eta_cutoff, epsilon_cutoff=self.epsilon_cutoff, typical_p=self.typical_p, - typical_threshold=self.typical_threshold, + typical_p_sigma=self.typical_p_sigma, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=self.repetition_penalty, @@ -188,7 +188,7 @@ class CompletionRequest(BaseModel): eta_cutoff: Optional[float] = 0.0 epsilon_cutoff: Optional[float] = 0.0 typical_p: Optional[float] = 1.0 - typical_threshold: Optional[float] = 0.0 + typical_p_sigma: Optional[float] = 0.0 n: Optional[int] = 1 stream: Optional[bool] = False logprobs: Optional[int] = None @@ -257,7 +257,7 @@ def logit_bias_logits_processor( eta_cutoff=self.eta_cutoff, epsilon_cutoff=self.epsilon_cutoff, typical_p=self.typical_p, - typical_threshold=self.typical_threshold, + typical_p_sigma=self.typical_p_sigma, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=self.repetition_penalty, @@ -409,7 +409,7 @@ class KoboldSamplingParams(BaseModel): eta_cutoff: float = Field(0.0, alias="eta_cutoff") epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff") typical_p: float = Field(1.0, alias="typical_p") - typical_threshold: float = Field(0.0, alias="typical_threshold") + typical_p_sigma: float = Field(0.0, alias="typical_p_sigma") use_beam_search: bool = Field(False, alias="use_beam_search") length_penalty: float = Field(1.0, alias="length_penalty") early_stopping: Union[bool, str] = Field(False, alias="early_stopping") diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 578608a83..14c42398d 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -111,9 +111,9 @@ def forward( logits = _apply_epsilon_cutoff(logits, sampling_tensors.epsilon_cutoffs) if do_typical_ps: - logits = _apply_typical_sampling( - logits, sampling_tensors.typical_ps, - sampling_tensors.typical_thresholds) + logits = _apply_typical_sampling(logits, + sampling_tensors.typical_ps, + sampling_tensors.typical_p_sigmas) if do_quadratic: logits = _apply_quadratic_sampling( logits, sampling_tensors.smoothing_factors, @@ -509,10 +509,10 @@ def _apply_epsilon_cutoff( def _apply_typical_sampling( logits: torch.Tensor, typical_p: torch.Tensor, - typical_threshold: torch.Tensor, + typical_p_sigma: torch.Tensor, ) -> torch.Tensor: typ_p = typical_p.clone().detach().to(logits.device).to(logits.dtype) - typ_threshold = typical_threshold.clone().detach().to(logits.device).to( + typ_threshold = typical_p_sigma.clone().detach().to(logits.device).to( logits.dtype) shifted_logits = torch.log_softmax(logits, dim=-1) @@ -529,7 +529,7 @@ def _apply_typical_sampling( # Invert the minimum deviation from the expected information content of the # probability distribution for the next token and scale it based on the - # provided typical_threshold parameter. + # provided typical_p_sigma parameter. max_threshold = surprisal_deviations.min().negative() * typ_threshold # Mask negative deviations and positive deviations above the max threshold diff --git a/aphrodite/modeling/sampling_metadata.py b/aphrodite/modeling/sampling_metadata.py index 8f53179ae..ad3e6d008 100644 --- a/aphrodite/modeling/sampling_metadata.py +++ b/aphrodite/modeling/sampling_metadata.py @@ -96,7 +96,7 @@ class SamplingTensors: eta_cutoffs: torch.Tensor epsilon_cutoffs: torch.Tensor typical_ps: torch.Tensor - typical_thresholds: torch.Tensor + typical_p_sigmas: torch.Tensor miro_taus: torch.Tensor miro_etas: torch.Tensor miro_mus: torch.Tensor @@ -130,7 +130,7 @@ def from_sampling_metadata( eta_cutoffs: List[float] = [] epsilon_cutoffs: List[float] = [] typical_ps: List[float] = [] - typical_thresholds: List[float] = [] + typical_p_sigmas: List[float] = [] miro_taus: List[float] = [] miro_etas: List[float] = [] miro_mus: List[float] = [] @@ -170,7 +170,7 @@ def from_sampling_metadata( eta_cutoff = sampling_params.eta_cutoff epsilon_cutoff = sampling_params.epsilon_cutoff typical_p = sampling_params.typical_p - typical_threshold = sampling_params.typical_threshold + typical_p_sigma = sampling_params.typical_p_sigma miro_tau = sampling_params.mirostat_tau miro_eta = sampling_params.mirostat_eta dynatemp_min = sampling_params.dynatemp_min @@ -200,7 +200,7 @@ def from_sampling_metadata( if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS: do_epsilon_cutoffs = True if do_typical_ps is False and (typical_p < 1.0 - _SAMPLING_EPS - or typical_threshold > 0.0): + or typical_p_sigma > 0.0): do_typical_ps = True if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS or smoothing_curve > 1.0): @@ -226,7 +226,7 @@ def from_sampling_metadata( eta_cutoffs += [0] * (prompt_len - 1) epsilon_cutoffs += [0] * (prompt_len - 1) typical_ps += [1] * (prompt_len - 1) - typical_thresholds += [0] * (prompt_len - 1) + typical_p_sigmas += [0] * (prompt_len - 1) dynatemp_mins += [dynatemp_min] * (prompt_len - 1) dynatemp_maxs += [dynatemp_max] * (prompt_len - 1) dynatemp_exps += [dynatemp_exp] * (prompt_len - 1) @@ -250,7 +250,7 @@ def from_sampling_metadata( eta_cutoffs += [eta_cutoff] * len(seq_ids) epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids) typical_ps += [typical_p] * len(seq_ids) - typical_thresholds += [typical_threshold] * len(seq_ids) + typical_p_sigmas += [typical_p_sigma] * len(seq_ids) dynatemp_mins += [dynatemp_min] * len(seq_ids) dynatemp_maxs += [dynatemp_max] * len(seq_ids) dynatemp_exps += [dynatemp_exp] * len(seq_ids) @@ -271,7 +271,7 @@ def from_sampling_metadata( sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties, frequency_penalties, repetition_penalties, tfss, eta_cutoffs, - epsilon_cutoffs, typical_ps, typical_thresholds, dynatemp_mins, + epsilon_cutoffs, typical_ps, typical_p_sigmas, dynatemp_mins, dynatemp_maxs, dynatemp_exps, smoothing_factors, smoothing_curves, miro_taus, miro_etas, miro_mus, miro_indices, miro_seqids, prompt_tokens, output_tokens, vocab_size, device, dtype) @@ -286,7 +286,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], frequency_penalties: List[float], repetition_penalties: List[float], tfss: List[float], eta_cutoffs: List[float], epsilon_cutoffs: List[float], - typical_ps: List[float], typical_thresholds: List[float], + typical_ps: List[float], typical_p_sigmas: List[float], dynatemp_mins: List[float], dynatemp_maxs: List[float], dynatemp_exps: List[float], smoothing_factors: List[float], smoothing_curves: List[float], miro_taus: List[float], @@ -358,10 +358,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], device="cpu", dtype=dtype, pin_memory=pin_memory) - typical_thresholds_t = torch.tensor(typical_thresholds, - device="cpu", - dtype=dtype, - pin_memory=pin_memory) + typical_p_sigmas_t = torch.tensor(typical_p_sigmas, + device="cpu", + dtype=dtype, + pin_memory=pin_memory) dynatemp_mins_t = torch.tensor(dynatemp_mins, device="cpu", dtype=dtype, @@ -425,8 +425,8 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], epsilon_cutoffs=epsilon_cutoffs_t.to(device=device, non_blocking=True), typical_ps=typical_ps_t.to(device=device, non_blocking=True), - typical_thresholds=typical_thresholds_t.to(device=device, - non_blocking=True), + typical_p_sigmas=typical_p_sigmas_t.to(device=device, + non_blocking=True), dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True), dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True), dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True), From 1e0e0587f355c9211f25f9774459349e946d7007 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 14:27:44 +0000 Subject: [PATCH 09/10] unsqueeze and add test --- aphrodite/modeling/layers/sampler.py | 1 + tests/samplers/test_typical_p_sampling.py | 39 +++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 tests/samplers/test_typical_p_sampling.py diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 14c42398d..76ca97b57 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -533,6 +533,7 @@ def _apply_typical_sampling( max_threshold = surprisal_deviations.min().negative() * typ_threshold # Mask negative deviations and positive deviations above the max threshold + max_threshold = max_threshold.unsqueeze(1) surprisal_deviations[surprisal_deviations <= 0] = 1000 surprisal_deviations[surprisal_deviations > max_threshold] = 1000 positive_mask = surprisal_deviations == 1000 diff --git a/tests/samplers/test_typical_p_sampling.py b/tests/samplers/test_typical_p_sampling.py new file mode 100644 index 000000000..150d6f2b2 --- /dev/null +++ b/tests/samplers/test_typical_p_sampling.py @@ -0,0 +1,39 @@ +import pytest +import torch + +from aphrodite.modeling.layers.sampler import _apply_typical_sampling + +def test_typical_sampling_shape(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert output.shape == logits.shape, "Output shape should match input shape" + +def test_typical_sampling_dtype(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert output.dtype == logits.dtype, "Output dtype should match input dtype" + +def test_typical_sampling_device(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert output.device == logits.device, "Output dev should match input dev" + +def test_typical_sampling_inf(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert not torch.isinf(output).any(), "Output should not contain inf" + +def test_typical_sampling_nan(): + logits = torch.randn(10, 5) + typical_p = torch.randn(10) + typical_p_sigma = torch.randn(10) + output = _apply_typical_sampling(logits, typical_p, typical_p_sigma) + assert not torch.isnan(output).any(), "Output should not contain NaN" From f0aa9b7d31853cabb819b45bed15afc57389fb19 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 19 Mar 2024 14:32:12 +0000 Subject: [PATCH 10/10] use a local variable --- aphrodite/modeling/layers/sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/aphrodite/modeling/layers/sampler.py b/aphrodite/modeling/layers/sampler.py index 76ca97b57..8391fd928 100644 --- a/aphrodite/modeling/layers/sampler.py +++ b/aphrodite/modeling/layers/sampler.py @@ -514,6 +514,7 @@ def _apply_typical_sampling( typ_p = typical_p.clone().detach().to(logits.device).to(logits.dtype) typ_threshold = typical_p_sigma.clone().detach().to(logits.device).to( logits.dtype) + THRESHOLD = 1000 shifted_logits = torch.log_softmax(logits, dim=-1) probs = torch.exp(shifted_logits) @@ -534,9 +535,9 @@ def _apply_typical_sampling( # Mask negative deviations and positive deviations above the max threshold max_threshold = max_threshold.unsqueeze(1) - surprisal_deviations[surprisal_deviations <= 0] = 1000 - surprisal_deviations[surprisal_deviations > max_threshold] = 1000 - positive_mask = surprisal_deviations == 1000 + surprisal_deviations[surprisal_deviations <= 0] = THRESHOLD + surprisal_deviations[surprisal_deviations > max_threshold] = THRESHOLD + positive_mask = surprisal_deviations == THRESHOLD typ_mask_sorted[..., :1] = 0 typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)