Skip to content

Commit

Permalink
Merge branch 'PygmalionAI:main' into amd-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Naomiusearch authored Nov 25, 2024
2 parents e888bab + 60f7b82 commit 531f7bc
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 8 deletions.
10 changes: 9 additions & 1 deletion aphrodite/common/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class SamplingParams(
input into sections where repetition is evaluated separately.
Common examples are newlines, quotes, and other structural tokens.
Defaults to None.
skew: Bias the token selection towards higher or lower probability
tokens. Defaults to 0 (disabled).
"""

n: int = 1
Expand Down Expand Up @@ -224,6 +226,7 @@ class SamplingParams(
dry_base: float = 1.75
dry_allowed_length: int = 2
dry_sequence_breaker_ids: List[int] = []
skew: float = 0.0
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
Expand Down Expand Up @@ -275,6 +278,7 @@ class SamplingParams(
"dry_base": 1.75,
"dry_allowed_length": 2,
"dry_sequence_breaker_ids": [],
"skew": 0.0,
}

def __post_init__(self) -> None:
Expand Down Expand Up @@ -419,7 +423,11 @@ def _verify_args(self) -> None:
if self.dry_allowed_length < 0:
raise ValueError(
"dry_allowed_length must be non-negative, got "
f"{self.dry_allowed_length}.")
f"{self.dry_allowed_length}.")
if self.skew < 0.0:
raise ValueError(
"skew must be non-negative, got "
f"{self.skew}.")

def _verify_beam_search(self) -> None:
if self.best_of == 1:
Expand Down
4 changes: 4 additions & 0 deletions aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
dynatemp_max: Optional[float] = 0.0
dynatemp_exponent: Optional[float] = 1.0
nsigma: Optional[float] = 0.0
skew: Optional[float] = 0.0
custom_token_bans: Optional[List[int]] = None
# doc: end-chat-completion-sampling-params

Expand Down Expand Up @@ -314,6 +315,7 @@ def to_sampling_params(
dynatemp_max=self.dynatemp_max,
dynatemp_exponent=self.dynatemp_exponent,
nsigma=self.nsigma,
skew=self.skew,
custom_token_bans=self.custom_token_bans,
)

Expand Down Expand Up @@ -432,6 +434,7 @@ class CompletionRequest(OpenAIBaseModel):
dynatemp_max: Optional[float] = 0.0
dynatemp_exponent: Optional[float] = 1.0
nsigma: Optional[float] = 0.0
skew: Optional[float] = 0.0
custom_token_bans: Optional[List[int]] = None
# doc: end-completion-sampling-params

Expand Down Expand Up @@ -547,6 +550,7 @@ def to_sampling_params(
dynatemp_max=self.dynatemp_max,
dynatemp_exponent=self.dynatemp_exponent,
nsigma=self.nsigma,
skew=self.skew,
custom_token_bans=self.custom_token_bans,
)

Expand Down
15 changes: 14 additions & 1 deletion aphrodite/modeling/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _init_sampling_tensors(
(sampling_tensors, do_penalties, do_no_repeat_ngrams, do_temperatures,
do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc, do_nsigmas,
do_dry, do_temp_last
do_dry, do_skew, do_temp_last
) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)

Expand All @@ -105,6 +105,7 @@ def _init_sampling_tensors(
self._do_xtc = do_xtc
self._do_nsgimas = do_nsigmas
self._do_dry = do_dry
self._do_skew = do_skew
self._do_temp_last = do_temp_last

def forward(
Expand Down Expand Up @@ -146,6 +147,7 @@ def forward(
do_xtc = self._do_xtc
do_nsigmas = self._do_nsgimas
do_dry = self._do_dry
do_skew = self._do_skew
do_temp_last = self._do_temp_last

logits = _apply_min_tokens_penalty(logits, sampling_metadata)
Expand Down Expand Up @@ -230,6 +232,17 @@ def forward(
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)

# skew needs to be applied post-softmax
if do_skew:
# reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
cum_probs = torch.cumsum(probs, dim=-1)
cum_probs = torch.pow(cum_probs, torch.exp(
sampling_tensors.skews).unsqueeze(dim=1))
probs = torch.diff(cum_probs, dim=-1,
prepend=torch.zeros_like(cum_probs[..., :1]))
logits = torch.log(probs)

# Compute the log probabilities.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

Expand Down
25 changes: 19 additions & 6 deletions aphrodite/modeling/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class SamplingTensors:
dry_bases: torch.Tensor
dry_allowed_lengths: torch.Tensor
dry_sequence_breaker_ids: torch.Tensor
skews: torch.Tensor
sampling_seeds: torch.Tensor
sample_indices: torch.Tensor
extra_seeds: Optional[torch.Tensor]
Expand All @@ -410,7 +411,7 @@ def from_sampling_metadata(
extra_seeds_to_generate: int = 0,
extra_entropy: Optional[Tuple[int, ...]] = None
) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
bool, bool, bool, bool, bool, bool, bool, bool]:
bool, bool, bool, bool, bool, bool, bool, bool, bool]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
Expand Down Expand Up @@ -446,6 +447,7 @@ def from_sampling_metadata(
dry_bases: List[float] = []
dry_allowed_lengths: List[int] = []
dry_sequence_breaker_ids: List[List[int]] = []
skews: List[float] = []

do_penalties = False
do_no_repeat_ngrams = False
Expand All @@ -461,6 +463,7 @@ def from_sampling_metadata(
do_xtc = False
do_nsigmas = False
do_dry = False
do_skews = False
do_temp_last = False

if _USE_TRITON_SAMPLER:
Expand Down Expand Up @@ -506,6 +509,7 @@ def from_sampling_metadata(
do_xtc |= params.xtc_probability > _SAMPLING_EPS
do_nsigmas |= params.nsigma > _SAMPLING_EPS
do_dry |= params.dry_multiplier > _SAMPLING_EPS
do_skews |= abs(params.skew) > _SAMPLING_EPS

do_temp_last |= params.temperature_last

Expand Down Expand Up @@ -548,6 +552,7 @@ def from_sampling_metadata(
dry_allowed_lengths += [params.dry_allowed_length] * n_seqs
dry_sequence_breaker_ids += (
[params.dry_sequence_breaker_ids] * n_seqs)
skews += [params.skew] * n_seqs

if _USE_TRITON_SAMPLER:
if is_prompt:
Expand Down Expand Up @@ -596,13 +601,14 @@ def from_sampling_metadata(
no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs,
typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds,
xtc_probabilities, nsigmas, dry_multipliers, dry_bases,
dry_allowed_lengths, dry_sequence_breaker_ids, sampling_seeds,
sample_indices, prompt_tokens, output_tokens, vocab_size,
extra_seeds_to_generate, device, dtype)
dry_allowed_lengths, dry_sequence_breaker_ids, skews,
sampling_seeds, sample_indices, prompt_tokens, output_tokens,
vocab_size, extra_seeds_to_generate, device, dtype)
return (sampling_tensors, do_penalties, do_no_repeat_ngrams,
do_temperatures, do_top_p_top_k, do_top_as, do_min_p,
do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
do_quadratic, do_xtc, do_nsigmas, do_dry, do_temp_last)
do_quadratic, do_xtc, do_nsigmas, do_dry, do_skews,
do_temp_last)

@classmethod
def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
Expand All @@ -620,7 +626,7 @@ def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
dry_multipliers: List[float], dry_bases: List[float],
dry_allowed_lengths: List[int],
dry_sequence_breaker_ids: List[List[int]],
sampling_seeds: List[List[int]],
skews: List[float], sampling_seeds: List[List[int]],
sample_indices: List[int], prompt_tokens: List[array],
output_tokens: List[array], vocab_size: int,
extra_seeds_to_generate: int, device: torch.device,
Expand Down Expand Up @@ -786,6 +792,12 @@ def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
dtype=torch.long,
pin_memory=pin_memory,
)
skews_t = torch.tensor(
skews,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)

sample_indices_t = torch.tensor(
sample_indices,
Expand Down Expand Up @@ -853,6 +865,7 @@ def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
non_blocking=True),
dry_sequence_breaker_ids=dry_sequence_breakers_t.to(device=device,
non_blocking=True),
skews=skews_t.to(device=device, non_blocking=True),
typical_ps=typical_ps_t.to(device=device, non_blocking=True),
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
output_tokens=output_t.to(device=device, non_blocking=True),
Expand Down
53 changes: 53 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,59 @@ def test_sampler_nsigma(seed: int, device: str):
"Top-nsigma sampling is not deterministic with same seed"


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_skew(seed: int, device: str):
"""Test that skew sampling behaves as expected."""
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)

high_prob_tokens = {}
for i in range(batch_size):
# Make token i have a much higher logit in sequence i
fake_logits[i, i] = 10.0
high_prob_tokens[i] = i

test_cases = [
# (skew, expected_behavior)
(2.0, "low"), # Strong bias away from high probability tokens
(0.5, "subtle"), # Subtle bias away from high probability tokens
(0.0, "neutral"), # No bias (regular sampling)
]

for skew, expected_behavior in test_cases:
sampling_params = SamplingParams(
temperature=1.0, # neutral temperature
skew=skew,
seed=random.randint(0, 10000), # for determinism
)

sampler_output = _do_sample(batch_size, fake_logits.clone(), sampler,
sampling_params, device)

for batch_idx, sequence_output in enumerate(sampler_output):
token_id = sequence_output.samples[0].output_token

if expected_behavior == "low":
# strong skew should bias away from high probability tokens
assert token_id != high_prob_tokens[batch_idx], \
f"With high skew {skew}, should not select high " \
f"probability token {high_prob_tokens[batch_idx]}"

elif expected_behavior == "subtle":
# we don't assert anything for subtle effect,
# as it's probabilistic
pass

# determinism
second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
sampling_params, device)
assert sampler_output == second_output, \
f"Skew sampling with seed is not deterministic for skew={skew}"


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
set_random_seed(42)
Expand Down

0 comments on commit 531f7bc

Please sign in to comment.