@@ -3,21 +3,28 @@ module NonlinearSolveSciPy
3
3
using ConcreteStructs: @concrete
4
4
using Reexport: @reexport
5
5
6
- using PythonCall: pyimport, pyfunc
7
- const scipy_optimize = try
8
- pyimport (" scipy.optimize" )
9
- catch
10
- nothing
6
+ using PythonCall: pyimport, pyfunc, Py
7
+
8
+ const scipy_optimize = Ref {Union{Py, Nothing}} (nothing )
9
+ const PY_NONE = Ref {Union{Py, Nothing}} (nothing )
10
+ const _SCIPY_AVAILABLE = Ref {Bool} (false )
11
+
12
+ function __init__ ()
13
+ try
14
+ scipy_optimize[] = pyimport (" scipy.optimize" )
15
+ PY_NONE[] = pyimport (" builtins" ). None
16
+ _SCIPY_AVAILABLE[] = true
17
+ catch
18
+
19
+ _SCIPY_AVAILABLE[] = false
20
+ end
11
21
end
12
- const _SCIPY_AVAILABLE = scipy_optimize != = nothing
13
- const PY_NONE = pyimport (" builtins" ). None
14
22
15
23
using SciMLBase
16
24
using NonlinearSolveBase: AbstractNonlinearSolveAlgorithm, construct_extension_function_wrapper
17
25
18
26
"""
19
27
SciPyLeastSquares(; method="trf", loss="linear")
20
-
21
28
Wrapper over `scipy.optimize.least_squares` (via PythonCall) for solving
22
29
`NonlinearLeastSquaresProblem`s. The keyword arguments correspond to the
23
30
`method` ("trf", "dogbox", "lm") and the robust loss function ("linear",
@@ -30,7 +37,7 @@ Wrapper over `scipy.optimize.least_squares` (via PythonCall) for solving
30
37
end
31
38
32
39
function SciPyLeastSquares (; method:: String = " trf" , loss:: String = " linear" )
33
- _SCIPY_AVAILABLE || error (" `SciPyLeastSquares` requires the Python package `scipy` to be available to PythonCall." )
40
+ _SCIPY_AVAILABLE[] || error (" `SciPyLeastSquares` requires the Python package `scipy` to be available to PythonCall." )
34
41
valid_methods = (" trf" , " dogbox" , " lm" )
35
42
valid_losses = (" linear" , " soft_l1" , " huber" , " cauchy" , " arctan" )
36
43
method in valid_methods ||
@@ -46,7 +53,6 @@ SciPyLeastSquaresLM() = SciPyLeastSquares(method = "lm")
46
53
47
54
"""
48
55
SciPyRoot(; method="hybr")
49
-
50
56
Wrapper over `scipy.optimize.root` for solving `NonlinearProblem`s. Available
51
57
methods include "hybr" (default), "lm", "broyden1", "broyden2", "anderson",
52
58
"diagbroyden", "linearmixing", "excitingmixing", "krylov", "df-sane" – any
@@ -58,13 +64,12 @@ method accepted by SciPy.
58
64
end
59
65
60
66
function SciPyRoot (; method:: String = " hybr" )
61
- _SCIPY_AVAILABLE || error (" `SciPyRoot` requires the Python package `scipy` to be available to PythonCall." )
67
+ _SCIPY_AVAILABLE[] || error (" `SciPyRoot` requires the Python package `scipy` to be available to PythonCall." )
62
68
return SciPyRoot (method, :SciPyRoot )
63
69
end
64
70
65
71
"""
66
72
SciPyRootScalar(; method="brentq")
67
-
68
73
Wrapper over `scipy.optimize.root_scalar` for scalar `IntervalNonlinearProblem`s
69
74
(bracketing problems). The default method uses Brent's algorithm ("brentq").
70
75
Other valid options: "bisect", "brentq", "brenth", "ridder", "toms748",
@@ -76,11 +81,10 @@ Other valid options: "bisect", "brentq", "brenth", "ridder", "toms748",
76
81
end
77
82
78
83
function SciPyRootScalar (; method:: String = " brentq" )
79
- _SCIPY_AVAILABLE || error (" `SciPyRootScalar` requires the Python package `scipy` to be available to PythonCall." )
84
+ _SCIPY_AVAILABLE[] || error (" `SciPyRootScalar` requires the Python package `scipy` to be available to PythonCall." )
80
85
return SciPyRootScalar (method, :SciPyRootScalar )
81
86
end
82
87
83
-
84
88
""" Internal: wrap a Julia residual function into a Python callable """
85
89
function _make_py_residual (f, p)
86
90
return pyfunc (x_py -> begin
@@ -98,7 +102,6 @@ function _make_py_scalar(f, p)
98
102
end )
99
103
end
100
104
101
-
102
105
function SciMLBase. __solve (prob:: SciMLBase.NonlinearLeastSquaresProblem , alg:: SciPyLeastSquares ;
103
106
abstol = nothing , maxiters = 10_000 , alias_u0:: Bool = false ,
104
107
kwargs... )
@@ -116,11 +119,11 @@ function SciMLBase.__solve(prob::SciMLBase.NonlinearLeastSquaresProblem, alg::Sc
116
119
bounds = nothing
117
120
end
118
121
119
- res = scipy_optimize. least_squares (py_f, collect (prob. u0);
122
+ res = scipy_optimize[] . least_squares (py_f, collect (prob. u0);
120
123
method = alg. method,
121
124
loss = alg. loss,
122
125
max_nfev = maxiters,
123
- bounds = bounds === nothing ? PY_NONE : bounds,
126
+ bounds = bounds === nothing ? PY_NONE[] : bounds,
124
127
kwargs... )
125
128
126
129
u_vec = Vector {Float64} (res. x)
143
146
function SciMLBase. __solve (prob:: SciMLBase.NonlinearProblem , alg:: SciPyRoot ;
144
147
abstol = nothing , maxiters = 10_000 , alias_u0:: Bool = false ,
145
148
kwargs... )
146
-
149
+
147
150
f!, u0, resid = construct_extension_function_wrapper (prob; alias_u0)
148
151
149
152
py_f = pyfunc (x_py -> begin
@@ -154,7 +157,7 @@ function SciMLBase.__solve(prob::SciMLBase.NonlinearProblem, alg::SciPyRoot;
154
157
155
158
tol = abstol === nothing ? nothing : abstol
156
159
157
- res = scipy_optimize. root (py_f, collect (u0);
160
+ res = scipy_optimize[] . root (py_f, collect (u0);
158
161
method = alg. method,
159
162
tol = tol,
160
163
options = Dict (" maxiter" => maxiters),
@@ -182,7 +185,7 @@ function SciMLBase.__solve(prob::SciMLBase.IntervalNonlinearProblem, alg::SciPyR
182
185
183
186
a, b = prob. tspan
184
187
185
- res = scipy_optimize. root_scalar (py_f;
188
+ res = scipy_optimize[] . root_scalar (py_f;
186
189
method = alg. method,
187
190
bracket = (a, b),
188
191
maxiter = maxiters,
206
209
export SciPyLeastSquares, SciPyLeastSquaresTRF, SciPyLeastSquaresDogbox, SciPyLeastSquaresLM,
207
210
SciPyRoot, SciPyRootScalar
208
211
209
- end # module
212
+ end
213
+
0 commit comments