You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is there any example code to do this? Should I generate new BlockMask everytime?
Thanks!
Essentially, I have problem of slicing BlockMask. For exmaple, if we have a prompt token of length 1001, I have the following codes for attention, which can be wrong. But, my question is if I need to generate 1001 token, how do I slice the exact position in the BlockMask for it?
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, create_mask
torch.set_default_device('cuda')
# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)
B, H, S = 1, 2, 5000 #5000 is the max model length
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
return causal_mask
causal = causal_mask
block_mask = create_block_mask(causal, B=B, H=H, Q_LEN=S, KV_LEN=S)
print('Mask Shape (Max Length): ', block_mask.shape)
# Input Prompt length is 1000, 3500 is the max token length we want
query = torch.randn(B, H, 1000, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 3500, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 3500, 64, device="cuda", dtype=torch.float32)
# slice block mask
q_slice = torch.arange(0, 3500//128 + 1).to(device)
block_mask = block_mask[:, :, q_slice]
print('Mask Shape (Sliced): ', block_mask.shape)
print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)
out = flex_attention(query, key, value, block_mask=block_mask)
print('Attention Output Shape:', out.shape)
block_mask = block_mask.to_string(limit=32,)
print(block_mask)
# Generate new token, then q length is 1, and position is 1000
q = torch.randn(1, H, 1, 64, device="cuda", dtype=torch.float32)
# The problem is how to select 1000 position at mask
Another question is that if I use Prefix Mask for token prompts, when I set H=None, it works; when I set H=H, it has errors.
# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)
B, H, S = 1, 2, 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'
full_attention_idx = torch.tensor([[0, 100]], dtype=torch.long).to(device)
def prefix_lm_causal_mask(b, h, q_idx, kv_idx):
full_mask = (kv_idx <= full_attention_idx[b][1]) & (kv_idx >= full_attention_idx[b][0])
causal_mask = q_idx >= kv_idx
return (full_mask | causal_mask)
prefix_lm_causal = prefix_lm_causal_mask
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=None, Q_LEN=S, KV_LEN=S)
print(block_mask.shape)
query = torch.randn(B, H, 100, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)
# slice block mask
q_slice = torch.arange(0, 350//128 + 1)
block_mask = block_mask[:, :, q_slice]
print(block_mask.shape)
print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)
out = flex_attention(query, key, value, block_mask=block_mask)
print('Attention Output Shape:', out.shape)
block_mask = block_mask.to_string(limit=32,)
print(block_mask)
When H=H
# Compile the flex_attention function
flex_attention = torch.compile(flex_attention)
B, H, S = 1, 2, 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'
full_attention_idx = torch.tensor([[0, 100]], dtype=torch.long).to(device)
def prefix_lm_causal_mask(b, h, q_idx, kv_idx):
full_mask = (kv_idx <= full_attention_idx[b][1]) & (kv_idx >= full_attention_idx[b][0])
causal_mask = q_idx >= kv_idx
return (full_mask | causal_mask)
prefix_lm_causal = prefix_lm_causal_mask
# In this case, our mask is different per sequence so we set B equal to our batch size
block_mask = create_block_mask(prefix_lm_causal, B=B, H=H, Q_LEN=S, KV_LEN=S)
print(block_mask.shape)
query = torch.randn(B, H, 100, 64, device="cuda", dtype=torch.float32)
key = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)
value = torch.randn(B, H, 350, 64, device="cuda", dtype=torch.float32)
# slice block mask
q_slice = torch.arange(0, 350//128 + 1)
block_mask = block_mask[:, :, q_slice]
print(block_mask.shape)
print('Query Shape:', query.shape)
print('Key Shape:', key.shape)
print('Value Shape:', value.shape)
out = flex_attention(query, key, value, block_mask=block_mask)
print('Attention Output Shape:', out.shape)
block_mask = block_mask.to_string(limit=32,)
print(block_mask)
Errors
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [0,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [1,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [0,32,0], thread: [2,0,0] Assertion `` failed.
...
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [55,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [56,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [57,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [58,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [59,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [60,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [61,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [62,0,0] Assertion `` failed.
/usr/local/lib/python3.10/dist-packages/torch/_inductor/runtime/compile_tasks.py:45: <module>: block: [1,20,0], thread: [63,0,0] Assertion `` failed.
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/tc289/project/p3i/test/attention.py", line 208, in <module>
block_mask = block_mask.to_string(limit=32,)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 513, in to_string
dense_mask = self.to_dense()
File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 500, in to_dense
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 173, in _ordered_to_dense
out = create_dense_batched(num_blocks_in_row, col_indices)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/attention/flex_attention.py", line 166, in create_dense_one
dense_mask[row_indices, valid_indices] = 1
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
return func(*args, **kwargs)
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
The text was updated successfully, but these errors were encountered:
Leo-T-Zang
changed the title
How to do KV Cache with FlexAttention and BlockMask?
How to do KV Cache with FlexAttention and BlockMask by slicing?
Oct 21, 2024
Is there any example code to do this? Should I generate new BlockMask everytime?
Thanks!
Essentially, I have problem of slicing BlockMask. For exmaple, if we have a prompt token of length 1001, I have the following codes for attention, which can be wrong. But, my question is if I need to generate 1001 token, how do I slice the exact position in the BlockMask for it?
Another question is that if I use Prefix Mask for token prompts, when I set
H=None
, it works; when I setH=H
, it has errors.When
H=H
Errors
The text was updated successfully, but these errors were encountered: