Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dot_product_attention #455

Merged
merged 15 commits into from
Feb 3, 2023
8 changes: 8 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ tanhshrink
trelu
```

## Attention

```@docs
dot_product_attention
dot_product_attention_scores
make_causal_mask
```

## Softmax

`Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally.
Expand Down
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ for f in ACTIVATIONS
end
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases

include("attention.jl")
export dot_product_attention, dot_product_attention_scores, make_causal_mask

include("dropout.jl")
export dropout, dropout!

Expand Down
139 changes: 139 additions & 0 deletions src/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
const AA3{T} = AbstractArray{T,3}
const AA4{T} = AbstractArray{T,4}
const AA{N,T} = AbstractArray{T,N}

"""
dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])

Multihead dot product attention used in transformer architectures.

The input arrays must have the first two dimensions given by the number of features
and the sequece length, then an arbitrary number of batch dimensions or none.

Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores.
of size `(kv_len, q_len, nheads, batch_size...)`.

See also [`dot_product_attention_scores`](@ref) if you only need the attention scores.

# Arguments

- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
It will be added to the attention scores before applying the softmax. Default `nothing`.
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
Default `identity` (no dropout).
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
The mask is applied to the attention scores before the softmax.
Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`.
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
- `nheads`: Number of heads to split the input arrays into. Default `1`.

# Examples

```julia
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
y, α = dot_product_attention(q, k, v)
```
"""
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N
batch_size = size(q)[3:end]
batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same."))
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

x, α = dot_product_attention(q, k, v, args...; kws...)

x = reshape(x, size(x, 1), size(x, 2), batch_size...)
α = reshape(α, size(α)[1:3]..., batch_size...)
return x, α
end

function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
fdrop=identity, mask=nothing, nheads=1)

(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same."))
size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same."))
size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same."))

# Multihead attention. TODO create fastpath for singlehead attention.
q, k, v = split_heads.((q, k, v), nheads)
x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)
return join_heads(x), α
end

function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]

α = dot_product_attention_scores(q, k, bias; fdrop, mask)
# [α] = [kv_len, q_len, nheads, batch_size]

# The following permutedims and batched_mul are equivalent to
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
vt = permutedims(v, (1, 3, 2, 4))
x = batched_mul(vt, α)
x = permutedims(x, (1, 3, 2, 4))
# [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size]
return x, α
end

"""
dot_product_attention_scores(query, key, [bias]; [fdrop, mask])

Return the attention scores for the [`dot_product_attention`](@ref).
Input arrays must have dimensions
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.

See [`dot_product_attention`](@ref) for more details.
"""
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
fdrop=identity, mask=nothing) where T

# The following permutedims and batched_mul are equivalent to
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
kt = permutedims(k, (3, 1, 2, 4))
qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1))
logits = batched_mul(kt, qt)
# [logits] = [kv_len, q_len, nheads, batch_size]

if bias !== nothing
logits = logits .+ bias
end

if mask !== nothing
if mask === :causal
mask = make_causal_mask(logits)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a cleaner API would be to let the mask keyword be a function. The nothing case is mask = identity and the causal case is mask = make_causal_mask (which I feel should be just causal_mask to be succinct).

Is there a reason to construct the mask on the fly? The calling layer in Flux can probably make and store the mask once. Then the other option is to allow nothing or an array. Then the user passes in mask = causal_mask(ntoken).

Copy link
Member

@mcabbott mcabbott Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the function which you pass, in this proposal?

  • mask = identity means this is applied to the array.

  • mask = make_causal_mask means it constructs a boolean matrix.

Agree that constructing the same matrix every time seems a bit wasteful, although probably not a big cost, there are quite a few larger copies made in this thing.

With mask = identity, the usual masking could be causal_mask! which is basically for i,j in ...; if i<j; x[i,j] = -Inf end; i.e. it just mutates the data array. This should be safe as the gradient of batched_mul does not need the original values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it shouldn't be identity, it should be trues_like though I'd be okay with nothing in order skip computing a mask at all.

My comment about constructing on the fly was not a performance concern. I just think it is more intuitive to pass in exactly the mask array I want used. It's an easier rule to remember and also scalable to whatever masking scheme is desired.

Copy link
Member

@mcabbott mcabbott Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside is that you have to make an array the right size. If you have several layers and want the same scheme for each, then perhaps it's a pain. Whereas a function like trues_like is told the size automatically.

(The implementation can branch on mask === trues_like to avoid work in the default case. We can also branch on the type of const causal_mask = triu ∘ trues_like if necc.)

While encoding this as a bool array makes some sense, it's also a little weird in that the implementation doesn't directly consume this. Maybe better than my mutating idea above, we can modify softmax to take a mask argument, and fuse it into the broadcast there, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but generally the size of this matrix which is # of tokens X # of tokens is known ahead of time. Even so, I agree that not needing to pass in this info is cleaner.

I mostly wanted to avoid "symbol switches" for arguments.

Copy link
Member

@mcabbott mcabbott Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to avoiding symbols. I like this mask = trues_like proposal the best so far.

One question I haven't looked at is what format the CUDNN thing is going to want.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of saying mask is either an array or callable, could we say it should be either an array or marker type for which one can override some init_mask(x, k, v) function? This would allow us to shift the conditionals out of the attention functions, while still allowing for relatively terse syntax like mask = CausalMask() when users don't want to precompute their own. You could imagine nice party tricks like passing mask = I.

Copy link
Member

@mcabbott mcabbott Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#460 is a go at this masked softmax idea.

With that, the default of no mask can in fact be mask = Returns(true) here, instead of trues_like. And the terse causal mask can be const causal_mask = triu ∘ trues_like, or a function equivalent to this (maybe it can be more efficient, not sure triu works on CuArrays). No conditionals required.

Edit: making #460 work on GPU too won't be just a few lines. But even without that, mask::Function = trues_like as the interface seems nice, instead of having to independently make something the right size.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triu only works on AbstractMatrix, which is not sufficient for the attention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this first implementation, I prefer to keep it more minimalistic and just accept nothing or arrays (I will remove :causal)

neginf = typemin(eltype(logits))
logits = ifelse.(mask, logits, neginf)
end

α = softmax(logits, dims=1)
return fdrop(α)
end

"""
make_causal_mask(x, dims=2)

Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
Its elements are set such that `m[i, j] == i ≤ j`.

Can be used to mask the attention scores in [`dot_product_attention`](@ref).
"""
function make_causal_mask(x::AbstractArray; dims::Int=2)
len = size(x, dims)
mask = triu(trues_like(x, (len, len)))
return mask
end

trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)

split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)
join_heads(x) = reshape(x, :, size(x)[3:end]...)

@non_differentiable make_causal_mask(x)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
@non_differentiable trues_like(::Any...)
@non_differentiable falses_like(::Any...)

15 changes: 13 additions & 2 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ _unbatch(A::BatchedAdjOrTrans) = parent(A)
batched_mul(A, B) -> C
A ⊠ B # \\boxtimes

Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`.
If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.
Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent
any indices in the last dimensions.

If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.

To transpose each matrix, apply `batched_transpose` to the array,
or `batched_adjoint` for conjugate-transpose:
Expand Down Expand Up @@ -42,6 +44,15 @@ This will be copied, as doing so is faster than `batched_mul_generic!`.
Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
"""
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My vote is to make this an internal _batched_mul_4 or something for now. Partly because I think explaining what does and doesn't work becomes more complicated with this method. And that doesn't have to be solved to add attention.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pity to not make things available. Maybe I can leave the previous docstring unchanged and add a new one for the new method?

batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :)
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end

function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 ||
throw(DimensionMismatch("batch size mismatch: A != B"))
Expand Down
2 changes: 1 addition & 1 deletion src/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ for (gemm, elt) in gemm_datatype_mappings

end

C
return C
end
end
end
71 changes: 71 additions & 0 deletions test/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
@testset "different batchsizes" begin
n = 15
lenq = 3
lenkv = 4
for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5]
q = rand(Float32, n, lenq, batch_size...)
k = rand(Float32, n, lenkv, batch_size...)
v = rand(Float32, n, lenkv, batch_size...)
y, α = dot_product_attention(q, k, v; nheads)
@test y isa Array{Float32}
@test size(y) == (n, lenq, batch_size...)
@test size(α) == (lenkv, lenq, nheads, batch_size...)
@test sum(α, dims=1) ≈ ones(1, lenq, nheads, batch_size...)
end
end

@testset "dot_product_attention_scores" begin
q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24
α = dot_product_attention_scores(q, k)
q2, k2 = reshape.((q, k), 8, 3, 1)
y, α2 = dot_product_attention(q2, k2, k2; nheads=2)
@test α ≈ α2
end

@testset "specific results" begin
q = k = v = reshape([1:12;], 4, 3, 1) ./ 12
y, α = dot_product_attention(q, k, v; nheads=2)
ytrue = [0.4297536645089624, 0.5130869978422957, 0.6137914555895531, 0.6971247889228864, 0.46431026790247376, 0.5476436012358071, 0.6478764227436047, 0.731209756076938, 0.49773020657887745, 0.5810635399122107, 0.6804545876711346, 0.763787921004468]
ytrue = reshape(ytrue, 4, 3, 1)
αtrue = [0.3138955704910261, 0.3329478654910607, 0.35315656401791323, 0.264431440679808, 0.32820631493296265, 0.4073622443872293, 0.21921458153690657, 0.31838021718955445, 0.4624052012735389, 0.2886914482847165, 0.33124273666190807, 0.3800658150533755, 0.24123865285082136, 0.3238934260675431, 0.43486792108163547, 0.19843756756539277, 0.31176110185581074, 0.4898013305787966]
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
αtrue = reshape(αtrue, 3, 3, 2, 1)
@test y ≈ ytrue
@test α ≈ αtrue
end

@testset "mask" begin
q = rand(4, 2, 3, 1)
k = rand(4, 2, 5, 1)

mask = rand(Bool, (5, 3))
α = dot_product_attention_scores(q, k; mask)
@test all((α[:,:,1,1].> 0) .== mask)
@test all((α[:,:,2,1].> 0) .== mask)

@testset "causal" begin
x = rand(4, 2, 3, 1)
mask = make_causal_mask(x, dims=3)
α = dot_product_attention_scores(x, x; mask)
@test all((α[:,:,1,1].> 0) .== mask)
@test all((α[:,:,2,1].> 0) .== mask)

α2 = dot_product_attention_scores(x, x; mask=:causal)
@test α2 ≈ α
end
end

@testset "dropout" begin
q = k = v = rand(10, 10, 10)
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)
y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5))
@test 0.6 > mean(>(0), α) > 0.4
end

@testset "bias" begin
q = rand(4, 5, 1)
k = v = rand(4, 3, 1)
bias = randn(3, 5)
y, α = dot_product_attention(q, k, v, bias; nheads=2)
@test size(α) == (3, 5, 2, 1)
@test size(y) == (4, 5, 1)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ include("test_utils.jl")
include("activations.jl")
end

@testset "Attention" begin
include("attention.jl")
end

@testset "Batched Multiplication" begin
include("batchedmul.jl")
end
Expand Down