Skip to content

Commit

Permalink
fix one token ID
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyftang committed Feb 10, 2024
1 parent 2c13886 commit ad78616
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions server/tests/models/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def test_causal_lm_batch_type(default_causal_lm):
assert default_causal_lm.batch_type == CausalLMBatch


@pytest.mark.parametrize("causal_lm_batch", [
"default_causal_lm_batch",
"schema_constrained_causal_lm_batch",
@pytest.mark.parametrize("causal_lm_batch, generated_token_id", [
("default_causal_lm_batch", 13),
("schema_constrained_causal_lm_batch", 90),
])
def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, request):
def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, generated_token_id, request):
causal_lm_batch = request.getfixturevalue(causal_lm_batch)

sequence_length = len(causal_lm_batch.all_input_ids[0])
Expand All @@ -159,7 +159,10 @@ def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, request):
assert len(next_batch.all_input_ids) == len(next_batch)
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
assert len(next_batch.attention_mask[0]) == 11
assert next_batch.all_input_ids[0][-1] == 13
assert next_batch.all_input_ids[0][-1] == generated_token_id

print(f"\n\ngen_token: {default_causal_lm.tokenizer.decode(next_batch.all_input_ids[0][-1])}")

assert next_batch.all_input_ids[0][-2] == 14402
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)

Expand Down

0 comments on commit ad78616

Please sign in to comment.