Skip to content

Commit

Permalink
feat: generalize indexing to all wrappers (#146)
Browse files Browse the repository at this point in the history
* feat: generalize indexing to all wrappers

* test: use `PermuteDimsArray` to test parentindices
  • Loading branch information
avik-pal authored Oct 4, 2024
1 parent fd9b469 commit 8435c5e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ ancestor(x::TracedRArray) = x
ancestor(x::WrappedTracedRArray) = ancestor(parent(x))

get_ancestor_indices(::TracedRArray, indices...) = indices
function get_ancestor_indices(
x::SubArray{T,N,<:AnyTracedRArray{T,N}}, indices...
) where {T,N}
return get_ancestor_indices(parent(x), Base.reindex(x.indices, indices)...)
function get_ancestor_indices(x::WrappedTracedRArray, indices...)
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
end

Base.getindex(a::AnyTracedRScalar{T}) where {T} = a
Expand Down
14 changes: 14 additions & 0 deletions test/wrapped_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,17 @@ end

@test btranspose_badjoint_compiled(x_ra) btranspose_badjoint(x)
end

function bypass_permutedims(x)
x = PermutedDimsArray(x, (2, 1, 3)) # Don't use permutedims here
return view(x, 2:3, 1:2, :)
end

@testset "PermutedDimsArray" begin
x = rand(4, 4, 3)
x_ra = Reactant.to_rarray(x)

bypass_permutedims_compiled = @compile bypass_permutedims(x_ra)

@test bypass_permutedims_compiled(x_ra) bypass_permutedims(x)
end

1 comment on commit 8435c5e

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 8435c5e Previous: fd9b469 Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1368821584 ns 1342236532 ns 1.02
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 218425022 ns 217334200 ns 1.01
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 6050106855 ns 7478176179 ns 0.81
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 15488677309 ns 18457674181 ns 0.84
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1325902484 ns 1240775517.5 ns 1.07
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 9023661 ns 8429118 ns 1.07
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1609420383 ns 1705665564 ns 0.94
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2794718192 ns 2181011809.5 ns 1.28
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1294577037 ns 1270663182 ns 1.02
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 95904983 ns 87661764.5 ns 1.09
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2174332771 ns 2264035064 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 6821449717 ns 12598148585 ns 0.54
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1325701874 ns 1296031343.5 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7473938 ns 7435714 ns 1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1465073203 ns 1522842192 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1463367147.5 ns 1593852319 ns 0.92
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1326350975.5 ns 1290749033 ns 1.03
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 11629254 ns 11587944 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1757970313 ns 1829919441 ns 0.96
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3444375485.5 ns 2589622924 ns 1.33
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1298097547.5 ns 1278886629 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 88846123 ns 88840313 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2210026137 ns 2317384176 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 4242489815 ns 3974355306 ns 1.07
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1276575769.5 ns 1286135780.5 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 115004565 ns 115595273.5 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3089232880 ns 3066839756 ns 1.01
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 17522163039 ns 7963080791 ns 2.20
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1322144951.5 ns 1298156213 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 125463399 ns 120812566 ns 1.04
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3222833870 ns 3345928889 ns 0.96
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 6170636132 ns 12142011305 ns 0.51
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1280369557 ns 1361020077 ns 0.94
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 84708910.5 ns 87717808.5 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 2107430045 ns 1957704050 ns 1.08
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 2642122582 ns 2423933291 ns 1.09

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.