From 9dd4d1f009d479a55ae5308933000f6b235ef713 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Feb 2025 23:19:06 -0500 Subject: [PATCH] feat: add RoPE --- docs/src/api/Lux/layers.md | 3 +- src/Lux.jl | 2 +- src/layers/embedding.jl | 95 +++++++++++++++++++++++++++++++++++--- 3 files changed, 91 insertions(+), 9 deletions(-) diff --git a/docs/src/api/Lux/layers.md b/docs/src/api/Lux/layers.md index 845255f87..4f5377c7d 100644 --- a/docs/src/api/Lux/layers.md +++ b/docs/src/api/Lux/layers.md @@ -67,7 +67,8 @@ Scale ```@docs Embedding -SinusoidalPositionalEncoding +RotaryPositionalEmbedding +SinusoidalPositionalEmbedding ``` ## Misc. Helper Layers diff --git a/src/Lux.jl b/src/Lux.jl index 6e0cc655c..16c56dcf9 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -107,7 +107,7 @@ include("deprecations.jl") # Layers export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Scale -export Embedding, SinusoidalPositionalEncoding, RotaryPositionalEncoding +export Embedding, SinusoidalPositionalEmbedding, RotaryPositionalEmbedding export Conv, ConvTranspose, Upsample, PixelShuffle export MaxPool, MeanPool, LPPool, GlobalMaxPool, GlobalMeanPool, GlobalLPPool, AdaptiveMaxPool, AdaptiveMeanPool, AdaptiveLPPool diff --git a/src/layers/embedding.jl b/src/layers/embedding.jl index 9f612866f..e316f3f3c 100644 --- a/src/layers/embedding.jl +++ b/src/layers/embedding.jl @@ -85,9 +85,9 @@ function (e::Embedding)(::Tuple{}, _, ::NamedTuple) end """ - SinusoidalPositionalEncoding(....) + SinusoidalPositionalEmbedding(....) -Sinusoidal Positional Encoding. For details see [1]. +Sinusoidal Positional Embedding. For details see [1]. ## Arguments @@ -106,7 +106,7 @@ Sinusoidal Positional Encoding. For details see [1]. [1] Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017). """ -@concrete struct SinusoidalPositionalEncoding{T} <: AbstractLuxLayer +@concrete struct SinusoidalPositionalEmbedding{T} <: AbstractLuxLayer log_min_freq::T log_max_freq::T dims <: IntegerType @@ -114,23 +114,104 @@ Systems (2017). full_turns::Bool end -function SinusoidalPositionalEncoding(dims::IntegerType; min_freq=0.0001f0, max_freq=1.0f0, +function SinusoidalPositionalEmbedding( + dims::IntegerType; min_freq=0.0001f0, max_freq=1.0f0, scale=nothing, full_turns::Bool=false) T = promote_type(typeof(min_freq), typeof(max_freq)) scale = scale === nothing ? T(√(2 / dims)) : T(scale) - return SinusoidalPositionalEncoding( + return SinusoidalPositionalEmbedding( T(log(min_freq)), T(log(max_freq)), dims, scale, full_turns) end -function initialstates(::AbstractRNG, spe::SinusoidalPositionalEncoding{T}) where {T} +function initialstates(::AbstractRNG, spe::SinusoidalPositionalEmbedding{T}) where {T} one_zero = range(T(1), T(0); length=spe.dims ÷ 2) sigmas = exp.(one_zero .* (spe.log_max_freq - spe.log_min_freq) .+ spe.log_min_freq) spe.full_turns && (@. sigmas *= 2π) return (; sigmas) end -function (spe::SinusoidalPositionalEncoding)(x::AbstractArray, ps, st::NamedTuple) +function (spe::SinusoidalPositionalEmbedding)(x::AbstractArray, ps, st::NamedTuple) y = reshape(match_eltype(spe, ps, st, x), 1, size(x)...) .* st.sigmas z = vcat(sin.(y), cos.(y)) .* spe.scale return z, st end + +""" + RotaryPositionalEmbedding(....) + +Rotary Positional Embedding. For details see [1]. + +## Arguments + +## Keyword Arguments + +## Input + +## Returns + +## Parameters + +## States + +## References + +[1] Su, Jianlin, et al. "Roformer: Enhanced transformer with rotary position embedding. +"Neurocomputing 568 (2024): 127063. +""" +@concrete struct RotaryPositionalEmbedding <: AbstractLuxLayer + dim <: IntegerType + max_sequence_length <: IntegerType + base <: IntegerType +end + +function RotaryPositionalEmbedding( + dim::IntegerType; max_sequence_length::IntegerType=4096, base::IntegerType=10000) + return RotaryPositionalEmbedding(dim, max_sequence_length, base) +end + +function initialstates(::AbstractRNG, rpe::RotaryPositionalEmbedding) + theta = 1.0f0 ./ + Float32.(rpe.base .^ + (range(0, rpe.dim - 1; step=2)[1:(rpe.dim ÷ 2)] ./ rpe.dim)) + + seq_idx = collect(Float32, 0:(rpe.max_sequence_length - 1)) + idx_theta = reshape(theta, :, 1) .* reshape(seq_idx, 1, :) + cache = vcat(cos.(idx_theta), sin.(idx_theta)) + + return (; cache, theta) +end + +function (rope::RotaryPositionalEmbedding)( + x::AbstractArray{T, 4}, ps, st::NamedTuple) where {T} + return rope((x, nothing), ps, st) +end + +function (rope::RotaryPositionalEmbedding)((x, input_pos)::Tuple, ps, st::NamedTuple) + @assert ndims(x)==4 "Input must be a 4D tensor" + + h_d, n_h, seq_len, b = size(x) + y = match_eltype(rope, ps, st, x) + + # extract the values based on whether input_pos is set or not + rope_cache = input_pos === nothing ? st.cache[:, 1:seq_len] : st.cache[:, input_pos] + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + xshaped = reshape(float.(y), 2, h_d ÷ 2, n_h, seq_len, b) + + # reshape the cache for broadcasting + rope_cache = reshape(rope_cache, 2, h_d ÷ 2, 1, seq_len, :) + + xshaped1 = xshaped[1:1, :, :, :, :] + xshaped2 = xshaped[2:2, :, :, :, :] + + rope_cache1 = rope_cache[1:1, :, :, :, :] + rope_cache2 = rope_cache[2:2, :, :, :, :] + + x_out = vcat( + xshaped1 .* rope_cache1 - xshaped2 .* rope_cache2, + xshaped2 .* rope_cache1 + xshaped1 .* rope_cache2 + ) + + return reshape(x_out, h_d, n_h, seq_len, b), st +end