diff --git a/src/layers/embedding.jl b/src/layers/embedding.jl index e316f3f3c..b821241c8 100644 --- a/src/layers/embedding.jl +++ b/src/layers/embedding.jl @@ -176,9 +176,7 @@ function initialstates(::AbstractRNG, rpe::RotaryPositionalEmbedding) seq_idx = collect(Float32, 0:(rpe.max_sequence_length - 1)) idx_theta = reshape(theta, :, 1) .* reshape(seq_idx, 1, :) - cache = vcat(cos.(idx_theta), sin.(idx_theta)) - - return (; cache, theta) + return (; cos_cache=cos.(idx_theta), sin_cache=sin.(idx_theta)) end function (rope::RotaryPositionalEmbedding)( @@ -193,24 +191,28 @@ function (rope::RotaryPositionalEmbedding)((x, input_pos)::Tuple, ps, st::NamedT y = match_eltype(rope, ps, st, x) # extract the values based on whether input_pos is set or not - rope_cache = input_pos === nothing ? st.cache[:, 1:seq_len] : st.cache[:, input_pos] + if input_pos === nothing + cos_cache = st.cos_cache[:, 1:seq_len] + sin_cache = st.sin_cache[:, 1:seq_len] + else + cos_cache = st.cos_cache[:, input_pos] + sin_cache = st.sin_cache[:, input_pos] + end # reshape input; the last dimension is used for computing the output. # Cast to float to match the reference implementation xshaped = reshape(float.(y), 2, h_d ÷ 2, n_h, seq_len, b) # reshape the cache for broadcasting - rope_cache = reshape(rope_cache, 2, h_d ÷ 2, 1, seq_len, :) + cos_cache = reshape(cos_cache, 1, h_d ÷ 2, 1, seq_len, :) + sin_cache = reshape(sin_cache, 1, h_d ÷ 2, 1, seq_len, :) xshaped1 = xshaped[1:1, :, :, :, :] xshaped2 = xshaped[2:2, :, :, :, :] - rope_cache1 = rope_cache[1:1, :, :, :, :] - rope_cache2 = rope_cache[2:2, :, :, :, :] - x_out = vcat( - xshaped1 .* rope_cache1 - xshaped2 .* rope_cache2, - xshaped2 .* rope_cache1 + xshaped1 .* rope_cache2 + xshaped1 .* cos_cache - xshaped2 .* sin_cache, + xshaped2 .* cos_cache + xshaped1 .* sin_cache ) return reshape(x_out, h_d, n_h, seq_len, b), st