Skip to content

MurrellGroup/Einops.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

82 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Einops.jl

Stable Dev Build Status Coverage

Einops.jl is a Julia implementation of einops, providing a concise notation for tensor operations, and unifying Julia's reshape, permutedims, reduce and repeat functions, with support for automatic differentiation.

The Python implementation uses strings to specify the operation, which is tricky to compile in Julia, so a string macro is exported for parity, e.g. einops"a b -> (b a)" expands to the form (:a, :b) --> ((:b, :a),), where --> is a custom operator that puts the left and right operands as type parameters of a special pattern type. This allows for compile-time awareness of dimensionalities, ensuring type stability.

Operations

rearrange

The rearrange combines reshaping and permutation operations into a single, expressive command.

julia> images = randn(3, 40, 30, 32); # channel, width, height, batch

# reorder axes to "w h c b" format:
julia> rearrange(images, (:c, :w, :h, :b) --> (:w, :h, :c, :b)) |> size
(40, 30, 3, 32)

# flatten each image into a vector
julia> rearrange(images, (c, :w, :h, :b) --> ((:c, :w, :h), :b)) |> size
(32, 3600)

# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
julia> rearrange(images, (:c, (:w, :w2), (:h, :h2), :b) --> (:c, :w, :h, (:w2, :h2, :b)), h2=2, w2=2) |> size
(3, 20, 15, 128)

reduce

The method for Base.reduce dispatches on ArrowPattern, applying reduction operations (like sum, mean, maximum) along specified axes. This is different from typical Base.reduce functionality, which reduces using binary operations.

julia> x = randn(64, 32, 100);

# perform max-reduction on the first axis
# Axis t does not appear on the right - thus we reduce over t
julia> reduce(maximum, x, (:c, :b, :t) --> (:c, :b)) |> size
(64, 32)

julia> reduce(mean, x, (:c, :b, (:t, :t5)) --> (:b, :c, :t), t5=5) |> size
(32, 64, 20)

repeat

The method for Base.repeat also dispatches on ArrowPattern, and repeats elements along existing or new axes.

julia> image = randn(40, 30); # a grayscale image (of shape height x width)

# change it to RGB format by repeating in each channel
julia> repeat(image, (:w, :h) --> (:c, :w, :h), c=3) |> size
(3, 40, 30)

# repeat image 2 times along height (vertical axis)
julia> repeat(image, (:w, :h) --> ((:repeat, :h), :w), repeat=2) |> size
(60, 40)

# repeat image 2 time along height and 3 times along width
julia> repeat(image, (:w, :h) --> ((:w, :w3), (:h, :h2)), w3=3, h2=2) |> size
(120, 60)

Roadmap

  • Implement rearrange.
  • Support Python implementation's string syntax for patterns with string macro.
  • Implement pack and unpack.
  • Implement parse_shape.
  • Implement repeat.
  • Implement reduce.
  • Support automatic differentiation (tested with Zygote.jl).
  • Implement einsum (or wrap existing implementation) (see #3).
  • Support ellipsis notation (using .. from EllipsisNotation.jl) (see #9).
  • Explore integration with PermutedDimsArray or TransmuteDims.jl for lazy and statically inferrable permutations (see #4).

Contributing

Contributions are welcome! Please feel free to open an issue to report a bug or start a discussion.

Contributors 2

  •  
  •  

Languages