Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: typical_p threshold sampling #343

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

AlpinDale
Copy link
Member

@AlpinDale AlpinDale commented Mar 19, 2024

PR adds a new hyperparameter to typical_p sampling, which scales the maximum threshold for positive deviations in typ_p. Credits to Suikamelon (@BugReporterZ ). Untested yet.

@AlpinDale AlpinDale requested a review from StefanGliga March 19, 2024 08:32
@BugReporterZ
Copy link

BugReporterZ commented Mar 19, 2024

Explanation

Here is an explanation of what the modified code is supposed to do.

  1. Calculate the set of candidate tokens exactly like in the original Typical_p algorithm;
  2. Retrieve the negative deviation $-D$ of the most likely token in the original distribution and invert its sign (let's call it $+D$);
  3. Find the set of tokens having a deviation from $0$ to $+D$ in the original token distribution;
  4. The final range of candidate tokens is the union of both sets.

We can also scale $+D$ by a hyperparameter named Sigma ($\sigma$); results may vary depending on the model and use case. A good safe value that from preliminary testing always appears to be always beneficial is 1.0. Values up to around 1.5~2.0 may be used but can occasionally lead to incoherence. 0.0 restores the original Typical-p behavior (i.e the original set of tokens is extended with an empty set).

This modification has a much greater impact at low Typical-P Tau ranges than at higher ones. As a starting point, Typical-P Tau around 0.35-0.50 with Sigma of 1.0 appears to work well all-around. Preferably, no other sampler (besides Temperature=1) nor repetition penalty should be used with the modified Typical-P with these suggested settings. Sigma can act as a temperature-like control (higher values will cause a larger amount of less likely tokens to be included and vice-versa).

Testing with different language models under various conditions may of course reveal different good values for both Tau and Sigma.

Why?

A problem in LLM-generated text is degeneracy (boredom/repetition and confusion/incoherency traps) with boredom/repetition due to a "pathology" of too often sampling the most likely tokens [1]. Typical-P by default does a good job in fixing this at low Tau values (by discarding high-probability tokens when appropriate), but it needlessly restricts the set of rarer tokens taken into consideration, ultimately making the token selection oddly deterministic when configured to 0.0 (⇒ the only possible choice is not the most likely token, but the one closest to the logit distribution's "conditional entropy").

Below is one example simulated in a spreadsheet. It can be seen that by default at low Tau values the token selection range is skewed toward rarer tokens on one side, but restricted on the other.

image

The modification in this PR extends that range up to positive deviations that (at least up to a Sigma of 1.0) should safe to pick, keeping the average perplexity of the generate text high(er), combating repetition while still avoiding confusion/incoherency and preventing the default deterministic behavior at low-to-zero Typical-P Tau values. It's also the realization that if the token having $-D$ deviation is an acceptable choice, then the one having $+D$ deviation should also be.

By allowing more tokens with positive deviation to be sampled, we also promote a longer positive tail in the distribution of deviations of the output tokens (as gauged by a language model at temperature=1).

The graphs below from the Typical P paper show the distribution of deviations ($\varepsilon$) of different types of human text, gauged by an LLM.

image

These graphs notably exhibit a long positive deviation tail (having roughly 2-3 times the absolute value of the negative deviation and an upper limit of about 10—although this might be use-case and model-dependent) and an average close to 0 or possibly very slightly positive (this is a universal observation). Boring, machine-sounding text composed by picking mostly highly-probable tokens, will exhibit a deviation distribution markedly skewed toward the negative side with very limited (if any) excursions into the positive side.

The proposed Typical-p modification can allow to obtain more human-like Conditional Entropy deviation distributions from language models compared to the original version

References


neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
# NOTE: We don't take the absolute value of the surprisal deviations
# This deviates from the original implementation
surprisal_deviations = neg_entropy - shifted_logits
Copy link

@BugReporterZ BugReporterZ Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with this change (or the intentions in the code here). The modification in my original hack (not posted in this PR) was intended to retain the basic behavior of Typical_P, which first sorts the surprisal deviations by their absolute value.

Only after this is done, then, using the signed surprisal deviations (copied into a different tensor before computing the absolute values for the other), you would obtain a second subset for extending the token selection as in the algorithm described in the explanation in the discussion.

@BugReporterZ
Copy link

Hopefully this graphical explanation further clarifies how the modified algorithm is supposed to work.

image

@BugReporterZ
Copy link

BugReporterZ commented Mar 21, 2024

Further testing over the past few days has revealed that,

[...] if the token having $-D$ deviation is an acceptable choice, then the one having $+D$ deviation should also be.

was actually not a fair assumption and may too easily lead to incoherent text. While lower values could be used with good results, like $+D\over 2$, that is probably still too arbitrary.

A saner approach to this (while still retaining an additional hyperparameter over the original algorithm) would be scaling the negative deviations by a factor (Lambda) that defaults to 1.0 (same behavior as the original algorithm), before they are converted to absolute values in the code.

By doing this, the tokens corresponding to the scaled deviations would be de-emphasized, and the token selection progressively skewed toward rarer tokens the higher Lambda is.

Below is a simulated example using Tau=0.35 and Lambda in the 1.0~2.0 range. It can be seen how at Lambda > 1 the selection (yellow boxes) skews more toward rarer tokens:

image

  • Advantages
    • The selection range scales up more progressively toward rarer tokens while retaining most of the original Typical-p behavior (where tokens are selected according to the configured cumulative probability / Tau).
    • The modified algorithm is simpler than the originally-proposed modification.
    • No arbitrary theoretical assumptions made.
  • Disadvantages
    • The oddly deterministic behavior of Typical-p at Tau=0.0 remains (the only token choice is the one with deviation=0, which is not the most likely token).
    • It would be desirable not to affect the deviation of control tokens like EOS, otherwise strange behavior may be observed (e.g. excessively long chat messages, etc). However this is a general problem that can affect other samplers as well.

I have code for the logic in sampler.py, but it's still based around the main branch. The portion for ignoring the first few tokens (UNK, EOS and BOS in Mistral and Llama tokenizers) could be ignored. Lambda would need to be plugged to an additional hyperparameter (which for the previous version in this PR branch was named Sigma).

def _apply_typical_sampling(
    logits: torch.Tensor,
    typical_p: torch.Tensor,
) -> torch.Tensor:
    typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
    shifted_logits = torch.log_softmax(logits, dim=-1)
    probs = shifted_logits.exp()

    # Tensor names and logic have been slightly altered to make the procedure a bit more consistent
    # with what the original Typical-P algorithm is intended to perform, thus clearer.
    surprises = shifted_logits.neg()
    conditional_entropy = (probs * surprises).nansum(dim=-1, keepdim=True)
    surprisal_deviations = surprises - conditional_entropy

    # Scale negative surprisal deviations by a scaling factor Lambda. This de-emphasizes
    # high-probability tokens. Lambda would ideally be an additional hyperparameter to typical-p.
    lambda_factor = 1.50
    lambda_mask = surprisal_deviations < 0

    # Don't affect special tokens (generally the first few tokens) by setting the mask to False
    # for them. This is an ugly hack; ideally we would want to identify such tokens more directly
    # as there is no guarantee that the first few tokens are special tokens or bytes (e.g. Qwen).
    tokens_to_ignore = 3
    lambda_mask[..., :tokens_to_ignore] = False

    # Actual scaling performed here.
    surprisal_deviations[lambda_mask] = surprisal_deviations[lambda_mask] * lambda_factor

    # From now on, the algorithm proceeds as in the original one, sorting tokens by absolute
    # surprisal deviations and picking them depending on their cumulative probability.
    surprisal_deviations = surprisal_deviations.abs()
    _, indices = torch.sort(surprisal_deviations)

    reordered_probs = probs.gather(-1, indices)
    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)

    min_tokens_to_keep = 1
    # Keep at least min_tokens_to_keep
    typ_mask_sorted[..., :min_tokens_to_keep] = 0

    typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
    logits[typ_mask] = -float("inf")
    return logits

@BugReporterZ
Copy link

BugReporterZ commented Mar 21, 2024

A different strategy with minimal modifications from the above could be, instead of scaling the positive deviations by a Lambda factor, shifting the entire deviations by a small Delta value.

image

Again, code that would apply to the main branch:

def _apply_typical_sampling(
    logits: torch.Tensor,
    typical_p: torch.Tensor,
) -> torch.Tensor:
    typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
    shifted_logits = torch.log_softmax(logits, dim=-1)
    probs = shifted_logits.exp()

    # Tensor names and logic have been slightly altered to make the procedure a bit more consistent
    # with what the original Typical-P algorithm is intended to perform, thus clearer.
    surprises = shifted_logits.neg()
    conditional_entropy = (probs * surprises).nansum(dim=-1, keepdim=True)
    surprisal_deviations = surprises - conditional_entropy

    # Shift surprisal deviations by a Delta value. This can both emphasize (-) or de-emphasize (+)
    # high-probability tokens. Delta would ideally be an additional hyperparameter to typical-p.
    # Note that small values in the 0.0~0.5 range are mostly useful here.
    delta = 0.20
    surprisal_deviations = surprisal_deviations - delta

    # From now on, the algorithm proceeds as in the original one, sorting tokens by absolute
    # surprisal deviations and picking them depending on their cumulative probability.
    surprisal_deviations = surprisal_deviations.abs()
    _, indices = torch.sort(surprisal_deviations)

    reordered_probs = probs.gather(-1, indices)
    typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)

    min_tokens_to_keep = 1
    # Keep at least min_tokens_to_keep
    typ_mask_sorted[..., :min_tokens_to_keep] = 0

    typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
    logits[typ_mask] = -float("inf")
    return logits

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants