Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Writing to a globally scoped tensor from score_mod function #19

Open
jeffwillette opened this issue Aug 15, 2024 · 1 comment
Open

Writing to a globally scoped tensor from score_mod function #19

jeffwillette opened this issue Aug 15, 2024 · 1 comment

Comments

@jeffwillette
Copy link

jeffwillette commented Aug 15, 2024

I have a question about a use case of flex attention.

I am wondering if it is possible to write to some globally scoped tensor the way that the alibi bias example in the link above reads from a globally scoped tensor. From the flex attention intro page: https://pytorch.org/blog/flexattention/ alibi bias should be implemented as follows:

alibi_bias = generate_alibi_bias() # [num_heads]

def alibi(score, b, h, q_idx, kv_idx):
    bias = alibi_bias[h] * (q_idx - kv_idx)
    return score + bias

The example initializes a globally scoped tensor, and then reads from it within the score_mod function. But, say I wanted to retrieve all scores from the attention to plot the attention matrix, or maybe sum the columns of the attention matrix to modify a KV cache eviction policy, or some other use case that requires writing some function of the scores to a tensor.

Is it possible to accomplish this by writing to a globally globally scoped tensor in the same way that the alibi example can read from one? I tried the following, but it didn’t work. Is there a way to accomplish this with flex attention?

import torch
from torch.nn.attention.flex_attention import flex_attention

query = torch.randn(1, 8, 256, 128)
key = torch.randn(1, 8, 128, 128)
value = torch.randn(1, 8, 128, 128)

scores_out = torch.zeros(1, 8, 256, 128)

def score_mod(score, b, h, q_idx, kv_idx):
    scores_out[b, h, q_idx, kv_idx] = score
    return score

out = flex_attention(query, key, value, score_mod=score_mod)

print(out.size())
print(scores_out)

The code above fails with the following trace:

Traceback (most recent call last):
  File "/home/jeff/test/flex_attention.py", line 14, in <module>
    out = flex_attention(query, key, value, score_mod=noop)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/nn/attention/flex_attention.py", line 767, in flex_attention
    out, _ = torch.compile(
             ^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_dynamo/eval_frame.py", line 448, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_dynamo/external_utils.py", line 36, in inner
    @functools.wraps(fn)
  File "/home/jeff/python/pytorch/torch/_dynamo/eval_frame.py", line 615, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.3", line 24, in forward
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 60, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_ops.py", line 437, in __call__
    return wrapper()
           ^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_dynamo/eval_frame.py", line 615, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_ops.py", line 433, in wrapper
    return self_.dispatch(
           ^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_ops.py", line 417, in dispatch
    return kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 635, in flex_attention_autograd
    out, logsumexp = FlexAttentionAutogradOp.apply(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 528, in forward
    out, logsumexp = flex_attention(
                     ^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 60, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_ops.py", line 437, in __call__
    return wrapper()
           ^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_dynamo/eval_frame.py", line 615, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_ops.py", line 428, in wrapper
    return torch.overrides.handle_torch_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/overrides.py", line 1715, in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 37, in __torch_function__
    return func(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 60, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_ops.py", line 437, in __call__
    return wrapper()
           ^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_dynamo/eval_frame.py", line 615, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_ops.py", line 433, in wrapper
    return self_.dispatch(
           ^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_ops.py", line 417, in dispatch
    return kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 216, in sdpa_dense
    out, lse = math_attention(
               ^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 185, in math_attention
    _, post_mod_scores = _math_attention_inner(
                         ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 155, in _math_attention_inner
    score_mod(scores, b, h, m, n, *score_mod_other_buffers),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/apis.py", line 202, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/fx/graph_module.py", line 739, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/fx/graph_module.py", line 317, in __call__
    raise e
  File "/home/jeff/python/pytorch/torch/fx/graph_module.py", line 304, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeff/python/pytorch/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.0", line 5, in forward
  File "/home/jeff/python/pytorch/torch/_higher_order_ops/flex_attention.py", line 37, in __torch_function__
    return func(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.
@drisspg
Copy link
Contributor

drisspg commented Aug 15, 2024

Thanks for opening the issue, currently this isnt supported.

I have a hack that can get this to work in eager(no triton kerenls) and will break other score-mods: https://gist.github.com/drisspg/c66d79d51b5dd1895a552cef0820ba2e

I think that this is a valuable feature but will require some bigger changes. I will open up an issue on PT for tracking and keep you posted

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants