Skip to content

Commit

Permalink
remove paste, parametrize geen token
Browse files Browse the repository at this point in the history
jeffreyftang committed Feb 9, 2024
1 parent 4864539 commit 2c13886
Showing 1 changed file with 9 additions and 33 deletions.
42 changes: 9 additions & 33 deletions server/tests/models/test_causal_lm.py
Original file line number Diff line number Diff line change
@@ -134,36 +134,6 @@ def test_batch_from_pb(pb_batch, causal_lm_batch, request):
assert batch.max_input_length == batch.input_lengths[0]


def test_batch_with_schema_from_pb(default_pb_batch, default_causal_lm_batch):
batch = default_causal_lm_batch

assert batch.batch_id == default_pb_batch.id
assert batch.requests == default_pb_batch.requests

assert len(batch.input_ids) == default_pb_batch.size
assert batch.input_ids[0][-1] == 14402
assert torch.all(batch.input_ids[0][:-1] == 50256)

assert batch.attention_mask[0, 0] == 1
assert torch.all(batch.attention_mask[0, 1:] == 0)

assert batch.past_key_values is None

assert all(
[
torch.equal(input_ids, all_input_ids[:, 0])
for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)
]
)

assert batch.input_lengths == [1]

assert len(batch) == default_pb_batch.size
assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)

assert batch.max_input_length == batch.input_lengths[0]


def test_batch_concatenate_no_prefill(default_causal_lm_batch):
with pytest.raises(ValueError):
CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])
@@ -173,9 +143,15 @@ def test_causal_lm_batch_type(default_causal_lm):
assert default_causal_lm.batch_type == CausalLMBatch


def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
@pytest.mark.parametrize("causal_lm_batch", [
"default_causal_lm_batch",
"schema_constrained_causal_lm_batch",
])
def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, request):
causal_lm_batch = request.getfixturevalue(causal_lm_batch)

sequence_length = len(causal_lm_batch.all_input_ids[0])
generations, next_batch = default_causal_lm.generate_token(causal_lm_batch)

assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)

0 comments on commit 2c13886

Please sign in to comment.