Skip to content

Commit

Permalink
fix: better caching
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 12, 2025
1 parent 2eae8fb commit 9c7c15c
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/layers/embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ function initialstates(::AbstractRNG, rpe::RotaryPositionalEmbedding)

seq_idx = collect(Float32, 0:(rpe.max_sequence_length - 1))
idx_theta = reshape(theta, :, 1) .* reshape(seq_idx, 1, :)
cache = vcat(cos.(idx_theta), sin.(idx_theta))

return (; cache, theta)
return (; cos_cache=cos.(idx_theta), sin_cache=sin.(idx_theta))

Check warning on line 179 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L177-L179

Added lines #L177 - L179 were not covered by tests
end

function (rope::RotaryPositionalEmbedding)(

Check warning on line 182 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L182

Added line #L182 was not covered by tests
Expand All @@ -193,24 +191,28 @@ function (rope::RotaryPositionalEmbedding)((x, input_pos)::Tuple, ps, st::NamedT
y = match_eltype(rope, ps, st, x)

Check warning on line 191 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L190-L191

Added lines #L190 - L191 were not covered by tests

# extract the values based on whether input_pos is set or not
rope_cache = input_pos === nothing ? st.cache[:, 1:seq_len] : st.cache[:, input_pos]
if input_pos === nothing
cos_cache = st.cos_cache[:, 1:seq_len]
sin_cache = st.sin_cache[:, 1:seq_len]

Check warning on line 196 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L194-L196

Added lines #L194 - L196 were not covered by tests
else
cos_cache = st.cos_cache[:, input_pos]
sin_cache = st.sin_cache[:, input_pos]

Check warning on line 199 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L198-L199

Added lines #L198 - L199 were not covered by tests
end

# reshape input; the last dimension is used for computing the output.
# Cast to float to match the reference implementation
xshaped = reshape(float.(y), 2, h_d ÷ 2, n_h, seq_len, b)

Check warning on line 204 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L204

Added line #L204 was not covered by tests

# reshape the cache for broadcasting
rope_cache = reshape(rope_cache, 2, h_d ÷ 2, 1, seq_len, :)
cos_cache = reshape(cos_cache, 1, h_d ÷ 2, 1, seq_len, :)
sin_cache = reshape(sin_cache, 1, h_d ÷ 2, 1, seq_len, :)

Check warning on line 208 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L207-L208

Added lines #L207 - L208 were not covered by tests

xshaped1 = xshaped[1:1, :, :, :, :]
xshaped2 = xshaped[2:2, :, :, :, :]

Check warning on line 211 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L210-L211

Added lines #L210 - L211 were not covered by tests

rope_cache1 = rope_cache[1:1, :, :, :, :]
rope_cache2 = rope_cache[2:2, :, :, :, :]

x_out = vcat(

Check warning on line 213 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L213

Added line #L213 was not covered by tests
xshaped1 .* rope_cache1 - xshaped2 .* rope_cache2,
xshaped2 .* rope_cache1 + xshaped1 .* rope_cache2
xshaped1 .* cos_cache - xshaped2 .* sin_cache,
xshaped2 .* cos_cache + xshaped1 .* sin_cache
)

return reshape(x_out, h_d, n_h, seq_len, b), st

Check warning on line 218 in src/layers/embedding.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/embedding.jl#L218

Added line #L218 was not covered by tests
Expand Down

0 comments on commit 9c7c15c

Please sign in to comment.