Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLJ extension #61

Merged
merged 6 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@ SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJScikitLearnInterface = "5ae90465-5518-4432-b9d2-8a1def2f0cab"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[extensions]
ACEfit_PythonCall_ext = "PythonCall"
ACEfit_MLJLinearModels_ext = [ "MLJ", "MLJLinearModels" ]
ACEfit_MLJScikitLearnInterface_ext = ["MLJScikitLearnInterface", "PythonCall", "MLJ"]

[compat]
julia = "1.9"
IterativeSolvers = "0.9.2"
MLJ = "0.19"
MLJLinearModels = "0.9"
MLJScikitLearnInterface = "0.5"
LowRankApprox = "0.5.3"
Optim = "1.7"
ParallelDataTransfer = "0.5.0"
Expand Down
29 changes: 29 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,35 @@ using ACEfit
using PythonCall
```

## MLJ solvers

To use [MLJ](https://github.com/alan-turing-institute/MLJ.jl) solvers you need to load MLJ in addition to ACEfit

```julia
using ACEfit
using MLJ
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually, does it make sense to recommend that MLJ is loaded first?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loading order does not matter, so I would leave it as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example in ext/ACEfit_MLJLinearModels_ext.jl loads MLJ first - can we make them consistent?

```

After that you need to load an appropriate MLJ solver. Take a look on available MLJ [solvers](https://alan-turing-institute.github.io/MLJ.jl/dev/model_browser/). Note that only [MLJScikitLearnInterface.jl](https://github.com/JuliaAI/MLJScikitLearnInterface.jl) and [MLJLinearModels.jl](https://github.com/JuliaAI/MLJLinearModels.jl) have extension available. To use other MLJ solvers please file an issue.

You need to load the solver and then create a solver structure

```julia
# Load ARD solver
ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface

# Create the solver itself and give it parameters
solver = ARDRegressor(
n_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
)
```

After this you can use the MLJ solver like any other solver.

## Index

```@index
```

Expand Down
53 changes: 53 additions & 0 deletions ext/ACEfit_MLJLinearModels_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module ACEfit_MLJLinearModels_ext

using MLJ
using ACEfit
using MLJLinearModels

"""
ACEfit.solve(solver, A, y)

Overloads `ACEfit.solve` to use MLJLinearModels solvers,
when `solver` is [MLJLinearModels](https://github.com/JuliaAI/MLJLinearModels.jl) solver.

# Example
```julia
using MLJ
using ACEfit

# Load Lasso solver
LassoRegressor = @load LassoRegressor pkg=MLJLinearModels

# Create the solver itself and give it parameters
solver = LassoRegressor(
lambda = 0.2,
fit_intercept = false
# insert more fit params
)

# fit ACE model
linear_fit(training_data, basis, solver)

# or lower level
ACEfit.fit(solver, A, y)
```
"""
function ACEfit.solve(solver::Union{
MLJLinearModels.ElasticNetRegressor,
MLJLinearModels.HuberRegressor,
MLJLinearModels.LADRegressor,
MLJLinearModels.LassoRegressor,
MLJLinearModels.LinearRegressor,
MLJLinearModels.QuantileRegressor,
MLJLinearModels.RidgeRegressor,
MLJLinearModels.RobustRegressor,
},
A, y)
Atable = MLJ.table(A)
mach = machine(solver, Atable, y)
MLJ.fit!(mach)
params = fitted_params(mach)
return Dict{String, Any}("C" => map( x->x.second, params.coefs) )
end

end
46 changes: 46 additions & 0 deletions ext/ACEfit_MLJScikitLearnInterface_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module ACEfit_MLJScikitLearnInterface_ext

using ACEfit
using MLJ
using MLJScikitLearnInterface
using PythonCall


"""
ACEfit.solve(solver, A, y)

Overloads `ACEfit.solve` to use scikitlearn solvers from MLJ.

# Example
```julia
using MLJ
using ACEfit

# Load ARD solver
ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface

# Create the solver itself and give it parameters
solver = ARDRegressor(
n_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
# more params
)

# fit ACE model
linear_fit(training_data, basis, solver)

# or lower level
ACEfit.fit(solver, A, y)
```
"""
function ACEfit.solve(solver, A, y)
Atable = MLJ.table(A)
mach = machine(solver, Atable, y)
MLJ.fit!(mach)
params = fitted_params(mach)
c = params.coef
return Dict{String, Any}("C" => pyconvert(Array, c) )
end

end
6 changes: 5 additions & 1 deletion src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ struct SKLEARN_BRR
n_iter::Integer
end

SKLEARN_BRR(; tol = 1e-3, n_iter = 300) = SKLEARN_BRR(tol, n_iter)
function SKLEARN_BRR(; tol = 1e-3, n_iter = 300)
@warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this."
SKLEARN_BRR(tol, n_iter)
end

# solve(solver::SKLEARN_BRR, ...) is implemented in ext/

Expand All @@ -140,6 +143,7 @@ struct SKLEARN_ARD
end

function SKLEARN_ARD(; n_iter = 300, tol = 1e-3, threshold_lambda = 10000)
@warn "SKLearn will transition to MLJ in future, please upgrade your script to reflect this."
SKLEARN_ARD(n_iter, tol, threshold_lambda)
end

Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJScikitLearnInterface = "5ae90465-5518-4432-b9d2-8a1def2f0cab"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ using Test
@testset "Bayesian Linear" begin include("test_bayesianlinear.jl") end

@testset "Linear Solvers" begin include("test_linearsolvers.jl") end

@testset "MLJ Solvers" begin include("test_mlj.jl") end
end
42 changes: 42 additions & 0 deletions test/test_mlj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using ACEfit
using LinearAlgebra
using MLJ
using MLJScikitLearnInterface

@info("Test MLJ interface on overdetermined system")
Nobs = 10_000
Nfeat = 100
A = randn(Nobs, Nfeat) / sqrt(Nobs)
y = randn(Nobs)
P = Diagonal(1.0 .+ rand(Nfeat))


@info(" ... MLJLinearModels LinearRegressor")
LinearRegressor = @load LinearRegressor pkg=MLJLinearModels
solver = LinearRegressor()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)


@info(" ... MLJLinearModels LassoRegressor")
LassoRegressor = @load LassoRegressor pkg=MLJLinearModels
solver = LassoRegressor()
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)


@info(" ... MLJ SKLearn ARD")
ARDRegressor = @load ARDRegressor pkg=MLJScikitLearnInterface
solver = ARDRegressor(
n_iter = 300,
tol = 1e-3,
threshold_lambda = 10000
)
results = ACEfit.solve(solver, A, y)
C = results["C"]
@show norm(A * C - y)
@show norm(C)