From 2bd32faf028c41a3c905faee778fa058f3065901 Mon Sep 17 00:00:00 2001 From: Chuck Witt Date: Mon, 12 Aug 2024 16:16:31 -0400 Subject: [PATCH] Update for sklearn 1.3. --- docs/src/index.md | 2 +- ext/ACEfit_MLJScikitLearnInterface_ext.jl | 4 ++-- ext/ACEfit_PythonCall_ext.jl | 14 +++++++------- src/solvers.jl | 12 ++++++------ test/test_mlj.jl | 4 ++-- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 3fb9969..d079563 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -34,7 +34,7 @@ ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface # Create the solver itself and give it parameters solver = ARDRegressor( - n_iter = 300, + max_iter = 300, tol = 1e-3, threshold_lambda = 10000 ) diff --git a/ext/ACEfit_MLJScikitLearnInterface_ext.jl b/ext/ACEfit_MLJScikitLearnInterface_ext.jl index 34ae06a..6e2fba7 100644 --- a/ext/ACEfit_MLJScikitLearnInterface_ext.jl +++ b/ext/ACEfit_MLJScikitLearnInterface_ext.jl @@ -21,7 +21,7 @@ ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface # Create the solver itself and give it parameters solver = ARDRegressor( - n_iter = 300, + max_iter = 300, tol = 1e-3, threshold_lambda = 10000 # more params @@ -43,4 +43,4 @@ function ACEfit.solve(solver, A, y) return Dict{String, Any}("C" => pyconvert(Array, c) ) end -end \ No newline at end of file +end diff --git a/ext/ACEfit_PythonCall_ext.jl b/ext/ACEfit_PythonCall_ext.jl index c3521f1..9f39909 100644 --- a/ext/ACEfit_PythonCall_ext.jl +++ b/ext/ACEfit_PythonCall_ext.jl @@ -7,13 +7,13 @@ using PythonCall function ACEfit.solve(solver::ACEfit.SKLEARN_BRR, A, y) @info "Entering SKLEARN_BRR" BRR = pyimport("sklearn.linear_model")."BayesianRidge" - clf = BRR(n_iter = solver.n_iter, tol = solver.tol, fit_intercept = true, + clf = BRR(max_iter = solver.max_iter, tol = solver.tol, fit_intercept = true, compute_score = true) clf.fit(A, y) - if length(clf.scores_) < solver.n_iter + if length(clf.scores_) < solver.max_iter @info "BRR converged to tol=$(solver.tol) after $(length(clf.scores_)) iterations." else - @warn "\nBRR did not converge to tol=$(solver.tol) after n_iter=$(solver.n_iter) iterations.\n" + @warn "\nBRR did not converge to tol=$(solver.tol) after max_iter=$(solver.max_iter) iterations.\n" end c = clf.coef_ return Dict{String, Any}("C" => pyconvert(Array, c) ) @@ -22,17 +22,17 @@ end function ACEfit.solve(solver::ACEfit.SKLEARN_ARD, A, y) ARD = pyimport("sklearn.linear_model")."ARDRegression" - clf = ARD(n_iter = solver.n_iter, threshold_lambda = solver.threshold_lambda, + clf = ARD(max_iter = solver.max_iter, threshold_lambda = solver.threshold_lambda, tol = solver.tol, fit_intercept = true, compute_score = true) clf.fit(A, y) - if length(clf.scores_) < solver.n_iter + if length(clf.scores_) < solver.max_iter @info "ARD converged to tol=$(solver.tol) after $(length(clf.scores_)) iterations." else - @warn "\n\nARD did not converge to tol=$(solver.tol) after n_iter=$(solver.n_iter) iterations.\n\n" + @warn "\n\nARD did not converge to tol=$(solver.tol) after max_iter=$(solver.max_iter) iterations.\n\n" end c = clf.coef_ return Dict{String, Any}("C" => pyconvert(Array,c) ) end -end \ No newline at end of file +end diff --git a/src/solvers.jl b/src/solvers.jl index ca51d0d..077d17b 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -124,12 +124,12 @@ SKLEARN_BRR """ struct SKLEARN_BRR tol::Number - n_iter::Integer + max_iter::Integer end -function SKLEARN_BRR(; tol = 1e-3, n_iter = 300) +function SKLEARN_BRR(; tol = 1e-3, max_iter = 300) @warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this." - SKLEARN_BRR(tol, n_iter) + SKLEARN_BRR(tol, max_iter) end # solve(solver::SKLEARN_BRR, ...) is implemented in ext/ @@ -138,14 +138,14 @@ end SKLEARN_ARD """ struct SKLEARN_ARD - n_iter::Integer + max_iter::Integer tol::Number threshold_lambda::Number end -function SKLEARN_ARD(; n_iter = 300, tol = 1e-3, threshold_lambda = 10000) +function SKLEARN_ARD(; max_iter = 300, tol = 1e-3, threshold_lambda = 10000) @warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this." - SKLEARN_ARD(n_iter, tol, threshold_lambda) + SKLEARN_ARD(max_iter, tol, threshold_lambda) end # solve(solver::SKLEARN_ARD, ...) is implemented in ext/ diff --git a/test/test_mlj.jl b/test/test_mlj.jl index aa9324f..b3703d2 100644 --- a/test/test_mlj.jl +++ b/test/test_mlj.jl @@ -32,11 +32,11 @@ C = results["C"] @info(" ... MLJ SKLearn ARD") ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface solver = ARDRegressor( - n_iter = 300, + max_iter = 300, tol = 1e-3, threshold_lambda = 10000 ) results = ACEfit.solve(solver, A, y) C = results["C"] @show norm(A * C - y) -@show norm(C) \ No newline at end of file +@show norm(C)