Skip to content

Commit

Permalink
feat: add RoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 12, 2025
1 parent 96d671c commit 9dd4d1f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 9 deletions.
3 changes: 2 additions & 1 deletion docs/src/api/Lux/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ Scale

```@docs
Embedding
SinusoidalPositionalEncoding
RotaryPositionalEmbedding
SinusoidalPositionalEmbedding
```

## Misc. Helper Layers
Expand Down
2 changes: 1 addition & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ include("deprecations.jl")
# Layers
export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer
export Bilinear, Dense, Scale
export Embedding, SinusoidalPositionalEncoding, RotaryPositionalEncoding
export Embedding, SinusoidalPositionalEmbedding, RotaryPositionalEmbedding
export Conv, ConvTranspose, Upsample, PixelShuffle
export MaxPool, MeanPool, LPPool, GlobalMaxPool, GlobalMeanPool, GlobalLPPool,
AdaptiveMaxPool, AdaptiveMeanPool, AdaptiveLPPool
Expand Down
95 changes: 88 additions & 7 deletions src/layers/embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ function (e::Embedding)(::Tuple{}, _, ::NamedTuple)
end

"""
SinusoidalPositionalEncoding(....)
SinusoidalPositionalEmbedding(....)
Sinusoidal Positional Encoding. For details see [1].
Sinusoidal Positional Embedding. For details see [1].
## Arguments
Expand All @@ -106,31 +106,112 @@ Sinusoidal Positional Encoding. For details see [1].
[1] Vaswani, A. "Attention is all you need." Advances in Neural Information Processing
Systems (2017).
"""
@concrete struct SinusoidalPositionalEncoding{T} <: AbstractLuxLayer
@concrete struct SinusoidalPositionalEmbedding{T} <: AbstractLuxLayer
log_min_freq::T
log_max_freq::T
dims <: IntegerType
scale <: Real
full_turns::Bool
end

function SinusoidalPositionalEncoding(dims::IntegerType; min_freq=0.0001f0, max_freq=1.0f0,
function SinusoidalPositionalEmbedding(
dims::IntegerType; min_freq=0.0001f0, max_freq=1.0f0,
scale=nothing, full_turns::Bool=false)
T = promote_type(typeof(min_freq), typeof(max_freq))
scale = scale === nothing ? T((2 / dims)) : T(scale)
return SinusoidalPositionalEncoding(
return SinusoidalPositionalEmbedding(
T(log(min_freq)), T(log(max_freq)), dims, scale, full_turns)
end

function initialstates(::AbstractRNG, spe::SinusoidalPositionalEncoding{T}) where {T}
function initialstates(::AbstractRNG, spe::SinusoidalPositionalEmbedding{T}) where {T}
one_zero = range(T(1), T(0); length=spe.dims ÷ 2)
sigmas = exp.(one_zero .* (spe.log_max_freq - spe.log_min_freq) .+ spe.log_min_freq)
spe.full_turns && (@. sigmas *= 2π)
return (; sigmas)
end

function (spe::SinusoidalPositionalEncoding)(x::AbstractArray, ps, st::NamedTuple)
function (spe::SinusoidalPositionalEmbedding)(x::AbstractArray, ps, st::NamedTuple)
y = reshape(match_eltype(spe, ps, st, x), 1, size(x)...) .* st.sigmas
z = vcat(sin.(y), cos.(y)) .* spe.scale
return z, st
end

"""
RotaryPositionalEmbedding(....)
Rotary Positional Embedding. For details see [1].
## Arguments
## Keyword Arguments
## Input
## Returns
## Parameters
## States
## References
[1] Su, Jianlin, et al. "Roformer: Enhanced transformer with rotary position embedding.
"Neurocomputing 568 (2024): 127063.
"""
@concrete struct RotaryPositionalEmbedding <: AbstractLuxLayer
dim <: IntegerType
max_sequence_length <: IntegerType
base <: IntegerType
end

function RotaryPositionalEmbedding(
dim::IntegerType; max_sequence_length::IntegerType=4096, base::IntegerType=10000)
return RotaryPositionalEmbedding(dim, max_sequence_length, base)
end

function initialstates(::AbstractRNG, rpe::RotaryPositionalEmbedding)
theta = 1.0f0 ./
Float32.(rpe.base .^
(range(0, rpe.dim - 1; step=2)[1:(rpe.dim ÷ 2)] ./ rpe.dim))

seq_idx = collect(Float32, 0:(rpe.max_sequence_length - 1))
idx_theta = reshape(theta, :, 1) .* reshape(seq_idx, 1, :)
cache = vcat(cos.(idx_theta), sin.(idx_theta))

return (; cache, theta)
end

function (rope::RotaryPositionalEmbedding)(
x::AbstractArray{T, 4}, ps, st::NamedTuple) where {T}
return rope((x, nothing), ps, st)
end

function (rope::RotaryPositionalEmbedding)((x, input_pos)::Tuple, ps, st::NamedTuple)
@assert ndims(x)==4 "Input must be a 4D tensor"

h_d, n_h, seq_len, b = size(x)
y = match_eltype(rope, ps, st, x)

# extract the values based on whether input_pos is set or not
rope_cache = input_pos === nothing ? st.cache[:, 1:seq_len] : st.cache[:, input_pos]

# reshape input; the last dimension is used for computing the output.
# Cast to float to match the reference implementation
xshaped = reshape(float.(y), 2, h_d ÷ 2, n_h, seq_len, b)

# reshape the cache for broadcasting
rope_cache = reshape(rope_cache, 2, h_d ÷ 2, 1, seq_len, :)

xshaped1 = xshaped[1:1, :, :, :, :]
xshaped2 = xshaped[2:2, :, :, :, :]

rope_cache1 = rope_cache[1:1, :, :, :, :]
rope_cache2 = rope_cache[2:2, :, :, :, :]

x_out = vcat(
xshaped1 .* rope_cache1 - xshaped2 .* rope_cache2,
xshaped2 .* rope_cache1 + xshaped1 .* rope_cache2
)

return reshape(x_out, h_d, n_h, seq_len, b), st
end

0 comments on commit 9dd4d1f

Please sign in to comment.