Skip to content

Commit

Permalink
Update for sklearn 1.3.
Browse files Browse the repository at this point in the history
  • Loading branch information
wcwitt committed Aug 12, 2024
1 parent 272270c commit 2bd32fa
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface

# Create the solver itself and give it parameters
solver = ARDRegressor(
n_iter = 300,
max_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
)
Expand Down
4 changes: 2 additions & 2 deletions ext/ACEfit_MLJScikitLearnInterface_ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface
# Create the solver itself and give it parameters
solver = ARDRegressor(
n_iter = 300,
max_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
# more params
Expand All @@ -43,4 +43,4 @@ function ACEfit.solve(solver, A, y)
return Dict{String, Any}("C" => pyconvert(Array, c) )
end

end
end
14 changes: 7 additions & 7 deletions ext/ACEfit_PythonCall_ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ using PythonCall
function ACEfit.solve(solver::ACEfit.SKLEARN_BRR, A, y)
@info "Entering SKLEARN_BRR"
BRR = pyimport("sklearn.linear_model")."BayesianRidge"
clf = BRR(n_iter = solver.n_iter, tol = solver.tol, fit_intercept = true,
clf = BRR(max_iter = solver.max_iter, tol = solver.tol, fit_intercept = true,
compute_score = true)
clf.fit(A, y)
if length(clf.scores_) < solver.n_iter
if length(clf.scores_) < solver.max_iter
@info "BRR converged to tol=$(solver.tol) after $(length(clf.scores_)) iterations."
else
@warn "\nBRR did not converge to tol=$(solver.tol) after n_iter=$(solver.n_iter) iterations.\n"
@warn "\nBRR did not converge to tol=$(solver.tol) after max_iter=$(solver.max_iter) iterations.\n"
end
c = clf.coef_
return Dict{String, Any}("C" => pyconvert(Array, c) )
Expand All @@ -22,17 +22,17 @@ end

function ACEfit.solve(solver::ACEfit.SKLEARN_ARD, A, y)
ARD = pyimport("sklearn.linear_model")."ARDRegression"
clf = ARD(n_iter = solver.n_iter, threshold_lambda = solver.threshold_lambda,
clf = ARD(max_iter = solver.max_iter, threshold_lambda = solver.threshold_lambda,
tol = solver.tol,
fit_intercept = true, compute_score = true)
clf.fit(A, y)
if length(clf.scores_) < solver.n_iter
if length(clf.scores_) < solver.max_iter
@info "ARD converged to tol=$(solver.tol) after $(length(clf.scores_)) iterations."
else
@warn "\n\nARD did not converge to tol=$(solver.tol) after n_iter=$(solver.n_iter) iterations.\n\n"
@warn "\n\nARD did not converge to tol=$(solver.tol) after max_iter=$(solver.max_iter) iterations.\n\n"
end
c = clf.coef_
return Dict{String, Any}("C" => pyconvert(Array,c) )
end

end
end
12 changes: 6 additions & 6 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ SKLEARN_BRR
"""
struct SKLEARN_BRR
tol::Number
n_iter::Integer
max_iter::Integer
end

function SKLEARN_BRR(; tol = 1e-3, n_iter = 300)
function SKLEARN_BRR(; tol = 1e-3, max_iter = 300)
@warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this."
SKLEARN_BRR(tol, n_iter)
SKLEARN_BRR(tol, max_iter)
end

# solve(solver::SKLEARN_BRR, ...) is implemented in ext/
Expand All @@ -138,14 +138,14 @@ end
SKLEARN_ARD
"""
struct SKLEARN_ARD
n_iter::Integer
max_iter::Integer
tol::Number
threshold_lambda::Number
end

function SKLEARN_ARD(; n_iter = 300, tol = 1e-3, threshold_lambda = 10000)
function SKLEARN_ARD(; max_iter = 300, tol = 1e-3, threshold_lambda = 10000)
@warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this."
SKLEARN_ARD(n_iter, tol, threshold_lambda)
SKLEARN_ARD(max_iter, tol, threshold_lambda)
end

# solve(solver::SKLEARN_ARD, ...) is implemented in ext/
Expand Down
4 changes: 2 additions & 2 deletions test/test_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ C = results["C"]
@info(" ... MLJ SKLearn ARD")
ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface
solver = ARDRegressor(
n_iter = 300,
max_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C)

0 comments on commit 2bd32fa

Please sign in to comment.