Skip to content

Commit

Permalink
test: re-enable flux testing (#1123)
Browse files Browse the repository at this point in the history
* test: reenable flux testing

* chore: bump minimum optimisers version
  • Loading branch information
avik-pal authored Dec 6, 2024
1 parent ef0d450 commit fb44a48
Show file tree
Hide file tree
Showing 17 changed files with 40 additions and 48 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.0"
version = "1.4.1-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -79,7 +79,7 @@ DispatchDoctor = "0.4.12"
Enzyme = "0.13.16"
EnzymeCore = "0.8.6"
FastClosures = "0.3.2"
Flux = "0.14.25"
Flux = "0.15"
ForwardDiff = "0.10.36"
FunctionWrappers = "1.1.3"
Functors = "0.5"
Expand All @@ -95,7 +95,7 @@ MacroTools = "0.5.13"
Markdown = "1.10"
NCCL = "0.1.1"
NNlib = "0.9.24"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.6"
Expand All @@ -107,7 +107,7 @@ SimpleChains = "0.4.7"
Static = "1.1.1"
StaticArraysCore = "1.4.3"
Statistics = "1.10"
Tracker = "0.2.36"
Tracker = "0.2.37"
WeightInitializers = "1"
Zygote = "0.6.70"
julia = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ LuxLib = "1.3.4"
LuxTestUtils = "1.5"
MLDataDevices = "1.6"
NNlib = "0.9.24"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion examples/Basics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ ComponentArrays = "0.15.18"
ForwardDiff = "0.10"
Lux = "1"
LuxCUDA = "0.3"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Zygote = "0.6"
2 changes: 1 addition & 1 deletion examples/BayesianNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ Functors = "0.4, 0.5"
LinearAlgebra = "1"
Lux = "1.2"
Random = "1"
Tracker = "0.2.36"
Tracker = "0.2.37"
Turing = "0.34, 0.35"
Zygote = "0.6.69"
2 changes: 1 addition & 1 deletion examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ LuxCUDA = "0.3.2"
MLDatasets = "0.7.14"
MLUtils = "0.4.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4"
Optimisers = "0.4.1"
PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Expand Down
2 changes: 1 addition & 1 deletion examples/DDIM/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ JLD2 = "0.4.48, 0.5"
Lux = "1"
LuxCUDA = "0.3"
MLUtils = "0.4"
Optimisers = "0.3, 0.4"
Optimisers = "0.4.1"
ParameterSchedulers = "0.4.1"
ProgressBars = "1"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Setfield = "1"
Statistics = "1"
Zygote = "0.6"
2 changes: 1 addition & 1 deletion examples/ImageNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ MLUtils = "0.4.4"
MPI = "0.20.21"
NCCL = "0.1.1"
OneHotArrays = "0.2.5"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
ParameterSchedulers = "0.4.2"
Random = "1.10"
Setfield = "1.1.1"
Expand Down
2 changes: 1 addition & 1 deletion examples/NeuralODE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
OrdinaryDiffEqTsit5 = "1"
SciMLSensitivity = "7.63"
Statistics = "1"
Expand Down
2 changes: 1 addition & 1 deletion examples/PINN2DPDE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Lux = "1"
LuxCUDA = "0.3.3"
MLUtils = "0.4.4"
OnlineStats = "1.7.1"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Printf = "1.10"
Random = "1.10"
Statistics = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion examples/PolynomialFitting/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ ADTypes = "1.10"
CairoMakie = "0.12"
Lux = "1"
LuxCUDA = "0.3"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Statistics = "1"
Zygote = "0.6"
2 changes: 1 addition & 1 deletion examples/SimpleChains/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Lux = "1"
MLDatasets = "0.7.14"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Random = "1"
SimpleChains = "0.4.6"
Zygote = "0.6.69"
2 changes: 1 addition & 1 deletion examples/SimpleRNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ JLD2 = "0.5"
Lux = "1"
LuxCUDA = "0.3"
MLUtils = "0.4"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Statistics = "1"
Zygote = "0.6"
6 changes: 0 additions & 6 deletions ext/LuxFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,6 @@ function Lux.convert_flux_model(
return Lux.GroupNorm(l.chs, l.G, l.λ; l.affine, epsilon=l.ϵ)
end

const _INVALID_TRANSFORMATION_TYPES = Union{<:Flux.Recur}

function Lux.convert_flux_model(l::T; kwargs...) where {T <: _INVALID_TRANSFORMATION_TYPES}
throw(FluxModelConversionException("Transformation of type $(T) is not supported."))
end

for cell in (:RNNCell, :LSTMCell, :GRUCell)
msg = "Recurrent Cell: $(cell) for Flux has semantical difference with Lux, \
mostly in-terms of how the bias term is dealt with. Lux aligns with the Pytorch \
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ MLUtils = "0.4.3"
NNlib = "0.9.24"
Octavian = "0.3.28"
OneHotArrays = "0.2.5"
Optimisers = "0.3.4, 0.4"
Optimisers = "0.4.1"
Pkg = "1.10"
Preferences = "1.4.3"
Random = "1.10"
Expand All @@ -79,5 +79,5 @@ Static = "1"
StaticArrays = "1.9"
Statistics = "1.10"
Test = "1.10"
Tracker = "0.2.36"
Tracker = "0.2.37"
Zygote = "0.6.70"
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ if ("all" in LUX_TEST_GROUP || "misc" in LUX_TEST_GROUP)
push!(EXTRA_PKGS, Pkg.PackageSpec("MPI"))
(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") &&
push!(EXTRA_PKGS, Pkg.PackageSpec("NCCL"))
# XXX: Reactivate once Flux is compatible with Functors 0.5
# push!(EXTRA_PKGS, Pkg.PackageSpec("Flux"))
push!(EXTRA_PKGS, Pkg.PackageSpec("Flux"))
end

if !Sys.iswindows()
Expand Down
43 changes: 21 additions & 22 deletions test/transform/flux_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:misc] skip=:(true) begin
@testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:misc] begin
import Flux

toluxpsst = FromFluxAdaptor(; preserve_ps_st=true)
Expand All @@ -8,10 +8,10 @@
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@testset "Containers" begin
@testset "Chain" begin
models = [Flux.Chain(Flux.Dense(2 => 5), Flux.Dense(5 => 1)),
Flux.Chain(; l1=Flux.Dense(2 => 5), l2=Flux.Dense(5 => 1))] |> dev

for model in models
for model in [
Flux.Chain(Flux.Dense(2 => 5), Flux.Dense(5 => 1)) |> dev,
Flux.Chain(; l1=Flux.Dense(2 => 5), l2=Flux.Dense(5 => 1)) |> dev
]
x = rand(Float32, 2, 1) |> aType

model_lux = toluxpsst(model)
Expand Down Expand Up @@ -57,10 +57,10 @@
end

@testset "Parallel" begin
models = [Flux.Parallel(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)),
Flux.Parallel(+; l1=Flux.Dense(2 => 2), l2=Flux.Dense(2 => 2))] |> dev

for model in models
for model in [
Flux.Parallel(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) |> dev,
Flux.Parallel(+; l1=Flux.Dense(2 => 2), l2=Flux.Dense(2 => 2)) |> dev
]
x = rand(Float32, 2, 1) |> aType

model_lux = toluxpsst(model)
Expand Down Expand Up @@ -94,8 +94,10 @@

@testset "Linear" begin
@testset "Dense" begin
for model in [Flux.Dense(2 => 4) |> dev,
Flux.Dense(2 => 4; bias=false) |> dev]
for model in [
Flux.Dense(2 => 4) |> dev,
Flux.Dense(2 => 4; bias=false) |> dev
]
x = randn(Float32, 2, 4) |> aType

model_lux = toluxpsst(model)
Expand All @@ -112,7 +114,9 @@

@testset "Scale" begin
for model in [
Flux.Scale(2) |> dev, Flux.Scale(2; bias=false) |> dev]
Flux.Scale(2) |> dev,
Flux.Scale(2; bias=false) |> dev
]
x = randn(Float32, 2, 4) |> aType

model_lux = toluxpsst(model)
Expand All @@ -128,8 +132,10 @@
end

@testset "Bilinear" begin
for model in [Flux.Bilinear((2, 3) => 5) |> dev,
Flux.Bilinear((2, 3) => 5; bias=false) |> dev]
for model in [
Flux.Bilinear((2, 3) => 5) |> dev,
Flux.Bilinear((2, 3) => 5; bias=false) |> dev
]
x = randn(Float32, 2, 4) |> aType
y = randn(Float32, 3, 4) |> aType

Expand Down Expand Up @@ -447,7 +453,7 @@
bias
end

Flux.@functor CustomFluxLayer
Flux.@layer CustomFluxLayer

(c::CustomFluxLayer)(x) = c.weight .* x .+ c.bias

Expand All @@ -466,12 +472,5 @@
@test tolux(identity) isa Lux.NoOpLayer
@test tolux(+) isa Lux.WrappedFunction
end

@testset "Unsupported Layers" begin
accum(h, x) = (h + x, x)
rnn = Flux.Recur(accum, 0)

@test_throws Lux.FluxModelConversionException tolux(rnn)
end
end
end

0 comments on commit fb44a48

Please sign in to comment.