Skip to content

Commit

Permalink
Merge pull request #868 from PygmalionAI/dry_zoom
Browse files Browse the repository at this point in the history
Rewrite DRY sampler to be a lot faster
  • Loading branch information
50h100a authored Dec 4, 2024
2 parents e182d00 + fc3c1cd commit 9b56927
Showing 1 changed file with 60 additions and 81 deletions.
141 changes: 60 additions & 81 deletions aphrodite/modeling/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,92 +631,71 @@ def _apply_dry(
Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
"""
if torch.all(multipliers == 0):
return logits

# DRY needs to be applied to both input AND output tokens
input_ids = torch.cat((input_token_ids, output_token_ids), dim=1)
vocab_size = logits.size(-1)

def compute_z_array(s: List[int], end: int, search_start: int) -> List[int]:
"""
Compute Z array using two-pointer technique for linear time complexity
"""
z = [0] * len(s)
right = end - 1
left = end - 1

while right >= search_start:
while left == right and left >= search_start:
if s[right] == s[end]:
break
right -= 1
left -= 1

while left >= search_start and s[left] == s[end - (right - left)]:
z[right] += 1
left -= 1

helper = right
while right > left:
right -= 1
if left == right:
break
z[right] = min(z[end - (helper - right)], right - left)
if left >= search_start and right - z[right] <= left:
break

return z

# Process each sequence in the batch
for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)):
multiplier = multipliers[i].item()
if multiplier == 0:
continue

seq_breakers = set(sequence_breakers_ids[i].tolist())
input_ids_list = input_ids_row.tolist()
last_token = input_ids_list[-1]

if last_token in seq_breakers:
VOCAB_SIZE = logits.size(-1)
MAX_NGRAM = 100

# Process each sequence that has a nonzero multiplier
applies_to = multipliers.nonzero(as_tuple=True)[0]
for irow in applies_to.tolist():
# DRY applies to input AND output tokens, so concat w/o padding tokens
prompt_len = (len(input_token_ids[irow]) -
(input_token_ids[irow] == VOCAB_SIZE).sum().item())
ouput_len = (len(output_token_ids[irow]) -
(output_token_ids[irow] == VOCAB_SIZE).sum().item())
token_seq = torch.cat((input_token_ids[irow][:prompt_len],
output_token_ids[irow][:ouput_len]), dim=0)

range_limit = ranges[irow].item()
if range_limit: # could this be done in cat? yes. do i care? no.
token_seq = token_seq[-range_limit:]

last_token = token_seq[-1].item()
if last_token in sequence_breakers_ids[irow]:
continue # early out for everything up to the min_ngram check?

# Build a mask of all the breaking tokens in the context
break_mask = torch.zeros(len(token_seq), dtype=torch.bool,
device=logits.device)
for break_tok in sequence_breakers_ids[irow]:
break_mask.logical_or_(token_seq == break_tok)

# Find the most recent breaking token (sets ngram limit)
max_ngram = 0
for max_ngram in range(min(len(break_mask), MAX_NGRAM + 1)):
if break_mask[-max_ngram - 1]:
break

min_ngram = allowed_lengths[irow].item()
if max_ngram <= min_ngram: # Too close to a break to match anything
continue

range_limit = ranges[i].item()
if range_limit == 0:
search_start = 0
else:
search_start = max(0, len(input_ids_list) - range_limit)

# Find max match length based on sequence breakers
max_match_length = 0
MAX_LENGTH = min(len(input_ids_list), 1000) # Prevent overflow
while (max_match_length < MAX_LENGTH and
input_ids_list[len(input_ids_list) - max_match_length - 1]
not in seq_breakers):
max_match_length += 1

z_array = compute_z_array(
input_ids_list, len(input_ids_list) - 1, search_start)

# If [token] is picked, what's the longest ngram that would match?
ngram_lens = torch.zeros(VOCAB_SIZE, dtype=torch.int32,
device=logits.device)

z_array = [min(length, max_match_length) for length in z_array]

penalties = {}
allowed_length = allowed_lengths[i]
base = bases[i]

for idx, match_length in enumerate(z_array[:-1]):
if match_length >= allowed_length:
next_token = input_ids_list[idx + 1]
if (next_token >= vocab_size or next_token in
seq_breakers):
continue
# Find all instances of the last token- potential ngrams!
endpoint_indexes = torch.nonzero(token_seq == last_token,
as_tuple=True)[0].tolist()
# NOTE: This seems like the slow part. Haven't benchmarked.
for idx in endpoint_indexes[:-1]: # Skip the last_token match
unwind = 0
# Check up to max_ngram tokens prior to idx (we know idx matches)
for unwind in range(1, min(idx, max_ngram) + 1):
if break_mask[idx - unwind]:
break
if token_seq[idx - unwind] != token_seq[-unwind - 1]:
break
next_tok = token_seq[idx+1]
# The repeated tokens BEFORE next_tok (+1 to include [idx]).
ngram_lens[next_tok] = max(ngram_lens[next_tok].item(), unwind + 1)

penalty = multiplier * (base ** (match_length - allowed_length))
penalties[next_token] = max(
penalty, penalties.get(next_token, 0))
# Convert ngram lengths to penalty exponents
penalty_mask = ngram_lens > 0
scales = bases[irow] ** (ngram_lens[penalty_mask] - min_ngram)

for token, penalty in penalties.items():
logits_row[token] -= penalty
# Calculate and apply penalties
logits[irow][penalty_mask] -= multipliers[irow] * scales

return logits

Expand Down

0 comments on commit 9b56927

Please sign in to comment.