Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update for sklearn 1.3. #77

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
matrix:
version:
- '1.9'
- '1.10'
- '1'
- 'nightly'
python-version:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ julia = "1.9"
IterativeSolvers = "0.9.2"
MLJ = "0.19"
MLJLinearModels = "0.9"
MLJScikitLearnInterface = "0.5"
MLJScikitLearnInterface = "0.7"
LowRankApprox = "0.5.3"
Optim = "1.7"
ParallelDataTransfer = "0.5.0"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface

# Create the solver itself and give it parameters
solver = ARDRegressor(
n_iter = 300,
max_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
)
Expand Down
4 changes: 2 additions & 2 deletions ext/ACEfit_MLJScikitLearnInterface_ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface

# Create the solver itself and give it parameters
solver = ARDRegressor(
n_iter = 300,
max_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
# more params
Expand All @@ -43,4 +43,4 @@ function ACEfit.solve(solver, A, y)
return Dict{String, Any}("C" => pyconvert(Array, c) )
end

end
end
14 changes: 7 additions & 7 deletions ext/ACEfit_PythonCall_ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ using PythonCall
function ACEfit.solve(solver::ACEfit.SKLEARN_BRR, A, y)
@info "Entering SKLEARN_BRR"
BRR = pyimport("sklearn.linear_model")."BayesianRidge"
clf = BRR(n_iter = solver.n_iter, tol = solver.tol, fit_intercept = true,
clf = BRR(max_iter = solver.max_iter, tol = solver.tol, fit_intercept = true,
compute_score = true)
clf.fit(A, y)
if length(clf.scores_) < solver.n_iter
if length(clf.scores_) < solver.max_iter
@info "BRR converged to tol=$(solver.tol) after $(length(clf.scores_)) iterations."
else
@warn "\nBRR did not converge to tol=$(solver.tol) after n_iter=$(solver.n_iter) iterations.\n"
@warn "\nBRR did not converge to tol=$(solver.tol) after max_iter=$(solver.max_iter) iterations.\n"
end
c = clf.coef_
return Dict{String, Any}("C" => pyconvert(Array, c) )
Expand All @@ -22,17 +22,17 @@ end

function ACEfit.solve(solver::ACEfit.SKLEARN_ARD, A, y)
ARD = pyimport("sklearn.linear_model")."ARDRegression"
clf = ARD(n_iter = solver.n_iter, threshold_lambda = solver.threshold_lambda,
clf = ARD(max_iter = solver.max_iter, threshold_lambda = solver.threshold_lambda,
tol = solver.tol,
fit_intercept = true, compute_score = true)
clf.fit(A, y)
if length(clf.scores_) < solver.n_iter
if length(clf.scores_) < solver.max_iter
@info "ARD converged to tol=$(solver.tol) after $(length(clf.scores_)) iterations."
else
@warn "\n\nARD did not converge to tol=$(solver.tol) after n_iter=$(solver.n_iter) iterations.\n\n"
@warn "\n\nARD did not converge to tol=$(solver.tol) after max_iter=$(solver.max_iter) iterations.\n\n"
end
c = clf.coef_
return Dict{String, Any}("C" => pyconvert(Array,c) )
end

end
end
12 changes: 6 additions & 6 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ SKLEARN_BRR
"""
struct SKLEARN_BRR
tol::Number
n_iter::Integer
max_iter::Integer
end

function SKLEARN_BRR(; tol = 1e-3, n_iter = 300)
function SKLEARN_BRR(; tol = 1e-3, max_iter = 300)
@warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this."
SKLEARN_BRR(tol, n_iter)
SKLEARN_BRR(tol, max_iter)
end

# solve(solver::SKLEARN_BRR, ...) is implemented in ext/
Expand All @@ -138,14 +138,14 @@ end
SKLEARN_ARD
"""
struct SKLEARN_ARD
n_iter::Integer
max_iter::Integer
tol::Number
threshold_lambda::Number
end

function SKLEARN_ARD(; n_iter = 300, tol = 1e-3, threshold_lambda = 10000)
function SKLEARN_ARD(; max_iter = 300, tol = 1e-3, threshold_lambda = 10000)
@warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this."
SKLEARN_ARD(n_iter, tol, threshold_lambda)
SKLEARN_ARD(max_iter, tol, threshold_lambda)
end

# solve(solver::SKLEARN_ARD, ...) is implemented in ext/
Expand Down
4 changes: 2 additions & 2 deletions test/test_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ C = results["C"]
@info(" ... MLJ SKLearn ARD")
ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface
solver = ARDRegressor(
n_iter = 300,
max_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)
@show norm(C)
Loading