Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Adding min tokens/repetition/presence/frequence penalties to V1 sampler #10681

Merged
merged 118 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from 116 commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 Jun 17, 2024
b3575e9
Merge branch 'vllm-project:main' into main
sroy745 Jun 20, 2024
94b0d43
Merge branch 'vllm-project:main' into main
sroy745 Jun 24, 2024
fa8fedf
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
6ed96b4
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
b71c533
Merge branch 'vllm-project:main' into main
sroy745 Jun 28, 2024
57babef
Merge branch 'vllm-project:main' into main
sroy745 Jun 29, 2024
4b19bac
Merge branch 'vllm-project:main' into main
sroy745 Jul 1, 2024
eb7a1c4
Merge branch 'vllm-project:main' into main
sroy745 Jul 6, 2024
7e2c87e
Merge branch 'vllm-project:main' into main
sroy745 Jul 10, 2024
6212d5f
Merge branch 'vllm-project:main' into main
sroy745 Jul 15, 2024
5491438
Merge branch 'vllm-project:main' into main
sroy745 Jul 17, 2024
68e080a
Merge branch 'vllm-project:main' into main
sroy745 Jul 31, 2024
55e4332
Merge branch 'vllm-project:main' into main
sroy745 Aug 13, 2024
532eb48
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
7cea056
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
185e056
Merge branch 'vllm-project:main' into main
sroy745 Aug 24, 2024
e2be95f
Merge branch 'vllm-project:main' into main
sroy745 Aug 27, 2024
2ed5473
Merge branch 'vllm-project:main' into main
sroy745 Aug 28, 2024
efa4714
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
fb87d34
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
5419e49
Merge branch 'vllm-project:main' into main
sroy745 Aug 31, 2024
9ba12f8
Merge branch 'vllm-project:main' into main
sroy745 Sep 2, 2024
25cef3d
Merge branch 'vllm-project:main' into main
sroy745 Sep 3, 2024
9d4cd09
Merge branch 'vllm-project:main' into main
sroy745 Sep 4, 2024
c48cacb
Merge branch 'vllm-project:main' into main
sroy745 Sep 5, 2024
c42c399
Merge branch 'vllm-project:main' into main
sroy745 Sep 7, 2024
3d13e43
Merge branch 'vllm-project:main' into main
sroy745 Sep 9, 2024
7479775
Merge branch 'vllm-project:main' into main
sroy745 Sep 11, 2024
df9b966
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
9a7ed92
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
118e838
Merge branch 'vllm-project:main' into main
sroy745 Sep 19, 2024
e640c69
Merge branch 'vllm-project:main' into main
sroy745 Sep 20, 2024
89fb6cd
Merge branch 'vllm-project:main' into main
sroy745 Sep 23, 2024
5d886cc
Merge branch 'vllm-project:main' into main
sroy745 Sep 24, 2024
56f2065
Merge branch 'vllm-project:main' into main
sroy745 Sep 24, 2024
28e103e
Merge branch 'vllm-project:main' into main
sroy745 Sep 25, 2024
2fc1490
Merge branch 'vllm-project:main' into main
sroy745 Sep 25, 2024
8805750
Merge branch 'vllm-project:main' into main
sroy745 Sep 26, 2024
b30e5af
Merge branch 'vllm-project:main' into main
sroy745 Sep 28, 2024
92322f1
Merge branch 'vllm-project:main' into main
sroy745 Sep 30, 2024
85e9001
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
cd4ff89
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
0dd96ed
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
9d4d969
Merge branch 'vllm-project:main' into main
sroy745 Oct 3, 2024
7d223b5
Merge branch 'vllm-project:main' into main
sroy745 Oct 5, 2024
f327d91
Merge branch 'vllm-project:main' into main
sroy745 Oct 5, 2024
b5adf28
Merge branch 'vllm-project:main' into main
sroy745 Oct 6, 2024
caf0d12
Merge branch 'vllm-project:main' into main
sroy745 Oct 7, 2024
28e77b1
Merge branch 'vllm-project:main' into main
sroy745 Oct 8, 2024
db7e46d
Merge branch 'vllm-project:main' into main
sroy745 Oct 9, 2024
59b35f0
Merge branch 'vllm-project:main' into main
sroy745 Oct 17, 2024
dd9affa
Merge branch 'vllm-project:main' into main
sroy745 Oct 17, 2024
f61a15d
Merge branch 'vllm-project:main' into main
sroy745 Oct 21, 2024
0569773
Merge branch 'vllm-project:main' into main
sroy745 Oct 27, 2024
a2090e0
Merge branch 'vllm-project:main' into main
sroy745 Oct 30, 2024
c9a3f00
Merge branch 'vllm-project:main' into main
sroy745 Nov 1, 2024
b59e6a8
Merge branch 'vllm-project:main' into main
sroy745 Nov 3, 2024
fd9fdff
Merge branch 'vllm-project:main' into main
sroy745 Nov 8, 2024
366cbf7
Merge branch 'vllm-project:main' into main
sroy745 Nov 11, 2024
65c9c79
Merge branch 'vllm-project:main' into main
sroy745 Nov 22, 2024
840c89d
Merge branch 'vllm-project:main' into main
sroy745 Nov 26, 2024
8700ecb
Merge branch 'vllm-project:main' into main
sroy745 Nov 26, 2024
8136fa4
Add options for min_tokens/repetition etc penalties to V1 sampler
sroy745 Nov 26, 2024
06d3247
Fixes
sroy745 Dec 1, 2024
b73e3be
Merge branch 'vllm-project:main' into main
sroy745 Dec 1, 2024
ca0313a
Add tests
sroy745 Dec 2, 2024
e19f99b
Add tests
sroy745 Dec 2, 2024
40f4ce2
Fix format
sroy745 Dec 2, 2024
d3e9bb7
Comments
sroy745 Dec 2, 2024
e3468fe
Tests
sroy745 Dec 2, 2024
47c4b74
Fixes
sroy745 Dec 3, 2024
35ac8bc
Fix tests
sroy745 Dec 3, 2024
cce8428
Fixes
sroy745 Dec 3, 2024
9febfbf
Fixes
sroy745 Dec 5, 2024
dc02a4f
Fixes
sroy745 Dec 5, 2024
b3c4472
Merge branch 'vllm-project:main' into main
sroy745 Dec 5, 2024
d0348a4
Merge remote-tracking branch 'origin/main' into sroy-v1-sampling
sroy745 Dec 5, 2024
111ff87
Merge branch 'vllm-project:main' into main
sroy745 Dec 9, 2024
2e1f781
Merge remote-tracking branch 'origin/main' into sroy-v1-sampling
sroy745 Dec 9, 2024
0db8e4f
Addressing comments
sroy745 Dec 9, 2024
f6c416f
Remove debug prints
sroy745 Dec 9, 2024
034ff3f
Adding utils
sroy745 Dec 9, 2024
3798152
Fixes
sroy745 Dec 9, 2024
00ec978
More fixes
sroy745 Dec 9, 2024
cf87280
Format
sroy745 Dec 9, 2024
bde6c9e
Some more tests
sroy745 Dec 9, 2024
a46cd14
Format
sroy745 Dec 9, 2024
750dcd8
Merge branch 'main' into sroy-v1-sampling
sroy745 Dec 10, 2024
795b1f8
Merge
sroy745 Dec 10, 2024
9a1ab49
Rename test file
sroy745 Dec 10, 2024
6472b70
Tests
sroy745 Dec 10, 2024
239a3fd
Changes
sroy745 Dec 10, 2024
0e3179a
Comments
sroy745 Dec 10, 2024
2457596
Merge branch 'main' into sroy-v1-sampling
WoosukKwon Dec 15, 2024
09a73d0
Only pass output token_ids to sampler
sroy745 Dec 17, 2024
abda623
Dummy
sroy745 Dec 17, 2024
c1d6cd1
Rerun tests
sroy745 Dec 17, 2024
6861e97
Remove tolist for prompts
sroy745 Dec 17, 2024
c5ab213
Add TODO
sroy745 Dec 17, 2024
c79fad5
Dummy
sroy745 Dec 18, 2024
b3f7736
Dummy
sroy745 Dec 18, 2024
c74b9bb
Rerun tests
sroy745 Dec 20, 2024
31ba41f
Rerun tests
sroy745 Dec 20, 2024
c7d5318
Merge branch 'main' into sroy-v1-sampling
WoosukKwon Dec 21, 2024
a781c11
Minor
WoosukKwon Dec 21, 2024
6bc8e01
Minor
WoosukKwon Dec 21, 2024
0912e3e
Minor
WoosukKwon Dec 21, 2024
5dd4caa
optimize
WoosukKwon Dec 21, 2024
12ab994
Make prompt_token_ids a class variable
sroy745 Dec 22, 2024
03190e1
Merge branch 'vllm-project:main' into sroy-v1-sampling
sroy745 Dec 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,41 @@ def test_engine_core(monkeypatch):
engine_core.abort_requests([req2.request_id, req0.request_id])
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0


def test_engine_core_advanced_sampling(monkeypatch):
"""
A basic end-to-end test to verify that the engine functions correctly
when additional sampling parameters, such as min_tokens and
presence_penalty, are set.
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
usage_context=UsageContext.UNKNOWN_CONTEXT)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()
request.sampling_params = SamplingParams(
min_tokens=4,
presence_penalty=1.0,
frequency_penalty=1.0,
repetition_penalty=0.1,
stop_token_ids=[1001, 1002],
)
engine_core.add_request(request)
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done.
while len(engine_core.step()) > 0:
pass

assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
Empty file added tests/v1/sample/__init__.py
Empty file.
331 changes: 331 additions & 0 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
from typing import List, Set, Tuple

import numpy as np
import pytest
import torch

from vllm.utils import make_tensor_with_pad
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler

VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64


def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
return fake_logits


def _create_penalty_tensor(batch_size: int, penalty_value: float,
device: torch.device) -> torch.Tensor:
return torch.full((batch_size, ),
fill_value=penalty_value,
dtype=torch.float,
device=device)


def _create_prompt_tokens_tensor(
prompt_token_ids: List[List[int]],
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
return make_tensor_with_pad(
prompt_token_ids,
pad=vocab_size,
device=device,
dtype=torch.int64,
pin_memory=False,
)


def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
vocab_size: int,
device: torch.device,
) -> SamplingMetadata:
output_token_ids: List[List[int]] = []
prompt_token_ids: List[List[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
prompt_token_ids.append(
np.random.randint(0,
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
top_p=torch.empty(batch_size, ),
top_k=torch.empty(batch_size, ),
no_top_p=True,
no_top_k=True,
generators={},
max_num_logprobs=VOCAB_SIZE,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
)
return fake_sampling_metadata


def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
batch.

If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens.append(
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens))
stop_token_ids.append(
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))

else:
min_tokens.append(np.random.randint(0, num_output_tokens))
stop_token_ids.append(set())
return (min_tokens, stop_token_ids)


def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
"""
Creates an output token list where each token occurs a distinct
number of times.

For each batch, a random subset of token IDs is selected from the
vocabulary. The selected tokens are then added to the output token
list, each with a different frequency.

Returns:
Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
- The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output
list.
"""
output_token_ids: List[List[int]] = []
sorted_token_ids_in_output: List[List[int]] = []
for _ in range(batch_size):
distinct_token_ids = np.random.choice(vocab_size,
size=np.random.randint(1, 10),
replace=False).tolist()
sorted_token_ids_in_output.append(distinct_token_ids)
output_token_ids_for_batch = []
for index, token_id in enumerate(distinct_token_ids):
output_token_ids_for_batch.extend(
[token_id for _ in range(index + 1)])
output_token_ids.append(output_token_ids_for_batch)
return (output_token_ids, sorted_token_ids_in_output)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
for vocab in range(VOCAB_SIZE):
# Verify that the logprobs for stop token ids is set
# to -inf.
logprob_index = torch.where(
sampler_output.logprob_token_ids[batch_idx] ==
vocab)[0].item()
if vocab in stop_token_ids[batch_idx]:
assert sampler_output.logprobs[batch_idx][
logprob_index] == -float("inf")
else:
assert sampler_output.logprobs[batch_idx][
logprob_index] != -float("inf")


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
def test_sampler_presence_penalty(device: str, batch_size: int,
presence_penalty: float):
"""
Test to verify that if presence penalty is enabled then tokens
are penalized as per their presence in the existing output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
output_token_ids = sampling_metadata.output_token_ids
sampling_metadata.presence_penalties = _create_penalty_tensor(
batch_size, presence_penalty, torch.device(device))
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
# The logprobs in the SamplerOutput are arranged in descending order.
# Since all tokens initially have the same logprobs, the non-penalized
# tokens will appear at the beginning, while the penalized tokens
# will appear at the end of the list.
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
VOCAB_SIZE - 1]
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
assert non_penalized_log_prod > penalized_log_prod
if presence_penalty > 0:
# If `presence_penalty` is set to a value greater than 0, it
# indicates a preference for new tokens over those already
# present in the output.
# Verify that the penalized token ID exists in the output, while the
# non-penalized token ID does not.
assert penalized_token_id in output_token_ids[batch_idx]
assert non_penalized_token_id not in output_token_ids[batch_idx]
elif presence_penalty < 0:
# If `presence_penalty` is set to a value less than 0, it indicates
# a preference for existing tokens over new ones. Verify that the
# non-penalized token ID exists in the output, while the penalized
# token ID does not.
assert non_penalized_token_id in output_token_ids[batch_idx]
assert penalized_token_id not in output_token_ids[batch_idx]


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
def test_sampler_frequency_penalty(device: str, batch_size: int,
frequency_penalty: float):
"""
Test to verify that if frequency penalty is enabled then tokens are
penalized as per their frequency of occurrence.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.frequency_penalties = _create_penalty_tensor(
batch_size, frequency_penalty, torch.device(device))
output_token_ids, sorted_token_ids_in_output = \
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
non_penalized_token_id = logprobs_token_ids[0]
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
distinct_sorted_token_ids_in_output = \
sorted_token_ids_in_output[batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[
len(distinct_sorted_token_ids_in_output) - 1]
if frequency_penalty > 0:
# If `frequency_penalty` is set to > 0, it indicates
# a preference for new tokens over existing ones. Verify that the
# non-penalized token ID is not present in the output, while the
# most penalized token is the one that occurs most frequently in
# the output.
assert non_penalized_token_id \
not in distinct_sorted_token_ids_in_output
assert penalized_token_id == most_frequent_token_id
elif frequency_penalty < 0:
# If `frequency_penalty` is set to < 0, it indicates
# a preference for existing tokens over new ones. Verify that the
# non-penalized token ID is the one that occurs most frequently
# in the output, while the penalized token ID is one that has not
# yet appeared.
assert non_penalized_token_id == most_frequent_token_id
assert penalized_token_id \
not in distinct_sorted_token_ids_in_output


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
def test_sampler_repetition_penalty(device: str, batch_size: int,
repetition_penalty: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.repetition_penalties = _create_penalty_tensor(
batch_size, repetition_penalty, torch.device(device))
sampling_metadata.no_penalties = False
sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata)
for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
non_penalized_token_id = logprobs_token_ids[0]
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
prompt_tokens = sampling_metadata.prompt_token_ids[
batch_idx][:].tolist()
output_tokens = sampling_metadata.output_token_ids[batch_idx]
if repetition_penalty > 1.0:
# If `repetition_penalty` > 1.0, verify that the non-penalized
# token ID has not been seen before, while the penalized token ID
# exists either in the prompt or the output.
assert (non_penalized_token_id not in prompt_tokens and \
non_penalized_token_id not in output_tokens)
assert (penalized_token_id in prompt_tokens or \
penalized_token_id in output_tokens)
elif repetition_penalty < 1.0:
# If `repetition_penalty` < 1.0, verify that the penalized
# token ID has not been seen before, while the non-penalized
# token ID exists either in the prompt or the output.
assert (penalized_token_id not in prompt_tokens and \
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)
Empty file added tests/v1/worker/__init__.py
Empty file.
Loading
Loading