Skip to content

Commit

Permalink
Merge pull request #1357 from n0rbed/bug_fixes
Browse files Browse the repository at this point in the history
Bug fixes, integration of solve_interms_ofvar and some changes
  • Loading branch information
ChrisRackauckas authored Nov 18, 2024
2 parents 2bc3c54 + 0368c3e commit c31c3fd
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ include("solver/polynomialization.jl")
include("solver/attract.jl")
include("solver/ia_main.jl")
include("solver/main.jl")
include("solver/ia_rules.jl")
include("solver/special_cases.jl")
export symbolic_solve

function symbolics_to_sympy end
Expand Down
24 changes: 16 additions & 8 deletions src/solver/ia_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
lhs = unwrap(lhs)

old_lhs = nothing

while !isequal(lhs, var)
subs, poly = filter_poly(lhs, var)

if check_poly_inunivar(poly, var)
if check_polynomial(poly, strict=false)
roots = []
new_var = gensym()
new_var = (@variables $new_var)[1]
Expand All @@ -20,7 +21,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
else
a, b, islin = linear_expansion(lhs - new_var, var)
if islin
lhs_roots = [-b / a]
lhs_roots = [-b // a]
else
lhs_roots = [RootsOf(lhs - new_var, var)]
if warns
Expand All @@ -31,15 +32,20 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri

for i in eachindex(lhs_roots)
for j in eachindex(rhs)
push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false))
if iscall(lhs_roots[i]) && operation(lhs_roots[i]) == RootsOf
lhs_roots[i].arguments[1] = substitute(lhs_roots[i].arguments[1], Dict(new_var=>rhs[j]), fold=false)
push!(roots, lhs_roots[i])
else
push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false))
end
end
end
return roots, conditions
end

if isequal(old_lhs, lhs)
warns && @warn("This expression cannot be solved with the methods available to ia_solve. Try a numerical method instead.")
return nothing
return nothing, conditions
end

old_lhs = deepcopy(lhs)
Expand Down Expand Up @@ -76,7 +82,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
else
# 2 / x = y
lhs = args[2]
rhs = map(sol -> args[1] // sol, rhs)
rhs = map(sol -> term(/, args[1], sol), rhs)
end

elseif oper === (^)
Expand Down Expand Up @@ -108,6 +114,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
elseif any(isequal(x, var) for x in get_variables(args[1])) &&
n_occurrences(args[2], var) == 0
lhs = args[1]
s, args[2] = filter_stuff(args[2])
rhs = map(sol -> term(^, sol, 1 // args[2]), rhs)
else
lhs = args[2]
Expand Down Expand Up @@ -169,7 +176,7 @@ function attract(lhs, var; warns = true, complex_roots = true, periodic_roots =
return nothing, conditions
end
end

new_var = collect(keys(sub))[1]
new_var_val = collect(values(sub))[1]

Expand All @@ -178,6 +185,7 @@ function attract(lhs, var; warns = true, complex_roots = true, periodic_roots =
new_roots = []

for root in roots
iscall(root) && operation(root) == RootsOf && continue
new_sol, new_conds = isolate(new_var_val - root, var; warns = warns, complex_roots, periodic_roots)
append!(conditions, new_conds)
push!(new_roots, new_sol)
Expand Down Expand Up @@ -273,9 +281,9 @@ function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots =
conditions = []
if nx == 0
warns && @warn("Var not present in given expression")
return []
return nothing
elseif nx == 1
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
elseif nx > 1
sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots)
end
Expand Down
22 changes: 16 additions & 6 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,6 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
expr = Vector{Num}(expr)
end

if expr_univar && !x_univar
expr = [expr]
expr_univar = false
end
if !expr_univar && x_univar
x = [x]
x_univar = false
Expand All @@ -189,8 +185,17 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where
isequal(sols, nothing) && return nothing
sols = map(postprocess_root, sols)
return sols
elseif expr_univar
all_vars = get_variables(expr)
diff_vars = setdiff(wrap.(all_vars), x)
if length(diff_vars) == 1
return solve_interms_ofvar(expr, diff_vars[1], dropmultiplicity=dropmultiplicity, warns=warns)
end

expr = [expr]
end


if !x_univar
for e in expr
for var in x
Expand Down Expand Up @@ -247,6 +252,7 @@ function symbolic_solve(expr; x...)
return symbolic_solve(expr, vars; x...)
end


"""
solve_univar(expression, x; dropmultiplicity=true)
This solver uses analytic solutions up to degree 4 to solve univariate polynomials.
Expand All @@ -266,10 +272,12 @@ implemented in the function `get_roots` and its children.
- dropmultiplicity (optional): Print repeated roots or not?
- strict (optional): Bool that enables/disables strict assert if input expression is a univariate polynomial or not. If strict=true and expression is not a polynomial, `solve_univar` throws an assertion error.
# Examples
"""
function solve_univar(expression, x; dropmultiplicity=true)
function solve_univar(expression, x; dropmultiplicity=true, strict=true)
args = []
mult_n = 1
expression = unwrap(expression)
Expand All @@ -287,6 +295,9 @@ function solve_univar(expression, x; dropmultiplicity=true)
end

subs, filtered_expr, assumptions = filter_poly(expression, x, assumptions=true)
if !strict && !check_polynomial(filtered_expr, strict=false)
return [RootsOf(wrap(expression), wrap(x))]
end
coeffs, constant = polynomial_coeffs(filtered_expr, [x])
degree = sdegree(coeffs, x)

Expand Down Expand Up @@ -325,7 +336,6 @@ function solve_univar(expression, x; dropmultiplicity=true)
end

if isequal(arr_roots, [])
@assert check_polynomial(expression) "This expression could not be solved by `symbolic_solve`."
return [RootsOf(wrap(expression), wrap(x))]
end

Expand Down
12 changes: 8 additions & 4 deletions src/solver/nemo_stuff.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Checks that the expression is a polynomial with integer or rational
# coefficients
function check_polynomial(poly)
function check_polynomial(poly; strict=true)
poly = wrap(poly)
vars = get_variables(poly)
distr, rem = polynomial_coeffs(poly, vars)
@assert isequal(rem, 0) "Not a polynomial"
@assert all(c -> c isa Integer || c isa Rational, collect(values(distr))) "Coefficients must be integer or rational"
return true
if strict
@assert isequal(rem, 0) "Not a polynomial"
@assert all(c -> c isa Integer || c isa Rational, collect(values(distr))) "Coefficients must be integer or rational"
return true
else
return isequal(rem, 0)
end
end

# factor(x^2*y + b*x*y - a*x - a*b) -> (x*y - a)*(x + b)
Expand Down
18 changes: 15 additions & 3 deletions src/solver/postprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,26 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic)
end
end

args = arguments(x)

# (X)^0 => 1
if oper === (^) && isequal(arguments(x)[2], 0)
if oper === (^) && isequal(args[2], 0) && !isequal(args[1], 0)
return 1
end

# (X)^1 => X
if oper === (^) && isequal(arguments(x)[2], 1)
return arguments(x)[1]
if oper === (^) && isequal(args[2], 1)
return args[1]
end

# (0)^X => 0
if oper === (^) && isequal(args[1], 0) && !isequal(args[2], 0)
return 0
end

# y / 0 => Inf
if oper === (/) && !isequal(args[1], 0) && isequal(args[2], 0)
return Inf
end

# sqrt((N / D)^2 * M) => N / D * sqrt(M)
Expand Down
7 changes: 2 additions & 5 deletions src/solver/preprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,10 @@ function clean_f(filtered_expr, var, subs)

if oper === (/)
args = arguments(unwrapped_f)
if any(isequal(var, x) for x in get_variables(args[2]))
filtered_expr = expand(args[1] * args[2])
if !all(isequal(var, x) for x in get_variables(args[2]))
filtered_expr = args[1]
push!(assumptions, substitute(args[2], subs, fold=false))
return filtered_expr, assumptions
end
filtered_expr = args[1]
@info "Assuming $(substitute(args[2], subs, fold=false) != 0)"
end
return filtered_expr, assumptions
end
Expand Down
40 changes: 40 additions & 0 deletions src/solver/ia_rules.jl → src/solver/special_cases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,47 @@ function cross_multiply(eq)
return cross_multiply(eq)
end
end
"""
solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
This special case solver expects a single equation in multiple variables and a
variable `s` (this can be any Num, `s` is used for convenience). The function generates
a system of equations to by observing the coefficients of the powers of `s` present in `eq`.
E.g. a system would look like `a+b = 1`, `a-2b = 3` for the eq `(a+b)s + (a-2b)s^2 - (1)s - (3)s^2 = 0`.
After generating this system, it calls `symbolic_solve`, which uses `solve_multivar`. `symbolic_solve` was chosen
instead of `solve_multivar` because it postprocesses the roots in order to simplify them and make them more user friendly.
Generation of system uses cross multiplication in order to simplify the equation and convert it
to a polynomial like shape.
# Arguments
- eq: Single symbolics Num or SymbolicUtils.BasicSymbolic. This is equated to 0 and then solved. E.g. `expr = x+2`, we solve `x+2 = 0`
- s: Variable to "isolate", i.e. ignore and generate the system of equations based on this variable's coefficients.
- dropmultiplicity (optional): Print repeated roots or not?
- warns (optional, this is not used currently): Warn user when something is wrong or not.
# Examples
```jldoctest
julia> @variables a b x s;
julia> eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3;
julia> Symbolics.solve_interms_ofvar(eq, s)
2-element Vector{Any}:
Dict{Num, Any}(a => -1//10, b => 3//2, x => (0 - 1im)*√(5))
Dict{Num, Any}(a => -1//10, b => 3//2, x => (0 + 1im)*√(5))
```
```jldoctest
julia> eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d));
julia> Symbolics.solve_interms_ofvar(eq, s)
1-element Vector{Any}:
Dict{Num, Any}(a => 1, d => 1, b => 1, c => 1)
```
"""
function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
@assert iscall(unwrap(eq))
vars = Symbolics.get_variables(eq)
Expand Down
27 changes: 19 additions & 8 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import Symbolics: ssqrt, slog, scbrt, symbolic_solve, ia_solve, postprocess_root
@test Base.get_extension(Symbolics, :SymbolicsNemoExt) === nothing
@variables x
roots = ia_solve(log(2 + x), x)
@test substitute(roots[1], Dict()) == -1.0
roots = @test_warn ["Nemo", "required"] ia_solve(log(2 + x^2), x)
@test operation(roots[1]) == Symbolics.RootsOf
end
Expand Down Expand Up @@ -69,23 +68,28 @@ end
@testset "Solving in terms of a constant var" begin
eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d))
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,c,d])
solve_roots = sort_arr(symbolic_solve(eq, [a,b,c,d]), [a,b,c,d])
known_roots = sort_arr([Dict(a=>1, b=>1, c=>1, d=>1)], [a,b,c,d])
@test check_approx(calcd_roots, known_roots)
@test check_approx(solve_roots, known_roots)

eq = (a+b)*s^2 - 2s^2 + 2*b*s - 3*s
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
solve_roots = sort_arr(symbolic_solve(eq, [a,b]), [a,b])
known_roots = sort_arr([Dict(a=>1/2, b=>3/2)], [a,b])
@test check_approx(calcd_roots, known_roots)
@test check_approx(solve_roots, known_roots)

eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,x])
solve_roots = sort_arr(symbolic_solve(eq, [a,b,x]), [a,b,x])
known_roots = sort_arr([Dict(a=>-1/10, b=>3/2, x=>-im*sqrt(5)), Dict(a=>-1/10, b=>3/2, x=>im*sqrt(5))], [a,b,x])
@test check_approx(calcd_roots, known_roots)
@test check_approx(solve_roots, known_roots)
end

@testset "Invalid input" begin
@test_throws AssertionError symbolic_solve(x, x^2)
@test_throws AssertionError symbolic_solve(1/x, x)
end

@testset "Nice univar cases" begin
Expand Down Expand Up @@ -355,14 +359,18 @@ end
@testset "Post Process roots" begin
SymbolicUtils.@syms __x
__symsqrt(x) = SymbolicUtils.term(ssqrt, x)
term = SymbolicUtils.term
@test Symbolics.postprocess_root(2 // 1) == 2 && Symbolics.postprocess_root(2 + 0*im) == 2
@test Symbolics.postprocess_root(__symsqrt(4)) == 2
@test isequal(Symbolics.postprocess_root(__symsqrt(__x)^2), __x)

@test Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 0) ) == 1
@test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.e, 0) ) == 1
@test Symbolics.postprocess_root( SymbolicUtils.term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi
@test isequal(Symbolics.postprocess_root( SymbolicUtils.term(^, __x, 1) ), __x)

@test isequal(Symbolics.postprocess_root(term(^, 0, __x)), 0)
@test_broken isequal(Symbolics.postprocess_root(term(/, __x, 0)), Inf)
@test Symbolics.postprocess_root(term(^, __x, 0) ) == 1
@test Symbolics.postprocess_root(term(^, Base.MathConstants.e, 0) ) == 1
@test Symbolics.postprocess_root(term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi
@test isequal(Symbolics.postprocess_root(term(^, __x, 1) ), __x)

x = Symbolics.term(sqrt, 2)
@test isequal(Symbolics.postprocess_root( expand((x + 1)^4) ), 17 + 12x)
Expand Down Expand Up @@ -426,7 +434,10 @@ end
lhs = ia_solve(a*x^b + c, x)[1]
lhs2 = symbolic_solve(a*x^b + c, x)[1]
rhs = Symbolics.term(^, -c.val/a.val, 1/b.val)
#@test isequal(lhs, rhs)
@test_broken isequal(lhs, rhs)

@test isequal(symbolic_solve(2/x, x)[1], Inf)
@test isequal(symbolic_solve(x^1.5, x)[1], 0)

lhs = symbolic_solve(log(a*x)-b,x)[1]
@test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E)
Expand Down

0 comments on commit c31c3fd

Please sign in to comment.