Einops.jl is a Julia implementation of einops, providing a concise notation for tensor operations, and unifying Julia's reshape
, permutedims
, reduce
and repeat
functions, with support for automatic differentiation.
The Python implementation uses strings to specify the operation, which is tricky to compile in Julia, so a string macro is exported for parity, e.g. einops"a b -> (b a)"
expands to the form (:a, :b) --> ((:b, :a),)
, where -->
is a custom operator that puts the left and right operands as type parameters of a special pattern type. This allows for compile-time awareness of dimensionalities, ensuring type stability.
The rearrange
combines reshaping and permutation operations into a single, expressive command.
julia> images = randn(3, 40, 30, 32); # channel, width, height, batch
# reorder axes to "w h c b" format:
julia> rearrange(images, (:c, :w, :h, :b) --> (:w, :h, :c, :b)) |> size
(40, 30, 3, 32)
# flatten each image into a vector
julia> rearrange(images, (c, :w, :h, :b) --> ((:c, :w, :h), :b)) |> size
(32, 3600)
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
julia> rearrange(images, (:c, (:w, :w2), (:h, :h2), :b) --> (:c, :w, :h, (:w2, :h2, :b)), h2=2, w2=2) |> size
(3, 20, 15, 128)
The method for Base.reduce
dispatches on ArrowPattern
, applying reduction operations (like sum
, mean
, maximum
) along specified axes. This is different from typical Base.reduce
functionality, which reduces using binary operations.
julia> x = randn(64, 32, 100);
# perform max-reduction on the first axis
# Axis t does not appear on the right - thus we reduce over t
julia> reduce(maximum, x, (:c, :b, :t) --> (:c, :b)) |> size
(64, 32)
julia> reduce(mean, x, (:c, :b, (:t, :t5)) --> (:b, :c, :t), t5=5) |> size
(32, 64, 20)
The method for Base.repeat
also dispatches on ArrowPattern
, and repeats elements along existing or new axes.
julia> image = randn(40, 30); # a grayscale image (of shape height x width)
# change it to RGB format by repeating in each channel
julia> repeat(image, (:w, :h) --> (:c, :w, :h), c=3) |> size
(3, 40, 30)
# repeat image 2 times along height (vertical axis)
julia> repeat(image, (:w, :h) --> ((:repeat, :h), :w), repeat=2) |> size
(60, 40)
# repeat image 2 time along height and 3 times along width
julia> repeat(image, (:w, :h) --> ((:w, :w3), (:h, :h2)), w3=3, h2=2) |> size
(120, 60)
- Implement
rearrange
. - Support Python implementation's string syntax for patterns with string macro.
- Implement
pack
andunpack
. - Implement
parse_shape
. - Implement
repeat
. - Implement
reduce
. - Support automatic differentiation (tested with Zygote.jl).
- Implement
einsum
(or wrap existing implementation) (see #3). - Support ellipsis notation (using
..
from EllipsisNotation.jl) (see #9). - Explore integration with
PermutedDimsArray
or TransmuteDims.jl for lazy and statically inferrable permutations (see #4).
Contributions are welcome! Please feel free to open an issue to report a bug or start a discussion.