Skip to content

Commit f992be2

Browse files
fixed import error
Signed-off-by: AdityaPandeyCN <[email protected]>
1 parent 5317b49 commit f992be2

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

lib/NonlinearSolveSciPy/src/NonlinearSolveSciPy.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,28 @@ module NonlinearSolveSciPy
33
using ConcreteStructs: @concrete
44
using Reexport: @reexport
55

6-
using PythonCall: pyimport, pyfunc
7-
const scipy_optimize = try
8-
pyimport("scipy.optimize")
9-
catch
10-
nothing
6+
using PythonCall: pyimport, pyfunc, Py
7+
8+
const scipy_optimize = Ref{Union{Py, Nothing}}(nothing)
9+
const PY_NONE = Ref{Union{Py, Nothing}}(nothing)
10+
const _SCIPY_AVAILABLE = Ref{Bool}(false)
11+
12+
function __init__()
13+
try
14+
scipy_optimize[] = pyimport("scipy.optimize")
15+
PY_NONE[] = pyimport("builtins").None
16+
_SCIPY_AVAILABLE[] = true
17+
catch
18+
19+
_SCIPY_AVAILABLE[] = false
20+
end
1121
end
12-
const _SCIPY_AVAILABLE = scipy_optimize !== nothing
13-
const PY_NONE = pyimport("builtins").None
1422

1523
using SciMLBase
1624
using NonlinearSolveBase: AbstractNonlinearSolveAlgorithm, construct_extension_function_wrapper
1725

1826
"""
1927
SciPyLeastSquares(; method="trf", loss="linear")
20-
2128
Wrapper over `scipy.optimize.least_squares` (via PythonCall) for solving
2229
`NonlinearLeastSquaresProblem`s. The keyword arguments correspond to the
2330
`method` ("trf", "dogbox", "lm") and the robust loss function ("linear",
@@ -30,7 +37,7 @@ Wrapper over `scipy.optimize.least_squares` (via PythonCall) for solving
3037
end
3138

3239
function SciPyLeastSquares(; method::String = "trf", loss::String = "linear")
33-
_SCIPY_AVAILABLE || error("`SciPyLeastSquares` requires the Python package `scipy` to be available to PythonCall.")
40+
_SCIPY_AVAILABLE[] || error("`SciPyLeastSquares` requires the Python package `scipy` to be available to PythonCall.")
3441
valid_methods = ("trf", "dogbox", "lm")
3542
valid_losses = ("linear", "soft_l1", "huber", "cauchy", "arctan")
3643
method in valid_methods ||
@@ -46,7 +53,6 @@ SciPyLeastSquaresLM() = SciPyLeastSquares(method = "lm")
4653

4754
"""
4855
SciPyRoot(; method="hybr")
49-
5056
Wrapper over `scipy.optimize.root` for solving `NonlinearProblem`s. Available
5157
methods include "hybr" (default), "lm", "broyden1", "broyden2", "anderson",
5258
"diagbroyden", "linearmixing", "excitingmixing", "krylov", "df-sane" – any
@@ -58,13 +64,12 @@ method accepted by SciPy.
5864
end
5965

6066
function SciPyRoot(; method::String = "hybr")
61-
_SCIPY_AVAILABLE || error("`SciPyRoot` requires the Python package `scipy` to be available to PythonCall.")
67+
_SCIPY_AVAILABLE[] || error("`SciPyRoot` requires the Python package `scipy` to be available to PythonCall.")
6268
return SciPyRoot(method, :SciPyRoot)
6369
end
6470

6571
"""
6672
SciPyRootScalar(; method="brentq")
67-
6873
Wrapper over `scipy.optimize.root_scalar` for scalar `IntervalNonlinearProblem`s
6974
(bracketing problems). The default method uses Brent's algorithm ("brentq").
7075
Other valid options: "bisect", "brentq", "brenth", "ridder", "toms748",
@@ -76,11 +81,10 @@ Other valid options: "bisect", "brentq", "brenth", "ridder", "toms748",
7681
end
7782

7883
function SciPyRootScalar(; method::String = "brentq")
79-
_SCIPY_AVAILABLE || error("`SciPyRootScalar` requires the Python package `scipy` to be available to PythonCall.")
84+
_SCIPY_AVAILABLE[] || error("`SciPyRootScalar` requires the Python package `scipy` to be available to PythonCall.")
8085
return SciPyRootScalar(method, :SciPyRootScalar)
8186
end
8287

83-
8488
""" Internal: wrap a Julia residual function into a Python callable """
8589
function _make_py_residual(f, p)
8690
return pyfunc(x_py -> begin
@@ -98,7 +102,6 @@ function _make_py_scalar(f, p)
98102
end)
99103
end
100104

101-
102105
function SciMLBase.__solve(prob::SciMLBase.NonlinearLeastSquaresProblem, alg::SciPyLeastSquares;
103106
abstol = nothing, maxiters = 10_000, alias_u0::Bool = false,
104107
kwargs...)
@@ -116,11 +119,11 @@ function SciMLBase.__solve(prob::SciMLBase.NonlinearLeastSquaresProblem, alg::Sc
116119
bounds = nothing
117120
end
118121

119-
res = scipy_optimize.least_squares(py_f, collect(prob.u0);
122+
res = scipy_optimize[].least_squares(py_f, collect(prob.u0);
120123
method = alg.method,
121124
loss = alg.loss,
122125
max_nfev = maxiters,
123-
bounds = bounds === nothing ? PY_NONE : bounds,
126+
bounds = bounds === nothing ? PY_NONE[] : bounds,
124127
kwargs...)
125128

126129
u_vec = Vector{Float64}(res.x)
@@ -143,7 +146,7 @@ end
143146
function SciMLBase.__solve(prob::SciMLBase.NonlinearProblem, alg::SciPyRoot;
144147
abstol = nothing, maxiters = 10_000, alias_u0::Bool = false,
145148
kwargs...)
146-
149+
147150
f!, u0, resid = construct_extension_function_wrapper(prob; alias_u0)
148151

149152
py_f = pyfunc(x_py -> begin
@@ -154,7 +157,7 @@ function SciMLBase.__solve(prob::SciMLBase.NonlinearProblem, alg::SciPyRoot;
154157

155158
tol = abstol === nothing ? nothing : abstol
156159

157-
res = scipy_optimize.root(py_f, collect(u0);
160+
res = scipy_optimize[].root(py_f, collect(u0);
158161
method = alg.method,
159162
tol = tol,
160163
options = Dict("maxiter" => maxiters),
@@ -182,7 +185,7 @@ function SciMLBase.__solve(prob::SciMLBase.IntervalNonlinearProblem, alg::SciPyR
182185

183186
a, b = prob.tspan
184187

185-
res = scipy_optimize.root_scalar(py_f;
188+
res = scipy_optimize[].root_scalar(py_f;
186189
method = alg.method,
187190
bracket = (a, b),
188191
maxiter = maxiters,
@@ -206,4 +209,5 @@ end
206209
export SciPyLeastSquares, SciPyLeastSquaresTRF, SciPyLeastSquaresDogbox, SciPyLeastSquaresLM,
207210
SciPyRoot, SciPyRootScalar
208211

209-
end # module
212+
end
213+

0 commit comments

Comments
 (0)