diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 02a953da04659..f7bcd4c855799 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -11,7 +11,8 @@ @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", + ["float"]) # needed for comparing logprobs with HF @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) @pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size @pytest.mark.parametrize("detokenize", [True, False]) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 70b544b608e29..4577d84db18ac 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -687,6 +687,12 @@ def context_attention_fwd(q, cap = current_platform.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 + + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + if q.dtype is torch.float32: + BLOCK = BLOCK // 2 + # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv