Skip to content

Commit

Permalink
[Bugfix] use float32 precision in samplers/test_logprobs.py for compa…
Browse files Browse the repository at this point in the history
…ring with HF (vllm-project#6409)

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
tdoublep authored and Alvant committed Oct 26, 2024
1 parent 7fa5726 commit 0fd4d88
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype",
["float"]) # needed for comparing logprobs with HF
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False])
Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,12 @@ def context_attention_fwd(q,

cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64

# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if q.dtype is torch.float32:
BLOCK = BLOCK // 2

# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
Expand Down

0 comments on commit 0fd4d88

Please sign in to comment.