Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do KV Cache with FlexAttention and BlockMask by slicing? #60

Open
Leo-T-Zang opened this issue Oct 21, 2024 · 0 comments
Open

How to do KV Cache with FlexAttention and BlockMask by slicing? #60

Leo-T-Zang opened this issue Oct 21, 2024 · 0 comments

Comments

@Leo-T-Zang
Copy link

Leo-T-Zang commented Oct 21, 2024

Is there any example code to do this? Should I generate new BlockMask everytime?

Thanks!


Essentially, I have problem of slicing BlockMask. For exmaple, if we have a prompt token of length 1001, I have the following codes for attention, which can be wrong. But, my question is if I need to generate 1001 token, how do I slice the exact position in the BlockMask for it?

import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, create_mask
torch.set_default_device('cuda')

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)

B, H, S = 1, 2, 5000 #5000 is the max model length

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def causal_mask(b, h, q_idx, kv_idx):

    causal_mask = q_idx >= kv_idx

    return causal_mask

causal = causal_mask

block_mask = create_block_mask(causal, B=B, H=H, Q_LEN=S, KV_LEN=S)
print('Mask Shape (Max Length): ', block_mask.shape)

# Input Prompt length is 1000, 3500 is the max token length we want
query = torch.randn(B, H, 1000, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 3500, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 3500, 64, device="cuda", dtype=torch.float32)

# slice block mask
q_slice = torch.arange(0, 3500//128 + 1).to(device)
block_mask = block_mask[:, :, q_slice]
print('Mask Shape (Sliced): ', block_mask.shape)

print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)

out = flex_attention(query, key, value, block_mask=block_mask)

print('Attention Output Shape:', out.shape)

block_mask = block_mask.to_string(limit=32,)
print(block_mask)

# Generate new token, then q length is 1, and position is 1000
q = torch.randn(1, H, 1, 64, device="cuda", dtype=torch.float32)
# The problem is how to select 1000 position at mask

Another question is that if I use Prefix Mask for token prompts, when I set H=None, it works; when I set H=H, it has errors.

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)

B, H, S = 1, 2, 500

device = 'cuda' if torch.cuda.is_available() else 'cpu'

full_attention_idx = torch.tensor([[0, 100]], dtype=torch.long).to(device)

def prefix_lm_causal_mask(b, h, q_idx, kv_idx):

    full_mask = (kv_idx <= full_attention_idx[b][1]) & (kv_idx >= full_attention_idx[b][0])
    causal_mask = q_idx >= kv_idx

    return (full_mask | causal_mask)

prefix_lm_causal = prefix_lm_causal_mask
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=None, Q_LEN=S, KV_LEN=S)
print(block_mask.shape)

query = torch.randn(B, H, 100, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)

# slice block mask
q_slice = torch.arange(0, 350//128 + 1)
block_mask = block_mask[:, :, q_slice]
print(block_mask.shape)

print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)

out = flex_attention(query, key, value, block_mask=block_mask)

print('Attention Output Shape:', out.shape)

block_mask = block_mask.to_string(limit=32,)
print(block_mask)

When H=H

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)

B, H, S = 1, 2, 500

device = 'cuda' if torch.cuda.is_available() else 'cpu'

full_attention_idx = torch.tensor([[0, 100]], dtype=torch.long).to(device)

def prefix_lm_causal_mask(b, h, q_idx, kv_idx):

    full_mask = (kv_idx <= full_attention_idx[b][1]) & (kv_idx >= full_attention_idx[b][0])
    causal_mask = q_idx >= kv_idx

    return (full_mask | causal_mask)

prefix_lm_causal = prefix_lm_causal_mask
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=H, Q_LEN=S, KV_LEN=S)
print(block_mask.shape)

query = torch.randn(B, H, 100, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)

# slice block mask
q_slice = torch.arange(0, 350//128 + 1)
block_mask = block_mask[:, :, q_slice]
print(block_mask.shape)

print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)

out = flex_attention(query, key, value, block_mask=block_mask)

print('Attention Output Shape:', out.shape)

block_mask = block_mask.to_string(limit=32,)
print(block_mask)

Errors

/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [0,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [1,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [2,0,0] Assertion `` failed.
...
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [55,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [56,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [57,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [58,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [59,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [60,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [61,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [62,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [63,0,0] Assertion `` failed.
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/tc289/project/p3i/test/attention.py", line 208, in <module>
    block_mask = block_mask.to_string(limit=32,)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 513, in to_string
    dense_mask = self.to_dense()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 500, in to_dense
    partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 173, in _ordered_to_dense
    out = create_dense_batched(num_blocks_in_row, col_indices)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 166, in create_dense_one
    dense_mask[row_indices, valid_indices] = 1
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
    return func(*args, **kwargs)
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


@Leo-T-Zang Leo-T-Zang changed the title How to do KV Cache with FlexAttention and BlockMask? How to do KV Cache with FlexAttention and BlockMask by slicing? Oct 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant