Skip to content

Commit

Permalink
samplers/test_logprobs.py: use float32 precision.
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
tdoublep committed Jul 13, 2024
1 parent d80aef3 commit e3b40a7
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype",
["float"]) # needed for comparing logprobs with HF
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False])
Expand Down

0 comments on commit e3b40a7

Please sign in to comment.