Skip to content

Commit

Permalink
bump functors to 0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 5, 2024
1 parent 8b87c2b commit 635b0e5
Show file tree
Hide file tree
Showing 11 changed files with 28 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ FastClosures = "0.3.2"
Flux = "0.14.25"
ForwardDiff = "0.10.36"
FunctionWrappers = "1.1.3"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ DocumenterVitepress = "0.1.3"
Enzyme = "0.13.13"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
GPUArraysCore = "0.1, 0.2"
KernelAbstractions = "0.9"
LinearAlgebra = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion examples/BayesianNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CairoMakie = "0.12"
Functors = "0.4"
Functors = "0.4, 0.5"
LinearAlgebra = "1"
Lux = "1"
Random = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ChainRulesCore = "1.24"
Compat = "4.15.0"
DispatchDoctor = "0.4.10"
EnzymeCore = "0.8.5"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
MLDataDevices = "1"
Random = "1.10"
Reactant = "0.2.4"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Aqua = "0.8.7"
EnzymeCore = "0.8.5"
ExplicitImports = "1.9.0"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
MLDataDevices = "1.0.0"
Optimisers = "0.3.3"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxTestUtils/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ DispatchDoctor = "0.4.12"
Enzyme = "0.13.13"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.4.11"
Functors = "0.4.11, 0.5"
JET = "0.9.6"
MLDataDevices = "1.0.0"
ReverseDiff = "1.15.3"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ CUDA = "5.2"
ChainRulesCore = "1.23"
Compat = "4.15"
FillArrays = "1"
Functors = "0.4.8"
Functors = "0.4.8, 0.5"
GPUArrays = "10, 11"
MLUtils = "0.4.4"
Metal = "1"
Expand Down
22 changes: 5 additions & 17 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,24 +362,12 @@ function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDe
return set_device!(T, rank)
end

# Dispatches for Different Data Structures
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :Reactant)
ldev = Symbol(dev, :Device)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray
return Adapt.adapt(D, x)
end
return map(D, x)
end
(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
function (D::$(ldev))(x)
isleaf(x) && return Adapt.adapt(D, x)
return Functors.fmap(D, x; exclude=isleaf)
end
@eval function (D::$(ldev))(x)
isleaf(x) && return Adapt.adapt(D, x)
return Functors.fmap(D, x; exclude=isleaf)
end
end

Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ComponentArrays = "0.15.8"
ExplicitImports = "1.9.0"
FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.8"
Functors = "0.4.8, 0.5"
MLUtils = "0.4"
Pkg = "1.10"
Random = "1.10"
Expand Down
14 changes: 14 additions & 0 deletions lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,17 @@ end

@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
end

@testset "data movement is type stable" begin
cpu = cpu_device()
gpu = gpu_device()

r = [1, 2]
x = (a = r, b = 3, c =(4, (d=5, e=r)))
y = @inferred(gpu(x))
x2 = @inferred(cpu(y))

# identity is preserved
@test y.a === y.c[2].e
@test x2.a === x2.c[2].e
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Documenter = "1.4"
Enzyme = "0.13.13"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
Hwloc = "3.2.0"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
Expand Down

0 comments on commit 635b0e5

Please sign in to comment.