Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameter to disable early exit of expression evaluation #91

Merged
merged 44 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
8df4bfe
add early_exit argument
nmheim Jun 30, 2024
8f1d2bd
make tests pass; add first test for early_exit
nmheim Jul 1, 2024
5cbe6a5
bumper & loopvec
nmheim Jul 3, 2024
82c4224
format
nmheim Jul 3, 2024
984a1ba
Merge branch 'SymbolicML:master' into nh/early-exit
nmheim Jul 3, 2024
1ee6884
introduce EvaluationOptions
nmheim Jul 8, 2024
86f8316
Merge branch 'master' into pr/nmheim/91
MilesCranmer Jul 19, 2024
6cab047
style: formatting
MilesCranmer Jul 19, 2024
6d46df9
style: more formatting
MilesCranmer Jul 19, 2024
b8f1087
style: clean up redundant options
MilesCranmer Jul 19, 2024
ee7d7c1
style: rename to `eval_options`
MilesCranmer Jul 19, 2024
c0b5a46
fix: merge edits to eval options
MilesCranmer Jul 19, 2024
17ae595
fix: fix generic eval errors
MilesCranmer Jul 19, 2024
2b98acf
refactor: test_evaluation.jl
MilesCranmer Jul 19, 2024
5399625
fix: specific branch calls
MilesCranmer Jul 19, 2024
204a9df
fix: `v_throw_errors` typo
MilesCranmer Jul 19, 2024
2fc5e87
fix: error catching for generic eval
MilesCranmer Jul 19, 2024
660d6f8
style: rename `EvaluationOptions` to `EvalOptions`
MilesCranmer Jul 19, 2024
dd24df6
test: fix initial errors test
MilesCranmer Jul 19, 2024
6302012
fix: type unstalbe tests
nmheim Jul 22, 2024
a73a04f
add doc strings
nmheim Jul 22, 2024
fe30e8b
update docs
nmheim Jul 22, 2024
944b2e8
format
nmheim Jul 22, 2024
905f5e1
approx equal
nmheim Jul 23, 2024
0a2bb96
fix enzyme test
nmheim Jul 24, 2024
6bd504b
Update docs/src/eval.md
nmheim Jul 25, 2024
54c6398
test: disable enzyme test
MilesCranmer Jul 24, 2024
958b9af
test: skip Enzyme test completely
MilesCranmer Jul 25, 2024
1365a55
test: fix Enzyme test
MilesCranmer Jul 26, 2024
cbcd221
style: fix formatting
MilesCranmer Jul 26, 2024
87c6225
ci: install fixed Enzyme
MilesCranmer Jul 26, 2024
86d2096
fix issue due to https://github.com/JuliaLang/Pkg.jl/issues/1585
MilesCranmer Jul 26, 2024
796cbae
fix custom enzyme install
MilesCranmer Jul 26, 2024
8a1ce63
Merge branch 'master' into nh/early-exit
MilesCranmer Jul 27, 2024
5a30d47
ci: remove Enzyme revision test
MilesCranmer Jul 27, 2024
64c797c
refactor: reduce code complexity of eval options
MilesCranmer Jul 27, 2024
a87016e
docs: render `EvalOptions` in docs
MilesCranmer Jul 27, 2024
ace5c19
test: more coverage of `EvalOptions` branches
MilesCranmer Jul 28, 2024
2c34b4e
test: prevent soft scope problem
MilesCranmer Jul 28, 2024
3113499
refactor: clean up evaluation
MilesCranmer Jul 28, 2024
586baf8
benchmarks: fix benchmark eval options
MilesCranmer Jul 28, 2024
3d7b529
fix: note instability in kw deprecation
MilesCranmer Jul 28, 2024
09b7a3d
feat: also include `early_exit` in scalar checks
MilesCranmer Jul 28, 2024
17a4a24
fix: incorporate `@return_on_nonfinite_val` in LoopVectorization exte…
MilesCranmer Jul 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions ext/DynamicExpressionsBumperExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DynamicExpressionsBumperExt

using Bumper: @no_escape, @alloc
using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce
using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce, EvaluationOptions
using DynamicExpressions.UtilsModule: ResultOk, counttuple, is_bad_array

import DynamicExpressions.ExtensionInterfaceModule:
Expand All @@ -11,8 +11,8 @@ function bumper_eval_tree_array(
tree::AbstractExpressionNode{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
::Val{turbo},
) where {T,turbo}
options::EvaluationOptions{turbo,true,early_exit}
) where {T,turbo,early_exit}
result = similar(cX, axes(cX, 2))
n = size(cX, 2)
all_ok = Ref(false)
Expand All @@ -25,7 +25,7 @@ function bumper_eval_tree_array(
ok = if leaf_node.constant
v = leaf_node.val
ar .= v
isfinite(v)
early_exit ? isfinite(v) : true
else
ar .= view(cX, leaf_node.feature, :)
true
Expand All @@ -37,7 +37,7 @@ function bumper_eval_tree_array(
# In the evaluation kernel, we combine the branch nodes
# with the arrays created by the leaf nodes:
((args::Vararg{Any,M}) where {M}) ->
dispatch_kerns!(operators, args..., Val(turbo)),
dispatch_kerns!(operators, args..., options),
tree;
break_sharing=Val(true),
)
Expand All @@ -48,55 +48,61 @@ function bumper_eval_tree_array(
return (result, all_ok[])
end

function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where {turbo}
function dispatch_kerns!(
operators, branch_node, cumulator, options::EvaluationOptions{turbo,true,early_exit}
) where {turbo,early_exit}
cumulator.ok || return cumulator

out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo))
return ResultOk(out, !is_bad_array(out))
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, options)
return early_exit ? ResultOk(out, !is_bad_array(out)) : ResultOk(out, true)
end
function dispatch_kerns!(
operators, branch_node, cumulator1, cumulator2, ::Val{turbo}
) where {turbo}
operators, branch_node, cumulator1, cumulator2, options::EvaluationOptions{turbo,true,early_exit}
) where {turbo,early_exit}
cumulator1.ok || return cumulator1
cumulator2.ok || return cumulator2

out = dispatch_kern2!(
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo)
)
return ResultOk(out, !is_bad_array(out))
out = dispatch_kern2!(operators.binops, branch_node.op, cumulator1.x, cumulator2.x, options)
return early_exit ? ResultOk(out, !is_bad_array(out)) : ResultOk(out, true)
end

@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo}
@generated function dispatch_kern1!(
unaops, op_idx, cumulator, options::EvaluationOptions{turbo,true,early_exit}
) where {turbo,early_exit}
nuna = counttuple(unaops)
quote
Base.@nif(
$nuna,
i -> i == op_idx,
i -> let op = unaops[i]
return bumper_kern1!(op, cumulator, Val(turbo))
return bumper_kern1!(op, cumulator, options)
end,
)
end
end
@generated function dispatch_kern2!(
binops, op_idx, cumulator1, cumulator2, ::Val{turbo}
) where {turbo}
binops, op_idx, cumulator1, cumulator2, options::EvaluationOptions{turbo,true,early_exit}
) where {turbo,early_exit}
nbin = counttuple(binops)
quote
Base.@nif(
$nbin,
i -> i == op_idx,
i -> let op = binops[i]
return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo))
return bumper_kern2!(op, cumulator1, cumulator2, options)
end,
)
end
end
function bumper_kern1!(op::F, cumulator, ::Val{false}) where {F}
function bumper_kern1!(
op::F, cumulator, ::EvaluationOptions{false,true,early_exit}
) where {F,early_exit}
@. cumulator = op(cumulator)
return cumulator
end
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}) where {F}
function bumper_kern2!(
op::F, cumulator1, cumulator2, ::EvaluationOptions{false,true,early_exit}
) where {F,early_exit}
@. cumulator1 = op(cumulator1, cumulator2)
return cumulator1
end
Expand Down
24 changes: 14 additions & 10 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt
using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateModule: @return_on_check
using DynamicExpressions.EvaluateModule: @return_on_check, EvaluationOptions
import DynamicExpressions.EvaluateModule:
deg1_eval,
deg2_eval,
Expand All @@ -18,7 +18,7 @@ import DynamicExpressions.ExtensionInterfaceModule:
_is_loopvectorization_loaded(::Int) = true

function deg2_eval(
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{true}
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::EvaluationOptions{true}
)::ResultOk where {T<:Number,F}
@turbo for j in eachindex(cumulator_l)
x = op(cumulator_l[j], cumulator_r[j])
Expand All @@ -28,7 +28,7 @@ function deg2_eval(
end

function deg1_eval(
cumulator::AbstractVector{T}, op::F, ::Val{true}
cumulator::AbstractVector{T}, op::F, ::EvaluationOptions{true}
)::ResultOk where {T<:Number,F}
@turbo for j in eachindex(cumulator)
x = op(cumulator[j])
Expand All @@ -38,7 +38,7 @@ function deg1_eval(
end

function deg1_l2_ll0_lr0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true}
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{true}
) where {T<:Number,F,F2}
if tree.l.l.constant && tree.l.r.constant
val_ll = tree.l.l.val
Expand Down Expand Up @@ -86,7 +86,7 @@ function deg1_l2_ll0_lr0_eval(
end

function deg1_l1_ll0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true}
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{true}
) where {T<:Number,F,F2}
if tree.l.l.constant
val_ll = tree.l.l.val
Expand All @@ -109,7 +109,7 @@ function deg1_l1_ll0_eval(
end

function deg2_l0_r0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{true}
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvaluationOptions{true}
) where {T<:Number,F}
if tree.l.constant && tree.r.constant
val_l = tree.l.val
Expand Down Expand Up @@ -157,7 +157,7 @@ function deg2_l0_eval(
cumulator::AbstractVector{T},
cX::AbstractArray{T},
op::F,
::Val{true},
::EvaluationOptions{true}
) where {T<:Number,F}
if tree.l.constant
val = tree.l.val
Expand All @@ -182,7 +182,7 @@ function deg2_r0_eval(
cumulator::AbstractVector{T},
cX::AbstractArray{T},
op::F,
::Val{true},
::EvaluationOptions{true}
) where {T<:Number,F}
if tree.r.constant
val = tree.r.val
Expand All @@ -203,11 +203,15 @@ function deg2_r0_eval(
end

## Interface with Bumper.jl
function bumper_kern1!(op::F, cumulator, ::Val{true}) where {F}
function bumper_kern1!(
op::F, cumulator, ::EvaluationOptions{true,true,early_exit}
) where {F,early_exit}
@turbo @. cumulator = op(cumulator)
return cumulator
end
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{true}) where {F}
function bumper_kern2!(
op::F, cumulator1, cumulator2, ::EvaluationOptions{true,true,early_exit}
) where {F,early_exit}
@turbo @. cumulator1 = op(cumulator1, cumulator2)
return cumulator1
end
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import .NodeModule:
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, EvaluationOptions
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
@reexport import .SimplifyModule: combine_operators, simplify_tree!
Expand Down
Loading
Loading