Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement dot_product_attention #455
implement dot_product_attention #455
Changes from 12 commits
bf64ca8
9da0005
2193639
eabcc02
aac281d
4d5a6d9
5a5c58b
19d377a
10e99c7
e61909c
43632ee
958171b
df8aa9b
09ac33b
d17de5e
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a cleaner API would be to let the
mask
keyword be a function. Thenothing
case ismask = identity
and the causal case ismask = make_causal_mask
(which I feel should be justcausal_mask
to be succinct).Is there a reason to construct the mask on the fly? The calling layer in Flux can probably make and store the mask once. Then the other option is to allow
nothing
or an array. Then the user passes inmask = causal_mask(ntoken)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the function which you pass, in this proposal?
mask = identity
means this is applied to the array.mask = make_causal_mask
means it constructs a boolean matrix.Agree that constructing the same matrix every time seems a bit wasteful, although probably not a big cost, there are quite a few larger copies made in this thing.
With
mask = identity
, the usual masking could becausal_mask!
which is basicallyfor i,j in ...; if i<j; x[i,j] = -Inf end;
i.e. it just mutates the data array. This should be safe as the gradient of batched_mul does not need the original values.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it shouldn't be
identity
, it should betrues_like
though I'd be okay withnothing
in order skip computing a mask at all.My comment about constructing on the fly was not a performance concern. I just think it is more intuitive to pass in exactly the mask array I want used. It's an easier rule to remember and also scalable to whatever masking scheme is desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The downside is that you have to make an array the right size. If you have several layers and want the same scheme for each, then perhaps it's a pain. Whereas a function like
trues_like
is told the size automatically.(The implementation can branch on
mask === trues_like
to avoid work in the default case. We can also branch on the type ofconst causal_mask = triu ∘ trues_like
if necc.)While encoding this as a bool array makes some sense, it's also a little weird in that the implementation doesn't directly consume this. Maybe better than my mutating idea above, we can modify
softmax
to take a mask argument, and fuse it into the broadcast there, I think.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, but generally the size of this matrix which is
# of tokens X # of tokens
is known ahead of time. Even so, I agree that not needing to pass in this info is cleaner.I mostly wanted to avoid "symbol switches" for arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes to avoiding symbols. I like this
mask = trues_like
proposal the best so far.One question I haven't looked at is what format the CUDNN thing is going to want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of saying
mask
is either an array or callable, could we say it should be either an array or marker type for which one can override someinit_mask(x, k, v)
function? This would allow us to shift the conditionals out of the attention functions, while still allowing for relatively terse syntax likemask = CausalMask()
when users don't want to precompute their own. You could imagine nice party tricks like passingmask = I
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#460 is a go at this masked softmax idea.
With that, the default of no mask can in fact be
mask = Returns(true)
here, instead oftrues_like
. And the terse causal mask can beconst causal_mask = triu ∘ trues_like
, or a function equivalent to this (maybe it can be more efficient, not suretriu
works on CuArrays). No conditionals required.Edit: making #460 work on GPU too won't be just a few lines. But even without that,
mask::Function = trues_like
as the interface seems nice, instead of having to independently make something the right size.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
triu
only works onAbstractMatrix
, which is not sufficient for the attention.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this first implementation, I prefer to keep it more minimalistic and just accept
nothing
or arrays (I will remove:causal
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My vote is to make this an internal
_batched_mul_4
or something for now. Partly because I think explaining what does and doesn't work becomes more complicated with this method. And that doesn't have to be solved to add attention.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a pity to not make things available. Maybe I can leave the previous docstring unchanged and add a new one for the new method?