Skip to content

Commit

Permalink
Merge pull request #260 from FluxML/optimiser-warning
Browse files Browse the repository at this point in the history
Add check that Flux optimiser is not being used
  • Loading branch information
ablaom authored Jun 11, 2024
2 parents 0b8be2a + 1d2670a commit 06282cd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ MLJModelInterface.deep_properties(::Type{<:MLJFluxModel}) =

# # CLEAN METHOD

const ERR_BAD_OPTIMISER = ArgumentError(
"Flux.jl optimiser detected. Only optimisers from Optimisers.jl are supported. "*
"For example, use `optimiser=Optimisers.Momentum()` after `import Optimisers`. "
)

function MLJModelInterface.clean!(model::MLJFluxModel)
warning = ""
if model.lambda < 0
Expand Down Expand Up @@ -40,6 +45,9 @@ function MLJModelInterface.clean!(model::MLJFluxModel)
"on an RNG during training, such as `Dropout`. Consider using "*
" `Random.default_rng()` instead. `"
end
# TODO: This could be removed in next breaking release (0.6.0):
model.optimiser isa Flux.Optimise.AbstractOptimiser && throw(ERR_BAD_OPTIMISER)

return warning
end

Expand Down Expand Up @@ -79,7 +87,7 @@ function regularized_optimiser(model, nbatches)
Optimisers.WeightDecay(λ_weight),
model.optimiser,
)
end
end
end

function MLJModelInterface.fit(model::MLJFluxModel,
Expand Down
4 changes: 4 additions & 0 deletions test/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ end
end
@test model.acceleration == CUDALibs()
end

@test_throws MLJFlux.ERR_BAD_OPTIMISER NeuralNetworkClassifier(
optimiser=Flux.Optimise.Adam(),
)
end

@testset "regularization: logic" begin
Expand Down

0 comments on commit 06282cd

Please sign in to comment.