Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent results between different sequences with sequence lengths less than a single page size #725

Open
fergusfinn opened this issue Jan 8, 2025 · 4 comments

Comments

@fergusfinn
Copy link

fergusfinn commented Jan 8, 2025

Hi,

I might have misunderstood the API, but with the following script, i'd expect the output to be two identical output_states (in ragged format). But when i run it, the first output state (when unpacked from the ragged format) is non-zero, while the second is all zeros.

As i increase the number of pages per sequence (i.e. change paged_kv_indptr -> [0, 4, 8], for example) while keeping the page size constant, the two outputs converge. But if i keep a single (incomplete) page per sequence, i always get this non-zero/zero behaviour for all page sizes.

The same behaviour also persists with torch.float16 vs. bfloat.

Are sequences shorter than a single page size not supported for batched prefills?

import flashinfer
import torch
from torch import tensor

torch.manual_seed(42)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")

num_heads = 32
num_key_value_heads = 8
head_dim = 64
page_size = 2
num_pages = 128

qo_indptr = tensor([0, 1, 2], device="cuda:0", dtype=torch.int32)

paged_kv_indptr = tensor([0, 1, 2], device="cuda:0", dtype=torch.int32)
paged_kv_indices = torch.arange(paged_kv_indptr[-1], device="cuda:0", dtype=torch.int32)
paged_kv_last_page_len = tensor([1, 1], device="cuda:0")

prefill_wrapper.plan(
    qo_indptr=qo_indptr,
    paged_kv_indptr=paged_kv_indptr,
    paged_kv_indices=paged_kv_indices,
    paged_kv_last_page_len=paged_kv_last_page_len,
    num_qo_heads=num_heads,
    num_kv_heads=num_key_value_heads,
    head_dim=head_dim,
    page_size=page_size,
    q_data_type="bfloat16",
    causal=True,
)

k_cache = torch.randn((1, page_size, num_key_value_heads, head_dim), dtype=torch.bfloat16, device="cuda:0").repeat(
    num_pages, 1, 1, 1
)
v_cache = torch.randn((1, page_size, num_key_value_heads, head_dim), dtype=torch.bfloat16, device="cuda:0").repeat(
    num_pages, 1, 1, 1
)
query_states = torch.randn((1, num_heads, head_dim), dtype=torch.bfloat16, device="cuda:0").repeat(qo_indptr[-1], 1, 1)

outputs = prefill_wrapper.run(query_states, (k_cache, v_cache))

print(outputs[0]) # tensor[32, 64] f16 n=2048 (4Kb) x∈[-2.457, 2.725] μ=-0.048 σ=0.958 cuda:0
print(outputs[1]) # tensor[32, 64] f16 n=2048 (4Kb) all_zeros cuda:0
print(torch.allclose(outputs[0], outputs[1]))  # False
@yzh119
Copy link
Collaborator

yzh119 commented Jan 8, 2025

Hi @fergusfinn , I checked your script carefully and it turns out paged_kv_last_page_len = tensor([1, 1], device="cuda:0") is a tensor with data type int64 and was reinterpreted as int32 inside kernel (we should improve the error message).

Converting it to int32 would resolve the issue :)

@fergusfinn
Copy link
Author

wow thank you! I should have realised, i did the same thing previously for the qo_indptr, but got an illegal memory access instead.

Would it be possible to add something to the docs re. accepted integer datatypes? (i'm guessing its int32 for all of these 'pointer' tensors across the library?) Happy to open a PR if that's useful

@yzh119
Copy link
Collaborator

yzh119 commented Jan 9, 2025

Yes that should be a great idea!

i'm guessing its int32 for all of these 'pointer' tensors across the library

At kernel side we support kernels for idtype=int64 as well (though they are not compiled ahead-of-time as part of the wheel) but I think int32 is still the common practice at this moment.

@fergusfinn
Copy link
Author

Opened a PR here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants