From 8df4bfefb7fb5394cfeec1f8af6348cb6a3d46e2 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Sun, 30 Jun 2024 20:48:44 +0200 Subject: [PATCH 01/41] add early_exit argument --- src/Evaluate.jl | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index b41add9b..083cbf7a 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -69,8 +69,10 @@ function eval_tree_array( operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false), bumper::Union{Bool,Val}=Val(false), + early_exit::Union{Bool,Val}=Val(true), ) where {T<:Number} v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false)) + v_early_exit = isa(turbo, Val) ? early_exit : (early_exit ? Val(true) : Val(false)) v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false)) if v_turbo isa Val{true} || v_bumper isa Val{true} @assert T in (Float32, Float64) @@ -83,8 +85,12 @@ function eval_tree_array( return bumper_eval_tree_array(tree, cX, operators, v_turbo) end - result = _eval_tree_array(tree, cX, operators, v_turbo) - return (result.x, result.ok && !is_bad_array(result.x)) + result = _eval_tree_array(tree, cX, operators, v_turbo, v_early_exit) + if v_early_exit isa Val{true} + return (result.x, result.ok && !is_bad_array(result.x)) + else + return (result.x, result.ok) + end end function eval_tree_array( tree::AbstractExpressionNode{T1}, @@ -92,12 +98,13 @@ function eval_tree_array( operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false), bumper::Union{Bool,Val}=Val(false), + early_exit::Union{Bool,Val}=Val(true), ) where {T1<:Number,T2<:Number} T = promote_type(T1, T2) @warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)." tree = convert(constructorof(typeof(tree)){T}, tree) cX = Base.Fix1(convert, T).(cX) - return eval_tree_array(tree, cX, operators; turbo, bumper) + return eval_tree_array(tree, cX, operators; turbo, bumper, early_exit) end get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U) @@ -108,7 +115,8 @@ function _eval_tree_array( cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}, -)::ResultOk where {T<:Number,turbo} + ::Val{early_exit}, +)::ResultOk where {T<:Number,turbo,early_exit} # First, we see if there are only constants in the tree - meaning # we can just return the constant result. if tree.degree == 0 @@ -120,12 +128,12 @@ function _eval_tree_array( return ResultOk(fill_similar(const_result.x[], cX, axes(cX, 2)), true) elseif tree.degree == 1 op_idx = tree.op - return dispatch_deg1_eval(tree, cX, op_idx, operators, Val(turbo)) + return dispatch_deg1_eval(tree, cX, op_idx, operators, Val(turbo), Val(early_exit)) else # TODO - add op(op2(x, y), z) and op(x, op2(y, z)) # op(x, y), where x, y are constants or variables. op_idx = tree.op - return dispatch_deg2_eval(tree, cX, op_idx, operators, Val(turbo)) + return dispatch_deg2_eval(tree, cX, op_idx, operators, Val(turbo), Val(early_exit)) end end @@ -165,17 +173,18 @@ end op_idx::Integer, operators::OperatorEnum, ::Val{turbo}, -) where {T<:Number,turbo} + ::Val{early_exit}, +) where {T<:Number,turbo,early_exit} nbin = get_nbin(operators) long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo)) !result_l.ok && return result_l - @return_on_nonfinite_array result_l.x + early_exit && @return_on_nonfinite_array result_l.x result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo)) !result_r.ok && return result_r - @return_on_nonfinite_array result_r.x + early_exit && @return_on_nonfinite_array result_r.x # op(x, y), for any x or y deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], Val(turbo)) end @@ -190,22 +199,22 @@ end elseif tree.r.degree == 0 result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo)) !result_l.ok && return result_l - @return_on_nonfinite_array result_l.x + early_exit && @return_on_nonfinite_array result_l.x # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, Val(turbo)) elseif tree.l.degree == 0 result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo)) !result_r.ok && return result_r - @return_on_nonfinite_array result_r.x + early_exit && @return_on_nonfinite_array result_r.x # op(x, y), where x is a constant or variable but y is not. deg2_l0_eval(tree, result_r.x, cX, op, Val(turbo)) else result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo)) !result_l.ok && return result_l - @return_on_nonfinite_array result_l.x + early_exit && @return_on_nonfinite_array result_l.x result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo)) !result_r.ok && return result_r - @return_on_nonfinite_array result_r.x + early_exit && @return_on_nonfinite_array result_r.x # op(x, y), for any x or y deg2_eval(result_l.x, result_r.x, op, Val(turbo)) end @@ -219,14 +228,15 @@ end op_idx::Integer, operators::OperatorEnum, ::Val{turbo}, -) where {T<:Number,turbo} + ::Val{early_exit}, +) where {T<:Number,turbo,early_exit} nuna = get_nuna(operators) long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote result = _eval_tree_array(tree.l, cX, operators, Val(turbo)) !result.ok && return result - @return_on_nonfinite_array result.x + early_exit && @return_on_nonfinite_array result.x early_exit deg1_eval(result.x, operators.unaops[op_idx], Val(turbo)) end end @@ -253,7 +263,7 @@ end # op(x), for any x. result = _eval_tree_array(tree.l, cX, operators, Val(turbo)) !result.ok && return result - @return_on_nonfinite_array result.x + early_exit && @return_on_nonfinite_array result.x early_exit deg1_eval(result.x, op, Val(turbo)) end end From 8f1d2bd56576c001513f8cec4e739dac0ba6b497 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Mon, 1 Jul 2024 10:37:23 +0200 Subject: [PATCH 02/41] make tests pass; add first test for early_exit --- src/Evaluate.jl | 22 +++++++++++----------- test/test_expressions.jl | 13 +++++++++++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 083cbf7a..4dc341fa 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -72,7 +72,7 @@ function eval_tree_array( early_exit::Union{Bool,Val}=Val(true), ) where {T<:Number} v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false)) - v_early_exit = isa(turbo, Val) ? early_exit : (early_exit ? Val(true) : Val(false)) + v_early_exit = isa(early_exit, Val) ? early_exit : (early_exit ? Val(true) : Val(false)) v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false)) if v_turbo isa Val{true} || v_bumper isa Val{true} @assert T in (Float32, Float64) @@ -179,10 +179,10 @@ end long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote - result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo)) + result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) !result_l.ok && return result_l early_exit && @return_on_nonfinite_array result_l.x - result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo)) + result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo), Val(early_exit)) !result_r.ok && return result_r early_exit && @return_on_nonfinite_array result_r.x # op(x, y), for any x or y @@ -197,22 +197,22 @@ end if tree.l.degree == 0 && tree.r.degree == 0 deg2_l0_r0_eval(tree, cX, op, Val(turbo)) elseif tree.r.degree == 0 - result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo)) + result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) !result_l.ok && return result_l early_exit && @return_on_nonfinite_array result_l.x # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, Val(turbo)) elseif tree.l.degree == 0 - result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo)) + result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo), Val(early_exit)) !result_r.ok && return result_r early_exit && @return_on_nonfinite_array result_r.x # op(x, y), where x is a constant or variable but y is not. deg2_l0_eval(tree, result_r.x, cX, op, Val(turbo)) else - result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo)) + result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) !result_l.ok && return result_l early_exit && @return_on_nonfinite_array result_l.x - result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo)) + result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo), Val(early_exit)) !result_r.ok && return result_r early_exit && @return_on_nonfinite_array result_r.x # op(x, y), for any x or y @@ -234,9 +234,9 @@ end long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote - result = _eval_tree_array(tree.l, cX, operators, Val(turbo)) + result = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) !result.ok && return result - early_exit && @return_on_nonfinite_array result.x early_exit + early_exit && @return_on_nonfinite_array result.x deg1_eval(result.x, operators.unaops[op_idx], Val(turbo)) end end @@ -261,9 +261,9 @@ end ) else # op(x), for any x. - result = _eval_tree_array(tree.l, cX, operators, Val(turbo)) + result = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) !result.ok && return result - early_exit && @return_on_nonfinite_array result.x early_exit + early_exit && @return_on_nonfinite_array result.x deg1_eval(result.x, op, Val(turbo)) end end diff --git a/test/test_expressions.jl b/test/test_expressions.jl index 86b581cd..2dfb72c9 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -249,3 +249,16 @@ end tree = get_tree(ex) @test_throws ArgumentError get_operators(tree, nothing) end + +@testitem "Disable early exit" begin + using DynamicExpressions + + T = Float64 + x = Node{T}(feature=1) + ops = OperatorEnum(binary_operators=[*]) + expr = x*2 + + X = [1.0 floatmax(T)] + @test all(isnan.(expr(X, ops))) + @test expr(X, ops, early_exit=Val(false)) ≈ [2.0, Inf] +end From 5cbe6a58751e6347a7fe347310c45ee3bb27fb53 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Wed, 3 Jul 2024 11:50:52 +0200 Subject: [PATCH 03/41] bumper & loopvec --- ext/DynamicExpressionsBumperExt.jl | 36 ++++++++++--------- ext/DynamicExpressionsLoopVectorizationExt.jl | 4 +-- src/Evaluate.jl | 2 +- test/test_evaluation.jl | 32 +++++++++++++++++ test/test_expressions.jl | 13 ------- 5 files changed, 54 insertions(+), 33 deletions(-) diff --git a/ext/DynamicExpressionsBumperExt.jl b/ext/DynamicExpressionsBumperExt.jl index 98934133..b4b89380 100644 --- a/ext/DynamicExpressionsBumperExt.jl +++ b/ext/DynamicExpressionsBumperExt.jl @@ -12,7 +12,8 @@ function bumper_eval_tree_array( cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}, -) where {T,turbo} + ::Val{early_exit} +) where {T,turbo,early_exit} result = similar(cX, axes(cX, 2)) n = size(cX, 2) all_ok = Ref(false) @@ -25,7 +26,7 @@ function bumper_eval_tree_array( ok = if leaf_node.constant v = leaf_node.val ar .= v - isfinite(v) + early_exit ? isfinite(v) : true else ar .= view(cX, leaf_node.feature, :) true @@ -37,7 +38,7 @@ function bumper_eval_tree_array( # In the evaluation kernel, we combine the branch nodes # with the arrays created by the leaf nodes: ((args::Vararg{Any,M}) where {M}) -> - dispatch_kerns!(operators, args..., Val(turbo)), + dispatch_kerns!(operators, args..., Val(turbo), Val(early_exit)), tree; break_sharing=Val(true), ) @@ -48,55 +49,56 @@ function bumper_eval_tree_array( return (result, all_ok[]) end -function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where {turbo} +function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}, ::Val{early_exit}) where {turbo,early_exit} cumulator.ok || return cumulator - out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo)) - return ResultOk(out, !is_bad_array(out)) + out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo), Val(early_exit)) + return early_exit ? ResultOk(out, !is_bad_array(out)) : ResultOk(out, true) end function dispatch_kerns!( - operators, branch_node, cumulator1, cumulator2, ::Val{turbo} -) where {turbo} + operators, branch_node, cumulator1, cumulator2, ::Val{turbo}, ::Val{early_exit} +) where {turbo,early_exit} cumulator1.ok || return cumulator1 cumulator2.ok || return cumulator2 out = dispatch_kern2!( - operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo) + operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo), Val(early_exit) ) - return ResultOk(out, !is_bad_array(out)) + return early_exit ? ResultOk(out, !is_bad_array(out)) : ResultOk(out, true) end -@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo} +@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}, ::Val{early_exit}) where {turbo,early_exit} nuna = counttuple(unaops) quote Base.@nif( $nuna, i -> i == op_idx, i -> let op = unaops[i] - return bumper_kern1!(op, cumulator, Val(turbo)) + return bumper_kern1!(op, cumulator, Val(turbo), Val(early_exit)) end, ) end end @generated function dispatch_kern2!( - binops, op_idx, cumulator1, cumulator2, ::Val{turbo} -) where {turbo} + binops, op_idx, cumulator1, cumulator2, ::Val{turbo}, ::Val{early_exit} +) where {turbo,early_exit} nbin = counttuple(binops) quote Base.@nif( $nbin, i -> i == op_idx, i -> let op = binops[i] - return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo)) + return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo), Val(early_exit)) end, ) end end -function bumper_kern1!(op::F, cumulator, ::Val{false}) where {F} +# FIXME: keeping the early_exit parameter for readability... should it be removed? +function bumper_kern1!(op::F, cumulator, ::Val{false}, ::Val{early_exit}) where {F,early_exit} @. cumulator = op(cumulator) return cumulator end -function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}) where {F} +function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}, ::Val{early_exit}) where {F,early_exit} @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 end diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index ec158320..960b4eb6 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -203,11 +203,11 @@ function deg2_r0_eval( end ## Interface with Bumper.jl -function bumper_kern1!(op::F, cumulator, ::Val{true}) where {F} +function bumper_kern1!(op::F, cumulator, ::Val{true}, ::Val{early_exit}) where {F,early_exit} @turbo @. cumulator = op(cumulator) return cumulator end -function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{true}) where {F} +function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{true}, ::Val{early_exit}) where {F,early_exit} @turbo @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 end diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 4dc341fa..c3f7feb6 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -82,7 +82,7 @@ function eval_tree_array( error("Please load the LoopVectorization.jl package to use this feature.") end if v_bumper isa Val{true} - return bumper_eval_tree_array(tree, cX, operators, v_turbo) + return bumper_eval_tree_array(tree, cX, operators, v_turbo, v_early_exit) end result = _eval_tree_array(tree, cX, operators, v_turbo, v_early_exit) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 66dd4b4d..3e92d4ba 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -218,3 +218,35 @@ end basic_eval ≈ many_ops_eval end end + +@testset "Disable early exit" begin + using DynamicExpressions + + T = Float16 + ex = @parse_expression(2*x, binary_operators=[*], variable_names=["x"], node_type=Node{T}) + X = T[1.0 floatmax(T)] + @test all(isnan.(ex(X))) + @test ex(X, early_exit=Val(false)) ≈ [2.0, Inf] + + + for turbo in [Val(false), Val(true)], + T in [Float32, Float64], + bumper in [Val(false), Val(true)] + + ex = @parse_expression( + (-b - sqrt(b^2 - (4*a)*c)) / (2*c), + binary_operators=[-,*,/,^], + unary_operators=[-,sqrt], + variable_names=["a", "b", "c"], + node_type=Node{T} + ) + X = T[ + -1 -1; + 1 floatmax(T); + 1 1; + ] + y = + @test all(isnan.(ex(X, bumper=bumper, turbo=turbo))) + @test ex(X, bumper=bumper, turbo=turbo, early_exit=Val(false)) ≈ T[-1.618033988749895, -Inf] + end +end diff --git a/test/test_expressions.jl b/test/test_expressions.jl index 2dfb72c9..86b581cd 100644 --- a/test/test_expressions.jl +++ b/test/test_expressions.jl @@ -249,16 +249,3 @@ end tree = get_tree(ex) @test_throws ArgumentError get_operators(tree, nothing) end - -@testitem "Disable early exit" begin - using DynamicExpressions - - T = Float64 - x = Node{T}(feature=1) - ops = OperatorEnum(binary_operators=[*]) - expr = x*2 - - X = [1.0 floatmax(T)] - @test all(isnan.(expr(X, ops))) - @test expr(X, ops, early_exit=Val(false)) ≈ [2.0, Inf] -end From 82c42240b5605adf6b4d816c7061e91b78017a1c Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Wed, 3 Jul 2024 12:01:30 +0200 Subject: [PATCH 04/41] format --- ext/DynamicExpressionsBumperExt.jl | 29 ++++++++++++++----- ext/DynamicExpressionsLoopVectorizationExt.jl | 8 +++-- src/Evaluate.jl | 20 +++++++++---- test/test_evaluation.jl | 29 ++++++++++--------- 4 files changed, 58 insertions(+), 28 deletions(-) diff --git a/ext/DynamicExpressionsBumperExt.jl b/ext/DynamicExpressionsBumperExt.jl index b4b89380..84156f8f 100644 --- a/ext/DynamicExpressionsBumperExt.jl +++ b/ext/DynamicExpressionsBumperExt.jl @@ -12,7 +12,7 @@ function bumper_eval_tree_array( cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}, - ::Val{early_exit} + ::Val{early_exit}, ) where {T,turbo,early_exit} result = similar(cX, axes(cX, 2)) n = size(cX, 2) @@ -49,10 +49,14 @@ function bumper_eval_tree_array( return (result, all_ok[]) end -function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}, ::Val{early_exit}) where {turbo,early_exit} +function dispatch_kerns!( + operators, branch_node, cumulator, ::Val{turbo}, ::Val{early_exit} +) where {turbo,early_exit} cumulator.ok || return cumulator - out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo), Val(early_exit)) + out = dispatch_kern1!( + operators.unaops, branch_node.op, cumulator.x, Val(turbo), Val(early_exit) + ) return early_exit ? ResultOk(out, !is_bad_array(out)) : ResultOk(out, true) end function dispatch_kerns!( @@ -62,12 +66,19 @@ function dispatch_kerns!( cumulator2.ok || return cumulator2 out = dispatch_kern2!( - operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo), Val(early_exit) + operators.binops, + branch_node.op, + cumulator1.x, + cumulator2.x, + Val(turbo), + Val(early_exit), ) return early_exit ? ResultOk(out, !is_bad_array(out)) : ResultOk(out, true) end -@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}, ::Val{early_exit}) where {turbo,early_exit} +@generated function dispatch_kern1!( + unaops, op_idx, cumulator, ::Val{turbo}, ::Val{early_exit} +) where {turbo,early_exit} nuna = counttuple(unaops) quote Base.@nif( @@ -94,11 +105,15 @@ end end end # FIXME: keeping the early_exit parameter for readability... should it be removed? -function bumper_kern1!(op::F, cumulator, ::Val{false}, ::Val{early_exit}) where {F,early_exit} +function bumper_kern1!( + op::F, cumulator, ::Val{false}, ::Val{early_exit} +) where {F,early_exit} @. cumulator = op(cumulator) return cumulator end -function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}, ::Val{early_exit}) where {F,early_exit} +function bumper_kern2!( + op::F, cumulator1, cumulator2, ::Val{false}, ::Val{early_exit} +) where {F,early_exit} @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 end diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index 960b4eb6..e671495b 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -203,11 +203,15 @@ function deg2_r0_eval( end ## Interface with Bumper.jl -function bumper_kern1!(op::F, cumulator, ::Val{true}, ::Val{early_exit}) where {F,early_exit} +function bumper_kern1!( + op::F, cumulator, ::Val{true}, ::Val{early_exit} +) where {F,early_exit} @turbo @. cumulator = op(cumulator) return cumulator end -function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{true}, ::Val{early_exit}) where {F,early_exit} +function bumper_kern2!( + op::F, cumulator1, cumulator2, ::Val{true}, ::Val{early_exit} +) where {F,early_exit} @turbo @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 end diff --git a/src/Evaluate.jl b/src/Evaluate.jl index c3f7feb6..74ea5bb5 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -197,22 +197,30 @@ end if tree.l.degree == 0 && tree.r.degree == 0 deg2_l0_r0_eval(tree, cX, op, Val(turbo)) elseif tree.r.degree == 0 - result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) + result_l = _eval_tree_array( + tree.l, cX, operators, Val(turbo), Val(early_exit) + ) !result_l.ok && return result_l early_exit && @return_on_nonfinite_array result_l.x # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, Val(turbo)) elseif tree.l.degree == 0 - result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo), Val(early_exit)) + result_r = _eval_tree_array( + tree.r, cX, operators, Val(turbo), Val(early_exit) + ) !result_r.ok && return result_r early_exit && @return_on_nonfinite_array result_r.x # op(x, y), where x is a constant or variable but y is not. deg2_l0_eval(tree, result_r.x, cX, op, Val(turbo)) else - result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) + result_l = _eval_tree_array( + tree.l, cX, operators, Val(turbo), Val(early_exit) + ) !result_l.ok && return result_l early_exit && @return_on_nonfinite_array result_l.x - result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo), Val(early_exit)) + result_r = _eval_tree_array( + tree.r, cX, operators, Val(turbo), Val(early_exit) + ) !result_r.ok && return result_r early_exit && @return_on_nonfinite_array result_r.x # op(x, y), for any x or y @@ -261,7 +269,9 @@ end ) else # op(x), for any x. - result = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) + result = _eval_tree_array( + tree.l, cX, operators, Val(turbo), Val(early_exit) + ) !result.ok && return result early_exit && @return_on_nonfinite_array result.x deg1_eval(result.x, op, Val(turbo)) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 3e92d4ba..79b1d71b 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -223,30 +223,31 @@ end using DynamicExpressions T = Float16 - ex = @parse_expression(2*x, binary_operators=[*], variable_names=["x"], node_type=Node{T}) + ex = @parse_expression( + 2 * x, binary_operators = [*], variable_names = ["x"], node_type = Node{T} + ) X = T[1.0 floatmax(T)] @test all(isnan.(ex(X))) - @test ex(X, early_exit=Val(false)) ≈ [2.0, Inf] - + @test ex(X; early_exit=Val(false)) ≈ [2.0, Inf] for turbo in [Val(false), Val(true)], T in [Float32, Float64], bumper in [Val(false), Val(true)] ex = @parse_expression( - (-b - sqrt(b^2 - (4*a)*c)) / (2*c), - binary_operators=[-,*,/,^], - unary_operators=[-,sqrt], - variable_names=["a", "b", "c"], - node_type=Node{T} + (-b - sqrt(b^2 - (4 * a) * c)) / (2 * c), + binary_operators = [-, *, /, ^], + unary_operators = [-, sqrt], + variable_names = ["a", "b", "c"], + node_type = Node{T} ) X = T[ - -1 -1; - 1 floatmax(T); - 1 1; + -1 -1 + 1 floatmax(T) + 1 1 ] - y = - @test all(isnan.(ex(X, bumper=bumper, turbo=turbo))) - @test ex(X, bumper=bumper, turbo=turbo, early_exit=Val(false)) ≈ T[-1.618033988749895, -Inf] + y = @test all(isnan.(ex(X; bumper=bumper, turbo=turbo))) + @test ex(X; bumper=bumper, turbo=turbo, early_exit=Val(false)) ≈ + T[-1.618033988749895, -Inf] end end From 1ee6884f709beb8700bd5dad51ff8f569deddc30 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Mon, 8 Jul 2024 12:36:58 +0200 Subject: [PATCH 05/41] introduce EvaluationOptions --- ext/DynamicExpressionsBumperExt.jl | 37 ++-- ext/DynamicExpressionsLoopVectorizationExt.jl | 20 +-- src/DynamicExpressions.jl | 2 +- src/Evaluate.jl | 161 +++++++++--------- src/precompile.jl | 12 +- test/test_evaluation.jl | 13 +- test/test_initial_errors.jl | 4 +- 7 files changed, 119 insertions(+), 130 deletions(-) diff --git a/ext/DynamicExpressionsBumperExt.jl b/ext/DynamicExpressionsBumperExt.jl index 84156f8f..550cf561 100644 --- a/ext/DynamicExpressionsBumperExt.jl +++ b/ext/DynamicExpressionsBumperExt.jl @@ -1,7 +1,7 @@ module DynamicExpressionsBumperExt using Bumper: @no_escape, @alloc -using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce +using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce, EvaluationOptions using DynamicExpressions.UtilsModule: ResultOk, counttuple, is_bad_array import DynamicExpressions.ExtensionInterfaceModule: @@ -11,8 +11,7 @@ function bumper_eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, - ::Val{turbo}, - ::Val{early_exit}, + options::EvaluationOptions{turbo,true,early_exit} ) where {T,turbo,early_exit} result = similar(cX, axes(cX, 2)) n = size(cX, 2) @@ -38,7 +37,7 @@ function bumper_eval_tree_array( # In the evaluation kernel, we combine the branch nodes # with the arrays created by the leaf nodes: ((args::Vararg{Any,M}) where {M}) -> - dispatch_kerns!(operators, args..., Val(turbo), Val(early_exit)), + dispatch_kerns!(operators, args..., options), tree; break_sharing=Val(true), ) @@ -50,34 +49,25 @@ function bumper_eval_tree_array( end function dispatch_kerns!( - operators, branch_node, cumulator, ::Val{turbo}, ::Val{early_exit} + operators, branch_node, cumulator, options::EvaluationOptions{turbo,true,early_exit} ) where {turbo,early_exit} cumulator.ok || return cumulator - out = dispatch_kern1!( - operators.unaops, branch_node.op, cumulator.x, Val(turbo), Val(early_exit) - ) + out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, options) return early_exit ? ResultOk(out, !is_bad_array(out)) : ResultOk(out, true) end function dispatch_kerns!( - operators, branch_node, cumulator1, cumulator2, ::Val{turbo}, ::Val{early_exit} + operators, branch_node, cumulator1, cumulator2, options::EvaluationOptions{turbo,true,early_exit} ) where {turbo,early_exit} cumulator1.ok || return cumulator1 cumulator2.ok || return cumulator2 - out = dispatch_kern2!( - operators.binops, - branch_node.op, - cumulator1.x, - cumulator2.x, - Val(turbo), - Val(early_exit), - ) + out = dispatch_kern2!(operators.binops, branch_node.op, cumulator1.x, cumulator2.x, options) return early_exit ? ResultOk(out, !is_bad_array(out)) : ResultOk(out, true) end @generated function dispatch_kern1!( - unaops, op_idx, cumulator, ::Val{turbo}, ::Val{early_exit} + unaops, op_idx, cumulator, options::EvaluationOptions{turbo,true,early_exit} ) where {turbo,early_exit} nuna = counttuple(unaops) quote @@ -85,13 +75,13 @@ end $nuna, i -> i == op_idx, i -> let op = unaops[i] - return bumper_kern1!(op, cumulator, Val(turbo), Val(early_exit)) + return bumper_kern1!(op, cumulator, options) end, ) end end @generated function dispatch_kern2!( - binops, op_idx, cumulator1, cumulator2, ::Val{turbo}, ::Val{early_exit} + binops, op_idx, cumulator1, cumulator2, options::EvaluationOptions{turbo,true,early_exit} ) where {turbo,early_exit} nbin = counttuple(binops) quote @@ -99,20 +89,19 @@ end $nbin, i -> i == op_idx, i -> let op = binops[i] - return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo), Val(early_exit)) + return bumper_kern2!(op, cumulator1, cumulator2, options) end, ) end end -# FIXME: keeping the early_exit parameter for readability... should it be removed? function bumper_kern1!( - op::F, cumulator, ::Val{false}, ::Val{early_exit} + op::F, cumulator, ::EvaluationOptions{false,true,early_exit} ) where {F,early_exit} @. cumulator = op(cumulator) return cumulator end function bumper_kern2!( - op::F, cumulator1, cumulator2, ::Val{false}, ::Val{early_exit} + op::F, cumulator1, cumulator2, ::EvaluationOptions{false,true,early_exit} ) where {F,early_exit} @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index e671495b..c397b9fc 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt using LoopVectorization: @turbo using DynamicExpressions: AbstractExpressionNode using DynamicExpressions.UtilsModule: ResultOk, fill_similar -using DynamicExpressions.EvaluateModule: @return_on_check +using DynamicExpressions.EvaluateModule: @return_on_check, EvaluationOptions import DynamicExpressions.EvaluateModule: deg1_eval, deg2_eval, @@ -18,7 +18,7 @@ import DynamicExpressions.ExtensionInterfaceModule: _is_loopvectorization_loaded(::Int) = true function deg2_eval( - cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{true} + cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::EvaluationOptions{true} )::ResultOk where {T<:Number,F} @turbo for j in eachindex(cumulator_l) x = op(cumulator_l[j], cumulator_r[j]) @@ -28,7 +28,7 @@ function deg2_eval( end function deg1_eval( - cumulator::AbstractVector{T}, op::F, ::Val{true} + cumulator::AbstractVector{T}, op::F, ::EvaluationOptions{true} )::ResultOk where {T<:Number,F} @turbo for j in eachindex(cumulator) x = op(cumulator[j]) @@ -38,7 +38,7 @@ function deg1_eval( end function deg1_l2_ll0_lr0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{true} ) where {T<:Number,F,F2} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val @@ -86,7 +86,7 @@ function deg1_l2_ll0_lr0_eval( end function deg1_l1_ll0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{true} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{true} ) where {T<:Number,F,F2} if tree.l.l.constant val_ll = tree.l.l.val @@ -109,7 +109,7 @@ function deg1_l1_ll0_eval( end function deg2_l0_r0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{true} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvaluationOptions{true} ) where {T<:Number,F} if tree.l.constant && tree.r.constant val_l = tree.l.val @@ -157,7 +157,7 @@ function deg2_l0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::Val{true}, + ::EvaluationOptions{true} ) where {T<:Number,F} if tree.l.constant val = tree.l.val @@ -182,7 +182,7 @@ function deg2_r0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::Val{true}, + ::EvaluationOptions{true} ) where {T<:Number,F} if tree.r.constant val = tree.r.val @@ -204,13 +204,13 @@ end ## Interface with Bumper.jl function bumper_kern1!( - op::F, cumulator, ::Val{true}, ::Val{early_exit} + op::F, cumulator, ::EvaluationOptions{true,true,early_exit} ) where {F,early_exit} @turbo @. cumulator = op(cumulator) return cumulator end function bumper_kern2!( - op::F, cumulator1, cumulator2, ::Val{true}, ::Val{early_exit} + op::F, cumulator1, cumulator2, ::EvaluationOptions{true,true,early_exit} ) where {F,early_exit} @turbo @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 0425367b..77b3b887 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -59,7 +59,7 @@ import .NodeModule: @reexport import .OperatorEnumModule: AbstractOperatorEnum @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! -@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array +@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, EvaluationOptions @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 74ea5bb5..be146476 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -26,6 +26,22 @@ macro return_on_nonfinite_array(array) end ) end +struct EvaluationOptions{T,B,E} + turbo::Val{T} + bumper::Val{B} + early_exit::Val{E} +end +function EvaluationOptions(; turbo=false, bumper=false, early_exit=true) + v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false)) + v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false)) + v_early_exit = isa(early_exit, Val) ? early_exit : (early_exit ? Val(true) : Val(false)) + return EvaluationOptions(v_turbo, v_bumper, v_early_exit) +end +function EvaluationOptions{T,B,E}(; + turbo=Val(false), bumper=Val(false), early_exit=Val(true) +) where {T,B,E} + return EvaluationOptions{T,B,E}(turbo, bumper, early_exit) +end """ eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false), bumper::Union{Bool,Val}=Val(false)) @@ -67,26 +83,21 @@ function eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; - turbo::Union{Bool,Val}=Val(false), - bumper::Union{Bool,Val}=Val(false), - early_exit::Union{Bool,Val}=Val(true), + options::EvaluationOptions = EvaluationOptions() ) where {T<:Number} - v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false)) - v_early_exit = isa(early_exit, Val) ? early_exit : (early_exit ? Val(true) : Val(false)) - v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false)) - if v_turbo isa Val{true} || v_bumper isa Val{true} + if options.turbo isa Val{true} || options.bumper isa Val{true} @assert T in (Float32, Float64) end - if v_turbo isa Val{true} + if options.turbo isa Val{true} _is_loopvectorization_loaded(0) || error("Please load the LoopVectorization.jl package to use this feature.") end - if v_bumper isa Val{true} - return bumper_eval_tree_array(tree, cX, operators, v_turbo, v_early_exit) + if options.bumper isa Val{true} + return bumper_eval_tree_array(tree, cX, operators, options) end - result = _eval_tree_array(tree, cX, operators, v_turbo, v_early_exit) - if v_early_exit isa Val{true} + result = _eval_tree_array(tree, cX, operators, options) + if options.early_exit isa Val{true} return (result.x, result.ok && !is_bad_array(result.x)) else return (result.x, result.ok) @@ -96,15 +107,13 @@ function eval_tree_array( tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; - turbo::Union{Bool,Val}=Val(false), - bumper::Union{Bool,Val}=Val(false), - early_exit::Union{Bool,Val}=Val(true), + options::EvaluationOptions=EvaluationOptions() ) where {T1<:Number,T2<:Number} T = promote_type(T1, T2) @warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)." tree = convert(constructorof(typeof(tree)){T}, tree) cX = Base.Fix1(convert, T).(cX) - return eval_tree_array(tree, cX, operators; turbo, bumper, early_exit) + return eval_tree_array(tree, cX, operators; options=options) end get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U) @@ -114,9 +123,8 @@ function _eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, - ::Val{turbo}, - ::Val{early_exit}, -)::ResultOk where {T<:Number,turbo,early_exit} + options::EvaluationOptions +)::ResultOk where {T<:Number} # First, we see if there are only constants in the tree - meaning # we can just return the constant result. if tree.degree == 0 @@ -128,17 +136,17 @@ function _eval_tree_array( return ResultOk(fill_similar(const_result.x[], cX, axes(cX, 2)), true) elseif tree.degree == 1 op_idx = tree.op - return dispatch_deg1_eval(tree, cX, op_idx, operators, Val(turbo), Val(early_exit)) + return dispatch_deg1_eval(tree, cX, op_idx, operators, options) else # TODO - add op(op2(x, y), z) and op(x, op2(y, z)) # op(x, y), where x, y are constants or variables. op_idx = tree.op - return dispatch_deg2_eval(tree, cX, op_idx, operators, Val(turbo), Val(early_exit)) + return dispatch_deg2_eval(tree, cX, op_idx, operators, options) end end function deg2_eval( - cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{false} + cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, )::ResultOk where {T<:Number,F} @inbounds @simd for j in eachindex(cumulator_l) x = op(cumulator_l[j], cumulator_r[j])::T @@ -148,7 +156,7 @@ function deg2_eval( end function deg1_eval( - cumulator::AbstractVector{T}, op::F, ::Val{false} + cumulator::AbstractVector{T}, op::F, )::ResultOk where {T<:Number,F} @inbounds @simd for j in eachindex(cumulator) x = op(cumulator[j])::T @@ -172,21 +180,20 @@ end cX::AbstractMatrix{T}, op_idx::Integer, operators::OperatorEnum, - ::Val{turbo}, - ::Val{early_exit}, -) where {T<:Number,turbo,early_exit} + options::EvaluationOptions +) where {T<:Number} nbin = get_nbin(operators) long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote - result_l = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) + result_l = _eval_tree_array(tree.l, cX, operators, options) !result_l.ok && return result_l - early_exit && @return_on_nonfinite_array result_l.x - result_r = _eval_tree_array(tree.r, cX, operators, Val(turbo), Val(early_exit)) + options.early_exit isa Val{true} && @return_on_nonfinite_array result_l.x + result_r = _eval_tree_array(tree.r, cX, operators, options) !result_r.ok && return result_r - early_exit && @return_on_nonfinite_array result_r.x + options.early_exit isa Val{true} && @return_on_nonfinite_array result_r.x # op(x, y), for any x or y - deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], Val(turbo)) + deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], options) end end return quote @@ -195,36 +202,28 @@ end i -> i == op_idx, i -> let op = operators.binops[i] if tree.l.degree == 0 && tree.r.degree == 0 - deg2_l0_r0_eval(tree, cX, op, Val(turbo)) + deg2_l0_r0_eval(tree, cX, op, options) elseif tree.r.degree == 0 - result_l = _eval_tree_array( - tree.l, cX, operators, Val(turbo), Val(early_exit) - ) + result_l = _eval_tree_array(tree.l, cX, operators, options) !result_l.ok && return result_l - early_exit && @return_on_nonfinite_array result_l.x + options.early_exit isa Val{true} && @return_on_nonfinite_array result_l.x # op(x, y), where y is a constant or variable but x is not. - deg2_r0_eval(tree, result_l.x, cX, op, Val(turbo)) + deg2_r0_eval(tree, result_l.x, cX, op, options) elseif tree.l.degree == 0 - result_r = _eval_tree_array( - tree.r, cX, operators, Val(turbo), Val(early_exit) - ) + result_r = _eval_tree_array( tree.r, cX, operators, options) !result_r.ok && return result_r - early_exit && @return_on_nonfinite_array result_r.x + options.early_exit isa Val{true} && @return_on_nonfinite_array result_r.x # op(x, y), where x is a constant or variable but y is not. - deg2_l0_eval(tree, result_r.x, cX, op, Val(turbo)) + deg2_l0_eval(tree, result_r.x, cX, op, options) else - result_l = _eval_tree_array( - tree.l, cX, operators, Val(turbo), Val(early_exit) - ) + result_l = _eval_tree_array( tree.l, cX, operators, options) !result_l.ok && return result_l - early_exit && @return_on_nonfinite_array result_l.x - result_r = _eval_tree_array( - tree.r, cX, operators, Val(turbo), Val(early_exit) - ) + options.early_exit isa Val{true} && @return_on_nonfinite_array result_l.x + result_r = _eval_tree_array( tree.r, cX, operators, options) !result_r.ok && return result_r - early_exit && @return_on_nonfinite_array result_r.x + options.early_exit isa Val{true} && @return_on_nonfinite_array result_r.x # op(x, y), for any x or y - deg2_eval(result_l.x, result_r.x, op, Val(turbo)) + deg2_eval(result_l.x, result_r.x, op) end end ) @@ -235,17 +234,16 @@ end cX::AbstractMatrix{T}, op_idx::Integer, operators::OperatorEnum, - ::Val{turbo}, - ::Val{early_exit}, -) where {T<:Number,turbo,early_exit} + options::EvaluationOptions +) where {T<:Number} nuna = get_nuna(operators) long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote - result = _eval_tree_array(tree.l, cX, operators, Val(turbo), Val(early_exit)) + result = _eval_tree_array(tree.l, cX, operators, options) !result.ok && return result - early_exit && @return_on_nonfinite_array result.x - deg1_eval(result.x, operators.unaops[op_idx], Val(turbo)) + options.early_exit isa Val{true} && @return_on_nonfinite_array result.x + deg1_eval(result.x, operators.unaops[op_idx]) end end # This @nif lets us generate an if statement over choice of operator, @@ -259,22 +257,20 @@ end # op(op2(x, y)), where x, y, z are constants or variables. l_op_idx = tree.l.op dispatch_deg1_l2_ll0_lr0_eval( - tree, cX, op, l_op_idx, operators.binops, Val(turbo) + tree, cX, op, l_op_idx, operators.binops, options ) elseif tree.l.degree == 1 && tree.l.l.degree == 0 # op(op2(x)), where x is a constant or variable. l_op_idx = tree.l.op dispatch_deg1_l1_ll0_eval( - tree, cX, op, l_op_idx, operators.unaops, Val(turbo) + tree, cX, op, l_op_idx, operators.unaops, options ) else # op(x), for any x. - result = _eval_tree_array( - tree.l, cX, operators, Val(turbo), Val(early_exit) - ) + result = _eval_tree_array( tree.l, cX, operators, options) !result.ok && return result - early_exit && @return_on_nonfinite_array result.x - deg1_eval(result.x, op, Val(turbo)) + options.early_exit isa Val{true} && @return_on_nonfinite_array result.x + deg1_eval(result.x, op) end end ) @@ -286,8 +282,8 @@ end op::F, l_op_idx::Integer, binops, - ::Val{turbo}, -) where {T<:Number,F,turbo} + options::EvaluationOptions +) where {T<:Number,F} nbin = counttuple(binops) # (Note this is only called from dispatch_deg1_eval, which has already # checked for long compilation times, so we don't need to check here) @@ -296,7 +292,7 @@ end $nbin, j -> j == l_op_idx, j -> let op_l = binops[j] - deg1_l2_ll0_lr0_eval(tree, cX, op, op_l, Val(turbo)) + deg1_l2_ll0_lr0_eval(tree, cX, op, op_l, options) end, ) end @@ -307,23 +303,23 @@ end op::F, l_op_idx::Integer, unaops, - ::Val{turbo}, -)::ResultOk where {T<:Number,F,turbo} + options::EvaluationOptions +)::ResultOk where {T<:Number,F} nuna = counttuple(unaops) quote Base.Cartesian.@nif( $nuna, j -> j == l_op_idx, j -> let op_l = unaops[j] - deg1_l1_ll0_eval(tree, cX, op, op_l, Val(turbo)) + deg1_l1_ll0_eval(tree, cX, op, op_l, options) end, ) end end function deg1_l2_ll0_lr0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{false} -) where {T<:Number,F,F2} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{false,false,E} +) where {T<:Number,F,F2,E} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val val_lr = tree.l.r.val @@ -371,8 +367,8 @@ end # op(op2(x)) for x variable or constant function deg1_l1_ll0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{false} -) where {T<:Number,F,F2} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{false,false,E} +) where {T<:Number,F,F2,E} if tree.l.l.constant val_ll = tree.l.l.val @return_on_check val_ll cX @@ -395,8 +391,8 @@ end # op(x, y) for x and y variable/constant function deg2_l0_r0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::Val{false} -) where {T<:Number,F} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvaluationOptions{false,false,E} +) where {T<:Number,F,E} if tree.l.constant && tree.r.constant val_l = tree.l.val @return_on_check val_l cX @@ -443,8 +439,8 @@ function deg2_l0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::Val{false}, -) where {T<:Number,F} + ::EvaluationOptions{false,false,E} +) where {T<:Number,F,E} if tree.l.constant val = tree.l.val @return_on_check val cX @@ -469,8 +465,8 @@ function deg2_r0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::Val{false}, -) where {T<:Number,F} + ::EvaluationOptions{false,false,E} +) where {T<:Number,F,E} if tree.r.constant val = tree.r.val @return_on_check val cX @@ -684,10 +680,11 @@ function eval(current_node) cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true, + options::EvaluationOptions=EvaluationOptions() ) - !throw_errors && return _eval_tree_array_generic(tree, cX, operators, Val(false)) + !throw_errors && return _eval_tree_array_generic(tree, cX, operators, options) try - return _eval_tree_array_generic(tree, cX, operators, Val(true)) + return _eval_tree_array_generic(tree, cX, operators, options) catch e tree_s = string_tree(tree, operators) error_msg = "Failed to evaluate tree $(tree_s)." diff --git a/src/precompile.jl b/src/precompile.jl index 9265c854..d2642f57 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -37,25 +37,25 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types # Trivial: for l in (x, c) - @ignore_domain_error eval_tree_array(l, X, operators; turbo=use_turbo) + @ignore_domain_error eval_tree_array(l, X, operators; options=EvaluationOptions(turbo=use_turbo)) end # Binary operators for i in eachindex(binops), l in (x, c), r in (x, c) tree = Node(i, l, r) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo) + @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) end # Unary operators for j in eachindex(unaops), k in eachindex(unaops), l in (x, c) tree = Node(j, l) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo) + @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) tree = Node(j, Node(k, l)) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo) + @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) end # Both operators @@ -67,11 +67,11 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types tree = Node(i, Node(j1, l), Node(j2, r)) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo) + @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) tree = Node(j1, Node(i, l, r)) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo) + @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) end end return nothing diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 79b1d71b..0c454fe7 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -219,8 +219,9 @@ end end end -@testset "Disable early exit" begin +@testitem "Disable early exit" begin using DynamicExpressions + using Bumper, LoopVectorization T = Float16 ex = @parse_expression( @@ -228,7 +229,7 @@ end ) X = T[1.0 floatmax(T)] @test all(isnan.(ex(X))) - @test ex(X; early_exit=Val(false)) ≈ [2.0, Inf] + @test ex(X; options=EvaluationOptions(early_exit=Val(false))) ≈ [2.0, Inf] for turbo in [Val(false), Val(true)], T in [Float32, Float64], @@ -246,8 +247,10 @@ end 1 floatmax(T) 1 1 ] - y = @test all(isnan.(ex(X; bumper=bumper, turbo=turbo))) - @test ex(X; bumper=bumper, turbo=turbo, early_exit=Val(false)) ≈ - T[-1.618033988749895, -Inf] + @test all(isnan.(ex(X; options=EvaluationOptions(bumper=bumper, turbo=turbo)))) + y = ex(X; options=EvaluationOptions(bumper=bumper, turbo=turbo, early_exit=false)) + @test y[1] == T(-1.618033988749895) + # FIXME: this is NaN on macOS and -Inf on windows/ubuntu... + @test !isfinite(y[2]) end end diff --git a/test/test_initial_errors.jl b/test/test_initial_errors.jl index f5b19710..9bc1c951 100644 --- a/test/test_initial_errors.jl +++ b/test/test_initial_errors.jl @@ -39,11 +39,11 @@ if VERSION >= v"1.9" @test_throws( "Please load the Bumper.jl package", - allow_unstable(() -> tree(ones(2, 10), operators; bumper=Val(true))) + allow_unstable(() -> tree(ones(2, 10), operators; options=EvaluationOptions(bumper=Val(true)))) ) @test_throws( "Please load the LoopVectorization.jl package", - allow_unstable(() -> tree(ones(2, 10), operators; turbo=Val(true))) + allow_unstable(() -> tree(ones(2, 10), operators; options=EvaluationOptions(turbo=Val(true)))) ) end From 6cab04704e35ab830ccfb76b0f00cf3d7c991763 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 21:58:18 +0100 Subject: [PATCH 06/41] style: formatting --- src/Evaluate.jl | 73 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 2ee3ac95..151e190e 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -114,7 +114,8 @@ function eval_tree_array( _is_loopvectorization_loaded(0) || error("Please load the LoopVectorization.jl package to use this feature.") end - if (_eval_options.turbo isa Val{true} || _eval_options.bumper isa Val{true}) && !(T <: Number) + if (_eval_options.turbo isa Val{true} || _eval_options.bumper isa Val{true}) && + !(T <: Number) error( "Bumper and LoopVectorization features are only compatible with numeric element types", ) @@ -124,7 +125,10 @@ function eval_tree_array( end result = _eval_tree_array(tree, cX, operators, _eval_options) - return (result.x, result.ok && (_eval_options.early_exit isa Val{true} || is_valid_array(result.x))) + return ( + result.x, + result.ok && (_eval_options.early_exit isa Val{false} || is_valid_array(result.x)), + ) end function eval_tree_array( @@ -137,7 +141,7 @@ function eval_tree_array( tree::AbstractExpressionNode{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; - kws... + kws..., ) where {T1,T2} T = promote_type(T1, T2) @warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)." @@ -153,7 +157,7 @@ function _eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, - eval_options::EvaluationOptions + eval_options::EvaluationOptions, )::ResultOk where {T} # First, we see if there are only constants in the tree - meaning # we can just return the constant result. @@ -176,7 +180,10 @@ function _eval_tree_array( end function deg2_eval( - cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::EvaluationOptions{false} + cumulator_l::AbstractVector{T}, + cumulator_r::AbstractVector{T}, + op::F, + ::EvaluationOptions{false}, )::ResultOk where {T,F} @inbounds @simd for j in eachindex(cumulator_l) x = op(cumulator_l[j], cumulator_r[j])::T @@ -185,7 +192,9 @@ function deg2_eval( return ResultOk(cumulator_l, true) end -function deg1_eval(cumulator::AbstractVector{T}, op::F, ::EvaluationOptions{false})::ResultOk where {T,F} +function deg1_eval( + cumulator::AbstractVector{T}, op::F, ::EvaluationOptions{false} +)::ResultOk where {T,F} @inbounds @simd for j in eachindex(cumulator) x = op(cumulator[j])::T cumulator[j] = x @@ -208,7 +217,7 @@ end cX::AbstractMatrix{T}, op_idx::Integer, operators::OperatorEnum, - eval_options::EvaluationOptions + eval_options::EvaluationOptions, ) where {T} nbin = get_nbin(operators) long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN @@ -234,22 +243,26 @@ end elseif tree.r.degree == 0 result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l - eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_l.x + eval_options.early_exit isa Val{true} && + @return_on_nonfinite_array result_l.x # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, eval_options) elseif tree.l.degree == 0 - result_r = _eval_tree_array( tree.r, cX, operators, eval_options) + result_r = _eval_tree_array(tree.r, cX, operators, eval_options) !result_r.ok && return result_r - eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_r.x + eval_options.early_exit isa Val{true} && + @return_on_nonfinite_array result_r.x # op(x, y), where x is a constant or variable but y is not. deg2_l0_eval(tree, result_r.x, cX, op, eval_options) else - result_l = _eval_tree_array( tree.l, cX, operators, eval_options) + result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l - eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_l.x - result_r = _eval_tree_array( tree.r, cX, operators, eval_options) + eval_options.early_exit isa Val{true} && + @return_on_nonfinite_array result_l.x + result_r = _eval_tree_array(tree.r, cX, operators, eval_options) !result_r.ok && return result_r - eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_r.x + eval_options.early_exit isa Val{true} && + @return_on_nonfinite_array result_r.x # op(x, y), for any x or y deg2_eval(result_l.x, result_r.x, op) end @@ -262,7 +275,7 @@ end cX::AbstractMatrix{T}, op_idx::Integer, operators::OperatorEnum, - eval_options::EvaluationOptions + eval_options::EvaluationOptions, ) where {T} nuna = get_nuna(operators) long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN @@ -295,9 +308,10 @@ end ) else # op(x), for any x. - result = _eval_tree_array( tree.l, cX, operators, eval_options) + result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result - eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result.x + eval_options.early_exit isa Val{true} && + @return_on_nonfinite_array result.x deg1_eval(result.x, op) end end @@ -310,7 +324,7 @@ end op::F, l_op_idx::Integer, binops, - eval_options::EvaluationOptions + eval_options::EvaluationOptions, ) where {T,F} nbin = counttuple(binops) # (Note this is only called from dispatch_deg1_eval, which has already @@ -331,7 +345,7 @@ end op::F, l_op_idx::Integer, unaops, - eval_options::EvaluationOptions + eval_options::EvaluationOptions, )::ResultOk where {T,F} nuna = counttuple(unaops) quote @@ -346,7 +360,11 @@ end end function deg1_l2_ll0_lr0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{false,false} + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + op::F, + op_l::F2, + ::EvaluationOptions{false,false}, ) where {T,F,F2} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val @@ -395,7 +413,11 @@ end # op(op2(x)) for x variable or constant function deg1_l1_ll0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{false,false} + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + op::F, + op_l::F2, + ::EvaluationOptions{false,false}, ) where {T,F,F2} if tree.l.l.constant val_ll = tree.l.l.val @@ -419,7 +441,10 @@ end # op(x, y) for x and y variable/constant function deg2_l0_r0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvaluationOptions{false,false} + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + op::F, + ::EvaluationOptions{false,false}, ) where {T,F} if tree.l.constant && tree.r.constant val_l = tree.l.val @@ -467,7 +492,7 @@ function deg2_l0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvaluationOptions{false,false} + ::EvaluationOptions{false,false}, ) where {T,F} if tree.l.constant val = tree.l.val @@ -493,7 +518,7 @@ function deg2_r0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvaluationOptions{false,false} + ::EvaluationOptions{false,false}, ) where {T,F} if tree.r.constant val = tree.r.val From 6d46df91ae228ab01a763353dc61d48c19082f00 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 22:03:14 +0100 Subject: [PATCH 07/41] style: more formatting --- ext/DynamicExpressionsBumperExt.jl | 22 +++++++++++------ ext/DynamicExpressionsLoopVectorizationExt.jl | 21 ++++++++++++---- src/DynamicExpressions.jl | 3 ++- src/precompile.jl | 24 ++++++++++++++----- test/test_evaluation.jl | 6 ++--- test/test_initial_errors.jl | 9 +++++-- 6 files changed, 61 insertions(+), 24 deletions(-) diff --git a/ext/DynamicExpressionsBumperExt.jl b/ext/DynamicExpressionsBumperExt.jl index 300d1169..a991a2b9 100644 --- a/ext/DynamicExpressionsBumperExt.jl +++ b/ext/DynamicExpressionsBumperExt.jl @@ -12,7 +12,7 @@ function bumper_eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, - options::EvaluationOptions{turbo,true,early_exit} + options::EvaluationOptions{turbo,true,early_exit}, ) where {T,turbo,early_exit} result = similar(cX, axes(cX, 2)) n = size(cX, 2) @@ -58,12 +58,18 @@ function dispatch_kerns!( return early_exit ? ResultOk(out, is_valid_array(out)) : ResultOk(out, true) end function dispatch_kerns!( - operators, branch_node, cumulator1, cumulator2, options::EvaluationOptions{turbo,true,early_exit} + operators, + branch_node, + cumulator1, + cumulator2, + options::EvaluationOptions{turbo,true,early_exit}, ) where {turbo,early_exit} cumulator1.ok || return cumulator1 cumulator2.ok || return cumulator2 - out = dispatch_kern2!(operators.binops, branch_node.op, cumulator1.x, cumulator2.x, options) + out = dispatch_kern2!( + operators.binops, branch_node.op, cumulator1.x, cumulator2.x, options + ) return early_exit ? ResultOk(out, is_valid_array(out)) : ResultOk(out, true) end @@ -73,16 +79,18 @@ end nuna = counttuple(unaops) quote Base.@nif( - $nuna, - i -> i == op_idx, - i -> let op = unaops[i] + $nuna, i -> i == op_idx, i -> let op = unaops[i] return bumper_kern1!(op, cumulator, options) end, ) end end @generated function dispatch_kern2!( - binops, op_idx, cumulator1, cumulator2, options::EvaluationOptions{turbo,true,early_exit} + binops, + op_idx, + cumulator1, + cumulator2, + options::EvaluationOptions{turbo,true,early_exit}, ) where {turbo,early_exit} nbin = counttuple(binops) quote diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index c397b9fc..e78666ed 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -18,7 +18,10 @@ import DynamicExpressions.ExtensionInterfaceModule: _is_loopvectorization_loaded(::Int) = true function deg2_eval( - cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::EvaluationOptions{true} + cumulator_l::AbstractVector{T}, + cumulator_r::AbstractVector{T}, + op::F, + ::EvaluationOptions{true}, )::ResultOk where {T<:Number,F} @turbo for j in eachindex(cumulator_l) x = op(cumulator_l[j], cumulator_r[j]) @@ -38,7 +41,11 @@ function deg1_eval( end function deg1_l2_ll0_lr0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{true} + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + op::F, + op_l::F2, + ::EvaluationOptions{true}, ) where {T<:Number,F,F2} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val @@ -86,7 +93,11 @@ function deg1_l2_ll0_lr0_eval( end function deg1_l1_ll0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::EvaluationOptions{true} + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + op::F, + op_l::F2, + ::EvaluationOptions{true}, ) where {T<:Number,F,F2} if tree.l.l.constant val_ll = tree.l.l.val @@ -157,7 +168,7 @@ function deg2_l0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvaluationOptions{true} + ::EvaluationOptions{true}, ) where {T<:Number,F} if tree.l.constant val = tree.l.val @@ -182,7 +193,7 @@ function deg2_r0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvaluationOptions{true} + ::EvaluationOptions{true}, ) where {T<:Number,F} if tree.r.constant val = tree.r.val diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index e00b175d..9af6a476 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -68,7 +68,8 @@ import .NodeModule: @reexport import .OperatorEnumModule: AbstractOperatorEnum @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! -@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, EvaluationOptions +@reexport import .EvaluateModule: + eval_tree_array, differentiable_eval_tree_array, EvaluationOptions @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/precompile.jl b/src/precompile.jl index 63fc1821..81fac517 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -37,25 +37,33 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types # Trivial: for l in (x, c) - @ignore_domain_error eval_tree_array(l, X, operators; options=EvaluationOptions(turbo=use_turbo)) + @ignore_domain_error eval_tree_array( + l, X, operators; options=EvaluationOptions(; turbo=use_turbo) + ) end # Binary operators for i in eachindex(binops), l in (x, c), r in (x, c) tree = Node(i, l, r) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) + @ignore_domain_error eval_tree_array( + tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + ) end # Unary operators for j in eachindex(unaops), k in eachindex(unaops), l in (x, c) tree = Node(j, l) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) + @ignore_domain_error eval_tree_array( + tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + ) tree = Node(j, Node(k, l)) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) + @ignore_domain_error eval_tree_array( + tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + ) end # Both operators @@ -67,11 +75,15 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types tree = Node(i, Node(j1, l), Node(j2, r)) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) + @ignore_domain_error eval_tree_array( + tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + ) tree = Node(j1, Node(i, l, r)) tree = convert(Node{T}, tree) - @ignore_domain_error eval_tree_array(tree, X, operators; options=EvaluationOptions(turbo=use_turbo)) + @ignore_domain_error eval_tree_array( + tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + ) end end return nothing diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index d8554947..91fca624 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -230,7 +230,7 @@ end ) X = T[1.0 floatmax(T)] @test all(isnan.(ex(X))) - @test ex(X; options=EvaluationOptions(early_exit=Val(false))) ≈ [2.0, Inf] + @test ex(X; options=EvaluationOptions(; early_exit=Val(false))) ≈ [2.0, Inf] for turbo in [Val(false), Val(true)], T in [Float32, Float64], @@ -248,8 +248,8 @@ end 1 floatmax(T) 1 1 ] - @test all(isnan.(ex(X; options=EvaluationOptions(bumper=bumper, turbo=turbo)))) - y = ex(X; options=EvaluationOptions(bumper=bumper, turbo=turbo, early_exit=false)) + @test all(isnan.(ex(X; options=EvaluationOptions(; bumper=bumper, turbo=turbo)))) + y = ex(X; options=EvaluationOptions(; bumper=bumper, turbo=turbo, early_exit=false)) @test y[1] == T(-1.618033988749895) # FIXME: this is NaN on macOS and -Inf on windows/ubuntu... @test !isfinite(y[2]) diff --git a/test/test_initial_errors.jl b/test/test_initial_errors.jl index 9bc1c951..4aedaf67 100644 --- a/test/test_initial_errors.jl +++ b/test/test_initial_errors.jl @@ -39,11 +39,16 @@ if VERSION >= v"1.9" @test_throws( "Please load the Bumper.jl package", - allow_unstable(() -> tree(ones(2, 10), operators; options=EvaluationOptions(bumper=Val(true)))) + allow_unstable( + () -> + tree(ones(2, 10), operators; options=EvaluationOptions(; bumper=Val(true))), + ) ) @test_throws( "Please load the LoopVectorization.jl package", - allow_unstable(() -> tree(ones(2, 10), operators; options=EvaluationOptions(turbo=Val(true)))) + allow_unstable( + () -> tree(ones(2, 10), operators; options=EvaluationOptions(; turbo=Val(true))) + ) ) end From b8f10871020ca9adf1adce3452f00b017ff8b78e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 22:05:14 +0100 Subject: [PATCH 08/41] style: clean up redundant options --- src/Evaluate.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 151e190e..ef59a4fc 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -27,22 +27,18 @@ macro return_on_nonfinite_array(array) end ) end + struct EvaluationOptions{T,B,E} turbo::Val{T} bumper::Val{B} early_exit::Val{E} end -function EvaluationOptions(; turbo=false, bumper=false, early_exit=true) +function EvaluationOptions(; turbo=Val(false), bumper=Val(false), early_exit=Val(true)) v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false)) v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false)) v_early_exit = isa(early_exit, Val) ? early_exit : (early_exit ? Val(true) : Val(false)) return EvaluationOptions(v_turbo, v_bumper, v_early_exit) end -function EvaluationOptions{T,B,E}(; - turbo=Val(false), bumper=Val(false), early_exit=Val(true) -) where {T,B,E} - return EvaluationOptions{T,B,E}(turbo, bumper, early_exit) -end """ eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false), bumper::Union{Bool,Val}=Val(false)) @@ -102,8 +98,8 @@ function eval_tree_array( :eval_tree_array, ) EvaluationOptions(; - turbo = turbo === nothing ? Val(false) : (turbo isa Val ? turbo : Val(turbo)), - bumper = bumper === nothing ? Val(false) : (bumper isa Val ? bumper : Val(bumper)), + turbo = turbo === nothing ? Val(false) : turbo, + bumper = bumper === nothing ? Val(false) : bumper, ) end #! format: on From ee7d7c1ce2731f2d15eb00b4836a3d67bdbf4d04 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 22:13:23 +0100 Subject: [PATCH 09/41] style: rename to `eval_options` --- ext/DynamicExpressionsBumperExt.jl | 51 +++++++++++++++--------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/ext/DynamicExpressionsBumperExt.jl b/ext/DynamicExpressionsBumperExt.jl index a991a2b9..7b8e5eb1 100644 --- a/ext/DynamicExpressionsBumperExt.jl +++ b/ext/DynamicExpressionsBumperExt.jl @@ -12,7 +12,7 @@ function bumper_eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, - options::EvaluationOptions{turbo,true,early_exit}, + eval_options::EvaluationOptions{turbo,true,early_exit}, ) where {T,turbo,early_exit} result = similar(cX, axes(cX, 2)) n = size(cX, 2) @@ -38,7 +38,7 @@ function bumper_eval_tree_array( # In the evaluation kernel, we combine the branch nodes # with the arrays created by the leaf nodes: ((args::Vararg{Any,M}) where {M}) -> - dispatch_kerns!(operators, args..., options), + dispatch_kerns!(operators, args..., eval_options), tree; break_sharing=Val(true), ) @@ -50,68 +50,67 @@ function bumper_eval_tree_array( end function dispatch_kerns!( - operators, branch_node, cumulator, options::EvaluationOptions{turbo,true,early_exit} -) where {turbo,early_exit} + operators, + branch_node, + cumulator, + eval_options::EvaluationOptions{<:Any,true,early_exit}, +) where {early_exit} cumulator.ok || return cumulator - out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, options) - return early_exit ? ResultOk(out, is_valid_array(out)) : ResultOk(out, true) + out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, eval_options) + return ResultOk(out, early_exit ? is_valid_array(out) : true) end function dispatch_kerns!( operators, branch_node, cumulator1, cumulator2, - options::EvaluationOptions{turbo,true,early_exit}, -) where {turbo,early_exit} + eval_options::EvaluationOptions{<:Any,true,early_exit}, +) where {early_exit} cumulator1.ok || return cumulator1 cumulator2.ok || return cumulator2 out = dispatch_kern2!( - operators.binops, branch_node.op, cumulator1.x, cumulator2.x, options + operators.binops, branch_node.op, cumulator1.x, cumulator2.x, eval_options ) - return early_exit ? ResultOk(out, is_valid_array(out)) : ResultOk(out, true) + return ResultOk(out, early_exit ? is_valid_array(out) : true) end @generated function dispatch_kern1!( - unaops, op_idx, cumulator, options::EvaluationOptions{turbo,true,early_exit} -) where {turbo,early_exit} + unaops, op_idx, cumulator, eval_options::EvaluationOptions +) nuna = counttuple(unaops) quote Base.@nif( - $nuna, i -> i == op_idx, i -> let op = unaops[i] - return bumper_kern1!(op, cumulator, options) + $nuna, + i -> i == op_idx, + i -> let op = unaops[i] + return bumper_kern1!(op, cumulator, eval_options) end, ) end end @generated function dispatch_kern2!( - binops, - op_idx, - cumulator1, - cumulator2, - options::EvaluationOptions{turbo,true,early_exit}, -) where {turbo,early_exit} + binops, op_idx, cumulator1, cumulator2, eval_options::EvaluationOptions +) nbin = counttuple(binops) quote Base.@nif( $nbin, i -> i == op_idx, i -> let op = binops[i] - return bumper_kern2!(op, cumulator1, cumulator2, options) + return bumper_kern2!(op, cumulator1, cumulator2, eval_options) end, ) end end -function bumper_kern1!( - op::F, cumulator, ::EvaluationOptions{false,true,early_exit} -) where {F,early_exit} +function bumper_kern1!(op::F, cumulator, ::EvaluationOptions{false,true}) where {F} @. cumulator = op(cumulator) return cumulator end function bumper_kern2!( - op::F, cumulator1, cumulator2, ::EvaluationOptions{false,true,early_exit} -) where {F,early_exit} + op::F, cumulator1, cumulator2, ::EvaluationOptions{false,true} +) where {F} @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 end From c0b5a461d0a42d789624e1e128c2ddf829cef7a7 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 22:17:02 +0100 Subject: [PATCH 10/41] fix: merge edits to eval options --- src/Evaluate.jl | 6 +++--- src/precompile.jl | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index ef59a4fc..3b0e12dd 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -260,7 +260,7 @@ end eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_r.x # op(x, y), for any x or y - deg2_eval(result_l.x, result_r.x, op) + deg2_eval(result_l.x, result_r.x, op, eval_options) end end ) @@ -280,7 +280,7 @@ end result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result.x - deg1_eval(result.x, operators.unaops[op_idx]) + deg1_eval(result.x, operators.unaops[op_idx], eval_options) end end # This @nif lets us generate an if statement over choice of operator, @@ -308,7 +308,7 @@ end !result.ok && return result eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result.x - deg1_eval(result.x, op) + deg1_eval(result.x, op, eval_options) end end ) diff --git a/src/precompile.jl b/src/precompile.jl index 81fac517..f7a72b19 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -38,7 +38,7 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types # Trivial: for l in (x, c) @ignore_domain_error eval_tree_array( - l, X, operators; options=EvaluationOptions(; turbo=use_turbo) + l, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) ) end @@ -47,7 +47,7 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types tree = Node(i, l, r) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) ) end @@ -56,13 +56,13 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types tree = Node(j, l) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) ) tree = Node(j, Node(k, l)) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) ) end @@ -76,13 +76,13 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types tree = Node(i, Node(j1, l), Node(j2, r)) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) ) tree = Node(j1, Node(i, l, r)) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) ) end end From 17ae595b6e8e979f637239628ce2a4258eda4cad Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 22:26:11 +0100 Subject: [PATCH 11/41] fix: fix generic eval errors --- src/Evaluate.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 3b0e12dd..892b19e8 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -728,11 +728,13 @@ function eval(current_node) tree::AbstractExpressionNode{T1}, cX::AbstractArray{T2,N}, operators::GenericOperatorEnum; - throw_errors::Bool=true, + throw_errors::Union{Val,Bool}=Val(true), ) where {T1,T2,N} - !throw_errors && return _eval_tree_array_generic(tree, cX, operators) + v_throw_errors = throw_errors isa Val ? throw_errors : Val(throw_errors) + v_throw_errors isa Val{false} && + return _eval_tree_array_generic(tree, cX, operators, v_throw_errors) try - return _eval_tree_array_generic(tree, cX, operators) + return _eval_tree_array_generic(tree, cX, operators, v_throw_errors) catch e if !throw_errors return nothing, false From 2b98acf4671208bd8d93f4a40717649f0cb37320 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 22:35:03 +0100 Subject: [PATCH 12/41] refactor: test_evaluation.jl --- test/test_evaluation.jl | 49 ++++++++++++++++++++++++++--------------- test/unittest.jl | 4 +--- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 91fca624..3568895f 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -1,8 +1,7 @@ -using DynamicExpressions -using Bumper -using LoopVectorization -using Random -using Test +#! format: off +@testitem "Test validity of expression evaluation" begin +using DynamicExpressions, Bumper, LoopVectorization, Random + include("test_params.jl") include("tree_gen_utils.jl") @@ -81,8 +80,14 @@ for turbo in [Val(false), Val(true)], end end end +end +#! format: on + +@testitem "Test specific branches of evaluation" begin + using DynamicExpressions, DynamicExpressions, Bumper, LoopVectorization + + include("test_params.jl") -@testset "Test specific branches of evaluation" begin for turbo in [false, true], T in [Float16, Float32, Float64, ComplexF32, ComplexF64] turbo && !(T in (Float32, Float64)) && continue # Test specific branches of evaluation code: @@ -126,8 +131,10 @@ end end # Check if julia version >= 1.7: -if VERSION >= v"1.7" - @testset "Test error catching for GenericOperatorEnum" begin +@testitem "Test error catching for GenericOperatorEnum" begin + using DynamicExpressions + + if VERSION >= v"1.7" # And, with generic operator enum, this should be an actual error: operators = GenericOperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin] @@ -168,7 +175,11 @@ if VERSION >= v"1.7" end end -@testset "Test many operators" begin +@testitem "Test many operators" begin + using DynamicExpressions + + include("tree_gen_utils.jl") + # Since we use `@nif` in evaluating expressions, # we can see if there are any issues with LARGE numbers of operators. num_ops = 100 @@ -224,13 +235,15 @@ end using DynamicExpressions using Bumper, LoopVectorization - T = Float16 - ex = @parse_expression( - 2 * x, binary_operators = [*], variable_names = ["x"], node_type = Node{T} - ) - X = T[1.0 floatmax(T)] - @test all(isnan.(ex(X))) - @test ex(X; options=EvaluationOptions(; early_exit=Val(false))) ≈ [2.0, Inf] + let + T = Float16 + ex = @parse_expression( + 2 * x, binary_operators = [*], variable_names = ["x"], node_type = Node{T} + ) + X = T[1.0 floatmax(T)] + @test all(isnan.(ex(X))) + @test ex(X; eval_options=EvaluationOptions(; early_exit=Val(false))) ≈ [2.0, Inf] + end for turbo in [Val(false), Val(true)], T in [Float32, Float64], @@ -248,8 +261,8 @@ end 1 floatmax(T) 1 1 ] - @test all(isnan.(ex(X; options=EvaluationOptions(; bumper=bumper, turbo=turbo)))) - y = ex(X; options=EvaluationOptions(; bumper=bumper, turbo=turbo, early_exit=false)) + @test all(isnan.(ex(X; eval_options=EvaluationOptions(; bumper, turbo)))) + y = ex(X; eval_options=EvaluationOptions(; bumper, turbo, early_exit=false)) @test y[1] == T(-1.618033988749895) # FIXME: this is NaN on macOS and -Inf on windows/ubuntu... @test !isfinite(y[2]) diff --git a/test/unittest.jl b/test/unittest.jl index e6765710..11244b14 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -46,9 +46,7 @@ end include("test_print.jl") end -@testitem "Test validity of expression evaluation" begin - include("test_evaluation.jl") -end +include("test_evaluation.jl") @testitem "Test validity of integer expression evaluation" begin include("test_integer_evaluation.jl") From 539962589cd33b05267608868a18f23bc07572ac Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 22:36:25 +0100 Subject: [PATCH 13/41] fix: specific branch calls --- test/test_evaluation.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 3568895f..b9337ec4 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -85,6 +85,7 @@ end @testitem "Test specific branches of evaluation" begin using DynamicExpressions, DynamicExpressions, Bumper, LoopVectorization + using DynamicExpressions.EvaluateModule: EvaluationOptions include("test_params.jl") @@ -100,7 +101,7 @@ end @test repr(tree) == "cos(cos(3.0))" tree = convert(Node{T}, tree) truth = cos(cos(T(3.0f0))) - @test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, Val(turbo)).x[1] ≈ + @test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, EvaluationOptions(; turbo)).x[1] ≈ truth # op(, ) @@ -108,7 +109,7 @@ end @test repr(tree) == "3.0 + 4.0" tree = convert(Node{T}, tree) truth = T(3.0f0) + T(4.0f0) - @test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), Val(turbo)).x[1] ≈ + @test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), EvaluationOptions(; turbo)).x[1] ≈ truth # op(op(, )) @@ -116,7 +117,7 @@ end @test repr(tree) == "cos(3.0 + 4.0)" tree = convert(Node{T}, tree) truth = cos(T(3.0f0) + T(4.0f0)) - @test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), Val(turbo)).x[1] ≈ + @test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), EvaluationOptions(; turbo)).x[1] ≈ truth # Test for presence of NaNs: From 204a9df7bf53a020885a20fdee450cd654e743e4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 22:40:16 +0100 Subject: [PATCH 14/41] fix: `v_throw_errors` typo --- src/Evaluate.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 892b19e8..c9e37e40 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -736,7 +736,7 @@ function eval(current_node) try return _eval_tree_array_generic(tree, cX, operators, v_throw_errors) catch e - if !throw_errors + if v_throw_errors isa Val{false} return nothing, false end tree_s = string_tree(tree, operators) From 2fc5e8792f61eba0344e7ee56452cc1ad817036e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 19 Jul 2024 23:44:20 +0100 Subject: [PATCH 15/41] fix: error catching for generic eval --- src/Evaluate.jl | 2 -- test/test_evaluation.jl | 25 +++++++++++++------------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index c9e37e40..2ac95f72 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -731,8 +731,6 @@ function eval(current_node) throw_errors::Union{Val,Bool}=Val(true), ) where {T1,T2,N} v_throw_errors = throw_errors isa Val ? throw_errors : Val(throw_errors) - v_throw_errors isa Val{false} && - return _eval_tree_array_generic(tree, cX, operators, v_throw_errors) try return _eval_tree_array_generic(tree, cX, operators, v_throw_errors) catch e diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index b9337ec4..78a12bf9 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -135,16 +135,17 @@ end @testitem "Test error catching for GenericOperatorEnum" begin using DynamicExpressions - if VERSION >= v"1.7" + @static if VERSION >= v"1.7" # And, with generic operator enum, this should be an actual error: + @eval my_fnc(x::Real) = x operators = GenericOperatorEnum(; - binary_operators=[+, -, *, /], unary_operators=[cos, sin] + binary_operators=[+, -, *, /], unary_operators=[cos, sin, my_fnc] ) + @extend_operators operators x1 = Node(Float64; feature=1) tree = sin(x1 / 0.0) X = randn(Float32, 10) let - local stack try tree(X, operators)[1] @test false @@ -153,25 +154,25 @@ end # Check that "Failed to evaluate" is in the message: @test occursin("Failed to evaluate", e.msg) stack = current_exceptions() + @test length(stack) == 2 + @test stack[1].exception isa DomainError end - @test length(stack) == 2 - @test stack[1].exception isa DomainError # If a method is not defined, we should get a nothing: - X = randn(Float32, 1, 10) - @test tree(X, operators; throw_errors=false) === nothing + X2 = randn(ComplexF64, 1, 10) + tree2 = my_fnc(x1) + @test tree2(X2, operators; throw_errors=false) === nothing # or a MethodError: try - tree(X, operators; throw_errors=true) + tree2(X2, operators; throw_errors=true) @test false catch e @test e isa ErrorException @test occursin("Failed to evaluate", e.msg) - stack = current_exceptions() + stack2 = current_exceptions() + @test length(stack2) == 2 + @test stack2[1].exception isa MethodError end - @test length(stack) == 2 - # Dividing by 0 should not be an MethodError - # @test stack[1].exception isa MethodError end end end From 660d6f8bbcc3011ed571edee598fe18d39ab8e3f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 20 Jul 2024 00:00:47 +0100 Subject: [PATCH 16/41] style: rename `EvaluationOptions` to `EvalOptions` --- ext/DynamicExpressionsBumperExt.jl | 23 +++++-------- ext/DynamicExpressionsLoopVectorizationExt.jl | 20 +++++------ src/DynamicExpressions.jl | 2 +- src/Evaluate.jl | 34 +++++++++---------- src/precompile.jl | 12 +++---- test/test_evaluation.jl | 14 ++++---- test/test_initial_errors.jl | 6 ++-- 7 files changed, 52 insertions(+), 59 deletions(-) diff --git a/ext/DynamicExpressionsBumperExt.jl b/ext/DynamicExpressionsBumperExt.jl index 7b8e5eb1..6e99927b 100644 --- a/ext/DynamicExpressionsBumperExt.jl +++ b/ext/DynamicExpressionsBumperExt.jl @@ -2,7 +2,7 @@ module DynamicExpressionsBumperExt using Bumper: @no_escape, @alloc using DynamicExpressions: - OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvaluationOptions + OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvalOptions using DynamicExpressions.UtilsModule: ResultOk, counttuple import DynamicExpressions.ExtensionInterfaceModule: @@ -12,7 +12,7 @@ function bumper_eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, - eval_options::EvaluationOptions{turbo,true,early_exit}, + eval_options::EvalOptions{turbo,true,early_exit}, ) where {T,turbo,early_exit} result = similar(cX, axes(cX, 2)) n = size(cX, 2) @@ -50,10 +50,7 @@ function bumper_eval_tree_array( end function dispatch_kerns!( - operators, - branch_node, - cumulator, - eval_options::EvaluationOptions{<:Any,true,early_exit}, + operators, branch_node, cumulator, eval_options::EvalOptions{<:Any,true,early_exit} ) where {early_exit} cumulator.ok || return cumulator @@ -65,7 +62,7 @@ function dispatch_kerns!( branch_node, cumulator1, cumulator2, - eval_options::EvaluationOptions{<:Any,true,early_exit}, + eval_options::EvalOptions{<:Any,true,early_exit}, ) where {early_exit} cumulator1.ok || return cumulator1 cumulator2.ok || return cumulator2 @@ -76,9 +73,7 @@ function dispatch_kerns!( return ResultOk(out, early_exit ? is_valid_array(out) : true) end -@generated function dispatch_kern1!( - unaops, op_idx, cumulator, eval_options::EvaluationOptions -) +@generated function dispatch_kern1!(unaops, op_idx, cumulator, eval_options::EvalOptions) nuna = counttuple(unaops) quote Base.@nif( @@ -91,7 +86,7 @@ end end end @generated function dispatch_kern2!( - binops, op_idx, cumulator1, cumulator2, eval_options::EvaluationOptions + binops, op_idx, cumulator1, cumulator2, eval_options::EvalOptions ) nbin = counttuple(binops) quote @@ -104,13 +99,11 @@ end ) end end -function bumper_kern1!(op::F, cumulator, ::EvaluationOptions{false,true}) where {F} +function bumper_kern1!(op::F, cumulator, ::EvalOptions{false,true}) where {F} @. cumulator = op(cumulator) return cumulator end -function bumper_kern2!( - op::F, cumulator1, cumulator2, ::EvaluationOptions{false,true} -) where {F} +function bumper_kern2!(op::F, cumulator1, cumulator2, ::EvalOptions{false,true}) where {F} @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 end diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index e78666ed..35da7de0 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt using LoopVectorization: @turbo using DynamicExpressions: AbstractExpressionNode using DynamicExpressions.UtilsModule: ResultOk, fill_similar -using DynamicExpressions.EvaluateModule: @return_on_check, EvaluationOptions +using DynamicExpressions.EvaluateModule: @return_on_check, EvalOptions import DynamicExpressions.EvaluateModule: deg1_eval, deg2_eval, @@ -21,7 +21,7 @@ function deg2_eval( cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, - ::EvaluationOptions{true}, + ::EvalOptions{true}, )::ResultOk where {T<:Number,F} @turbo for j in eachindex(cumulator_l) x = op(cumulator_l[j], cumulator_r[j]) @@ -31,7 +31,7 @@ function deg2_eval( end function deg1_eval( - cumulator::AbstractVector{T}, op::F, ::EvaluationOptions{true} + cumulator::AbstractVector{T}, op::F, ::EvalOptions{true} )::ResultOk where {T<:Number,F} @turbo for j in eachindex(cumulator) x = op(cumulator[j]) @@ -45,7 +45,7 @@ function deg1_l2_ll0_lr0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvaluationOptions{true}, + ::EvalOptions{true}, ) where {T<:Number,F,F2} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val @@ -97,7 +97,7 @@ function deg1_l1_ll0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvaluationOptions{true}, + ::EvalOptions{true}, ) where {T<:Number,F,F2} if tree.l.l.constant val_ll = tree.l.l.val @@ -120,7 +120,7 @@ function deg1_l1_ll0_eval( end function deg2_l0_r0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvaluationOptions{true} + tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvalOptions{true} ) where {T<:Number,F} if tree.l.constant && tree.r.constant val_l = tree.l.val @@ -168,7 +168,7 @@ function deg2_l0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvaluationOptions{true}, + ::EvalOptions{true}, ) where {T<:Number,F} if tree.l.constant val = tree.l.val @@ -193,7 +193,7 @@ function deg2_r0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvaluationOptions{true}, + ::EvalOptions{true}, ) where {T<:Number,F} if tree.r.constant val = tree.r.val @@ -215,13 +215,13 @@ end ## Interface with Bumper.jl function bumper_kern1!( - op::F, cumulator, ::EvaluationOptions{true,true,early_exit} + op::F, cumulator, ::EvalOptions{true,true,early_exit} ) where {F,early_exit} @turbo @. cumulator = op(cumulator) return cumulator end function bumper_kern2!( - op::F, cumulator1, cumulator2, ::EvaluationOptions{true,true,early_exit} + op::F, cumulator1, cumulator2, ::EvalOptions{true,true,early_exit} ) where {F,early_exit} @turbo @. cumulator1 = op(cumulator1, cumulator2) return cumulator1 diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 9af6a476..4ed62ae8 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -69,7 +69,7 @@ import .NodeModule: @reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! @reexport import .EvaluateModule: - eval_tree_array, differentiable_eval_tree_array, EvaluationOptions + eval_tree_array, differentiable_eval_tree_array, EvalOptions @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 2ac95f72..bb061b74 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -28,16 +28,16 @@ macro return_on_nonfinite_array(array) ) end -struct EvaluationOptions{T,B,E} +struct EvalOptions{T,B,E} turbo::Val{T} bumper::Val{B} early_exit::Val{E} end -function EvaluationOptions(; turbo=Val(false), bumper=Val(false), early_exit=Val(true)) +function EvalOptions(; turbo=Val(false), bumper=Val(false), early_exit=Val(true)) v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false)) v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false)) v_early_exit = isa(early_exit, Val) ? early_exit : (early_exit ? Val(true) : Val(false)) - return EvaluationOptions(v_turbo, v_bumper, v_early_exit) + return EvalOptions(v_turbo, v_bumper, v_early_exit) end """ @@ -80,7 +80,7 @@ function eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; - eval_options::Union{EvaluationOptions,Nothing}=nothing, + eval_options::Union{EvalOptions,Nothing}=nothing, turbo::Union{Bool,Val,Nothing}=nothing, bumper::Union{Bool,Val,Nothing}=nothing, ) where {T} @@ -97,7 +97,7 @@ function eval_tree_array( "The `turbo` and `bumper` keyword arguments are deprecated. Please use `eval_options` instead.", :eval_tree_array, ) - EvaluationOptions(; + EvalOptions(; turbo = turbo === nothing ? Val(false) : turbo, bumper = bumper === nothing ? Val(false) : bumper, ) @@ -153,7 +153,7 @@ function _eval_tree_array( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, - eval_options::EvaluationOptions, + eval_options::EvalOptions, )::ResultOk where {T} # First, we see if there are only constants in the tree - meaning # we can just return the constant result. @@ -179,7 +179,7 @@ function deg2_eval( cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, - ::EvaluationOptions{false}, + ::EvalOptions{false}, )::ResultOk where {T,F} @inbounds @simd for j in eachindex(cumulator_l) x = op(cumulator_l[j], cumulator_r[j])::T @@ -189,7 +189,7 @@ function deg2_eval( end function deg1_eval( - cumulator::AbstractVector{T}, op::F, ::EvaluationOptions{false} + cumulator::AbstractVector{T}, op::F, ::EvalOptions{false} )::ResultOk where {T,F} @inbounds @simd for j in eachindex(cumulator) x = op(cumulator[j])::T @@ -213,7 +213,7 @@ end cX::AbstractMatrix{T}, op_idx::Integer, operators::OperatorEnum, - eval_options::EvaluationOptions, + eval_options::EvalOptions, ) where {T} nbin = get_nbin(operators) long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN @@ -271,7 +271,7 @@ end cX::AbstractMatrix{T}, op_idx::Integer, operators::OperatorEnum, - eval_options::EvaluationOptions, + eval_options::EvalOptions, ) where {T} nuna = get_nuna(operators) long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN @@ -320,7 +320,7 @@ end op::F, l_op_idx::Integer, binops, - eval_options::EvaluationOptions, + eval_options::EvalOptions, ) where {T,F} nbin = counttuple(binops) # (Note this is only called from dispatch_deg1_eval, which has already @@ -341,7 +341,7 @@ end op::F, l_op_idx::Integer, unaops, - eval_options::EvaluationOptions, + eval_options::EvalOptions, )::ResultOk where {T,F} nuna = counttuple(unaops) quote @@ -360,7 +360,7 @@ function deg1_l2_ll0_lr0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvaluationOptions{false,false}, + ::EvalOptions{false,false}, ) where {T,F,F2} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val @@ -413,7 +413,7 @@ function deg1_l1_ll0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvaluationOptions{false,false}, + ::EvalOptions{false,false}, ) where {T,F,F2} if tree.l.l.constant val_ll = tree.l.l.val @@ -440,7 +440,7 @@ function deg2_l0_r0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, - ::EvaluationOptions{false,false}, + ::EvalOptions{false,false}, ) where {T,F} if tree.l.constant && tree.r.constant val_l = tree.l.val @@ -488,7 +488,7 @@ function deg2_l0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvaluationOptions{false,false}, + ::EvalOptions{false,false}, ) where {T,F} if tree.l.constant val = tree.l.val @@ -514,7 +514,7 @@ function deg2_r0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvaluationOptions{false,false}, + ::EvalOptions{false,false}, ) where {T,F} if tree.r.constant val = tree.r.val diff --git a/src/precompile.jl b/src/precompile.jl index f7a72b19..d16bc6b7 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -38,7 +38,7 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types # Trivial: for l in (x, c) @ignore_domain_error eval_tree_array( - l, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) + l, X, operators; eval_options=EvalOptions(; turbo=use_turbo) ) end @@ -47,7 +47,7 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types tree = Node(i, l, r) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvalOptions(; turbo=use_turbo) ) end @@ -56,13 +56,13 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types tree = Node(j, l) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvalOptions(; turbo=use_turbo) ) tree = Node(j, Node(k, l)) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvalOptions(; turbo=use_turbo) ) end @@ -76,13 +76,13 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types tree = Node(i, Node(j1, l), Node(j2, r)) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvalOptions(; turbo=use_turbo) ) tree = Node(j1, Node(i, l, r)) tree = convert(Node{T}, tree) @ignore_domain_error eval_tree_array( - tree, X, operators; eval_options=EvaluationOptions(; turbo=use_turbo) + tree, X, operators; eval_options=EvalOptions(; turbo=use_turbo) ) end end diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 78a12bf9..57779554 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -85,7 +85,7 @@ end @testitem "Test specific branches of evaluation" begin using DynamicExpressions, DynamicExpressions, Bumper, LoopVectorization - using DynamicExpressions.EvaluateModule: EvaluationOptions + using DynamicExpressions.EvaluateModule: EvalOptions include("test_params.jl") @@ -101,7 +101,7 @@ end @test repr(tree) == "cos(cos(3.0))" tree = convert(Node{T}, tree) truth = cos(cos(T(3.0f0))) - @test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, EvaluationOptions(; turbo)).x[1] ≈ + @test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, EvalOptions(; turbo)).x[1] ≈ truth # op(, ) @@ -109,7 +109,7 @@ end @test repr(tree) == "3.0 + 4.0" tree = convert(Node{T}, tree) truth = T(3.0f0) + T(4.0f0) - @test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), EvaluationOptions(; turbo)).x[1] ≈ + @test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), EvalOptions(; turbo)).x[1] ≈ truth # op(op(, )) @@ -117,7 +117,7 @@ end @test repr(tree) == "cos(3.0 + 4.0)" tree = convert(Node{T}, tree) truth = cos(T(3.0f0) + T(4.0f0)) - @test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), EvaluationOptions(; turbo)).x[1] ≈ + @test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), EvalOptions(; turbo)).x[1] ≈ truth # Test for presence of NaNs: @@ -244,7 +244,7 @@ end ) X = T[1.0 floatmax(T)] @test all(isnan.(ex(X))) - @test ex(X; eval_options=EvaluationOptions(; early_exit=Val(false))) ≈ [2.0, Inf] + @test ex(X; eval_options=EvalOptions(; early_exit=Val(false))) ≈ [2.0, Inf] end for turbo in [Val(false), Val(true)], @@ -263,8 +263,8 @@ end 1 floatmax(T) 1 1 ] - @test all(isnan.(ex(X; eval_options=EvaluationOptions(; bumper, turbo)))) - y = ex(X; eval_options=EvaluationOptions(; bumper, turbo, early_exit=false)) + @test all(isnan.(ex(X; eval_options=EvalOptions(; bumper, turbo)))) + y = ex(X; eval_options=EvalOptions(; bumper, turbo, early_exit=false)) @test y[1] == T(-1.618033988749895) # FIXME: this is NaN on macOS and -Inf on windows/ubuntu... @test !isfinite(y[2]) diff --git a/test/test_initial_errors.jl b/test/test_initial_errors.jl index 4aedaf67..fb4f5974 100644 --- a/test/test_initial_errors.jl +++ b/test/test_initial_errors.jl @@ -1,4 +1,5 @@ using DynamicExpressions +using DynamicExpressions: EvalOptions using DispatchDoctor: allow_unstable using Test @@ -40,15 +41,14 @@ if VERSION >= v"1.9" @test_throws( "Please load the Bumper.jl package", allow_unstable( - () -> - tree(ones(2, 10), operators; options=EvaluationOptions(; bumper=Val(true))), + () -> tree(ones(2, 10), operators; options=EvalOptions(; bumper=Val(true))) ) ) @test_throws( "Please load the LoopVectorization.jl package", allow_unstable( - () -> tree(ones(2, 10), operators; options=EvaluationOptions(; turbo=Val(true))) + () -> tree(ones(2, 10), operators; options=EvalOptions(; turbo=Val(true))) ) ) end From dd24df654ff89df001dfd69733c3066b44c8cdd5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 20 Jul 2024 00:23:53 +0100 Subject: [PATCH 17/41] test: fix initial errors test --- test/test_initial_errors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_initial_errors.jl b/test/test_initial_errors.jl index fb4f5974..56e69cb2 100644 --- a/test/test_initial_errors.jl +++ b/test/test_initial_errors.jl @@ -41,14 +41,14 @@ if VERSION >= v"1.9" @test_throws( "Please load the Bumper.jl package", allow_unstable( - () -> tree(ones(2, 10), operators; options=EvalOptions(; bumper=Val(true))) + () -> tree(ones(2, 10), operators; eval_options=EvalOptions(; bumper=Val(true))) ) ) @test_throws( "Please load the LoopVectorization.jl package", allow_unstable( - () -> tree(ones(2, 10), operators; options=EvalOptions(; turbo=Val(true))) + () -> tree(ones(2, 10), operators; eval_options=EvalOptions(; turbo=Val(true))) ) ) end From 63020121e946b18e1cfb39e8e0f32ab02275171f Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Mon, 22 Jul 2024 22:09:06 +0200 Subject: [PATCH 18/41] fix: type unstalbe tests --- test/test_evaluation.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 57779554..49dad640 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -89,8 +89,8 @@ end include("test_params.jl") - for turbo in [false, true], T in [Float16, Float32, Float64, ComplexF32, ComplexF64] - turbo && !(T in (Float32, Float64)) && continue + for turbo in [Val(false), Val(true)], T in [Float16, Float32, Float64, ComplexF32, ComplexF64] + turbo isa Val{true} && !(T in (Float32, Float64)) && continue # Test specific branches of evaluation code: # op(op()) local tree, operators @@ -264,9 +264,8 @@ end 1 1 ] @test all(isnan.(ex(X; eval_options=EvalOptions(; bumper, turbo)))) - y = ex(X; eval_options=EvalOptions(; bumper, turbo, early_exit=false)) + y = ex(X; eval_options=EvalOptions(; bumper, turbo, early_exit=Val(false))) @test y[1] == T(-1.618033988749895) - # FIXME: this is NaN on macOS and -Inf on windows/ubuntu... @test !isfinite(y[2]) end end From a73a04fb41e74350a672dbdf8fe04f0623a4086b Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Mon, 22 Jul 2024 22:09:32 +0200 Subject: [PATCH 19/41] add doc strings --- src/Evaluate.jl | 41 +++++++++++++++++++++++++++++++++++++--- src/EvaluationHelpers.jl | 18 ++++++------------ 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index bb061b74..fd28761f 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -28,6 +28,31 @@ macro return_on_nonfinite_array(array) ) end +""" + EvalOptions{T,B,E} + +EvalOptions contain flags for the different modes to evaluate an expression. + +# Fields + +- `turbo::Val`: If `Val{true}`, use LoopVectorization.jl for faster + evaluation. +- `bumper::Val`: If `Val{true}, use Bumper.jl for faster evaluation. +- `early_exit::Val`: If `Val{true}`, any element of any step becoming + `NaN` or `Inf` will terminate the computation and the whole buffer will be + returned with `NaN`s. This makes sure that expressions with singularities + don't wast compute cycles. Setting `Val{false}` will continue the computation + as usual and thus result in `NaN`s only in the elements that actually have + `NaN`s. + +# Constructors + + EvalOptions(; turbo=Val(false), bumper=Val(false), early_exit=Val(true)) + +Construct EvalOptions with defaults. Can also be called with boolean values +instead of `Val`s for convenience, although this should be avoided as it +introduces a type instability. +""" struct EvalOptions{T,B,E} turbo::Val{T} bumper::Val{B} @@ -41,7 +66,14 @@ function EvalOptions(; turbo=Val(false), bumper=Val(false), early_exit=Val(true) end """ - eval_tree_array(tree::AbstractExpressionNode, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false), bumper::Union{Bool,Val}=Val(false)) + eval_tree_array( + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum; + eval_options::Union{EvalOptions,Nothing}=nothing, + turbo::Union{Bool,Val,Nothing}=nothing, + bumper::Union{Bool,Val,Nothing}=nothing, + ) where {T} Evaluate a binary tree (equation) over a given input data matrix. The operators contain all of the operators used. This function fuses doublets @@ -51,8 +83,11 @@ and triplets of operations for lower memory usage. - `tree::AbstractExpressionNode`: The root node of the tree to evaluate. - `cX::AbstractMatrix{T}`: The input data to evaluate the tree on. - `operators::OperatorEnum`: The operators used in the tree. -- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation. -- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation. +- `eval_options::Union{EvalOptions,Nothing}`: See EvalOptions for documenation + on the different evaluation modes. +- `turbo::Union{Bool,Val,Nothing}`: Deprecated. Part of EvalOptions now. +- `bumper::Union{Bool,Val,Nothing}`: Deprecated. Part of EvalOptions now. + # Returns - `(output, complete)::Tuple{AbstractVector{T}, Bool}`: the result, diff --git a/src/EvaluationHelpers.jl b/src/EvaluationHelpers.jl index 5e9a0b12..3762bcd0 100644 --- a/src/EvaluationHelpers.jl +++ b/src/EvaluationHelpers.jl @@ -8,7 +8,7 @@ import ..EvaluateDerivativeModule: eval_grad_tree_array # Evaluation: """ - (tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=false, bumper::Union{Bool,Val}=Val(false)) + (tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...) Evaluate a binary tree (equation) over a given input data matrix. The operators contain all of the operators used. This function fuses doublets @@ -16,10 +16,10 @@ and triplets of operations for lower memory usage. # Arguments - `tree::AbstractExpressionNode`: The root node of the tree to evaluate. -- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on. +- `X::AbstractMatrix{T}`: The input data to evaluate the tree on. - `operators::OperatorEnum`: The operators used in the tree. -- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation. -- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation. +- `kws...`: Passed to `eval_tree_array`. + # Returns - `output::AbstractVector{T}`: the result, which is a 1D array. @@ -32,18 +32,12 @@ function (tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...) return out end """ - (tree::AbstractExpressionNode)(X::AbstractMatrix, operators::GenericOperatorEnum; throw_errors::Bool=true) + (tree::AbstractExpressionNode)(X, operators::GenericOperatorEnum; kws...) # Arguments - `X::AbstractArray`: The input data to evaluate the tree on. - `operators::GenericOperatorEnum`: The operators used in the tree. -- `throw_errors::Bool=true`: Whether to throw errors - if they occur during evaluation. Otherwise, - MethodErrors will be caught before they happen and - evaluation will return `nothing`, - rather than throwing an error. This is useful in cases - where you are unsure if a particular tree is valid or not, - and would prefer to work with `nothing` as an output. +- `kws...`: Passed to `eval_tree_array`. # Returns - `output`: the result of the evaluation. From fe30e8ba9d2e2461b29401bed16f553ea4c92d82 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Mon, 22 Jul 2024 22:56:10 +0200 Subject: [PATCH 20/41] update docs --- docs/src/eval.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/src/eval.md b/docs/src/eval.md index 82adec7b..1ca1c793 100644 --- a/docs/src/eval.md +++ b/docs/src/eval.md @@ -6,14 +6,21 @@ Given an expression tree specified with a `Node` type, you may evaluate the expr over an array of data with the following command: ```@docs -eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number} +eval_tree_array( + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + operators::OperatorEnum; + eval_options::Union{EvalOptions,Nothing}=nothing, + turbo::Union{Bool,Val,Nothing}=nothing, + bumper::Union{Bool,Val,Nothing}=nothing, +) where {T} ``` Assuming you are only using a single `OperatorEnum`, you can also use the following shorthand by using the expression as a function: ``` - (tree::AbstractExpressionNode)(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=false, bumper::Union{Bool,Val}=Val(false)) + (tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...) Evaluate a binary tree (equation) over a given input data matrix. The operators contain all of the operators used. This function fuses doublets @@ -23,8 +30,7 @@ and triplets of operations for lower memory usage. - `tree::AbstractExpressionNode`: The root node of the tree to evaluate. - `cX::AbstractMatrix{T}`: The input data to evaluate the tree on. - `operators::OperatorEnum`: The operators used in the tree. -- `turbo::Union{Bool,Val}`: Use LoopVectorization.jl for faster evaluation. -- `bumper::Union{Bool,Val}`: Use Bumper.jl for faster evaluation. +- `kws...`: Passed to `eval_tree_array`. # Returns - `output::AbstractVector{T}`: the result, which is a 1D array. From 944b2e817184f7ee74d9af2a2f00164b5adc4027 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Mon, 22 Jul 2024 22:57:39 +0200 Subject: [PATCH 21/41] format --- test/test_evaluation.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 49dad640..8f71537b 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -89,7 +89,9 @@ end include("test_params.jl") - for turbo in [Val(false), Val(true)], T in [Float16, Float32, Float64, ComplexF32, ComplexF64] + for turbo in [Val(false), Val(true)], + T in [Float16, Float32, Float64, ComplexF32, ComplexF64] + turbo isa Val{true} && !(T in (Float32, Float64)) && continue # Test specific branches of evaluation code: # op(op()) From 905f5e185dcd82fc167aaefd66dfb61ee6533d47 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Tue, 23 Jul 2024 12:03:42 +0200 Subject: [PATCH 22/41] approx equal --- test/test_evaluation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 8f71537b..dbf9f224 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -267,7 +267,7 @@ end ] @test all(isnan.(ex(X; eval_options=EvalOptions(; bumper, turbo)))) y = ex(X; eval_options=EvalOptions(; bumper, turbo, early_exit=Val(false))) - @test y[1] == T(-1.618033988749895) + @test y[1] ≈ T(-1.618033988749895) @test !isfinite(y[2]) end end From 0a2bb96415d6efa0f5173aa715715d9cb090afd7 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Wed, 24 Jul 2024 14:59:50 +0200 Subject: [PATCH 23/41] fix enzyme test --- test/test_enzyme.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/test_enzyme.jl b/test/test_enzyme.jl index d5dae094..d0182e22 100644 --- a/test/test_enzyme.jl +++ b/test/test_enzyme.jl @@ -5,9 +5,6 @@ using DynamicExpressions operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin)) # TODO: More operators will trigger a segfault in Enzyme -# These options are required for Enzyme to work: -const eval_options = (turbo=Val(false),) - x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3) tree = Node(1, x1, Node(1, x2)) # == x1 + cos(x2) @@ -16,7 +13,7 @@ X = randn(3, 100); dX = zero(X) function f(tree, X, operators, output) - output[] = sum(eval_tree_array(tree, X, operators; eval_options...)[1]) + output[] = sum(eval_tree_array(tree, X, operators)[1]) return nothing end From 6bd504bae3eb4144864c45d544e3a29b03221e58 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Thu, 25 Jul 2024 14:26:42 +0200 Subject: [PATCH 24/41] Update docs/src/eval.md Co-authored-by: Miles Cranmer --- docs/src/eval.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/eval.md b/docs/src/eval.md index 1ca1c793..1d92874f 100644 --- a/docs/src/eval.md +++ b/docs/src/eval.md @@ -30,7 +30,7 @@ and triplets of operations for lower memory usage. - `tree::AbstractExpressionNode`: The root node of the tree to evaluate. - `cX::AbstractMatrix{T}`: The input data to evaluate the tree on. - `operators::OperatorEnum`: The operators used in the tree. -- `kws...`: Passed to `eval_tree_array`. +- `kws...`: Passed to [`eval_tree_array`](@ref). # Returns - `output::AbstractVector{T}`: the result, which is a 1D array. From 54c639828602484a2a0801786e8682f5698a6f02 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 24 Jul 2024 23:34:56 +0100 Subject: [PATCH 25/41] test: disable enzyme test --- test/test_enzyme.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_enzyme.jl b/test/test_enzyme.jl index d0182e22..a3cd9f09 100644 --- a/test/test_enzyme.jl +++ b/test/test_enzyme.jl @@ -33,6 +33,8 @@ true_dX = cat(ones(100), -sin.(X[2, :]), zeros(100); dims=2)' @test true_dX ≈ dX +#! format: off +@static if false # Broken test (see https://github.com/EnzymeAD/Enzyme.jl/issues/1241) function my_loss_function(tree, X, operators) # Get the outputs @@ -65,3 +67,6 @@ d_tree = begin end @test_broken get_scalar_constants(d_tree) ≈ [1.0, 0.717356] +end + +#! format: on From 958b9aff595a2cfb4f05f58eedfa913d9e42e624 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 25 Jul 2024 14:10:36 +0100 Subject: [PATCH 26/41] test: skip Enzyme test completely --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 55d376e3..f37fa390 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -83,6 +83,7 @@ jobs: - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - name: Run tests + continue-on-error: ${{ matrix.test_name == 'enzyme' }} run: | julia --color=yes -e 'import Pkg; Pkg.add("Coverage")' SR_TEST=${{ matrix.test_name }} julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)' From 1365a55bd9053d278dad4ca0224a3eb48af16af1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 26 Jul 2024 19:14:40 +0100 Subject: [PATCH 27/41] test: fix Enzyme test --- test/test_enzyme.jl | 46 ++++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/test/test_enzyme.jl b/test/test_enzyme.jl index a3cd9f09..287ccea9 100644 --- a/test/test_enzyme.jl +++ b/test/test_enzyme.jl @@ -20,25 +20,24 @@ end output = [0.0] doutput = [1.0] -autodiff( - Reverse, - f, - Const(tree), - Duplicated(X, dX), - Const(operators), - Duplicated(output, doutput), -) +fetch(schedule(Task(64 * 1024^2) do + autodiff( + Reverse, + f, + Const(tree), + Duplicated(X, dX), + Const(operators), + Duplicated(output, doutput), + ) +end)) true_dX = cat(ones(100), -sin.(X[2, :]), zeros(100); dims=2)' @test true_dX ≈ dX -#! format: off -@static if false -# Broken test (see https://github.com/EnzymeAD/Enzyme.jl/issues/1241) function my_loss_function(tree, X, operators) # Get the outputs - y = tree(X, operators) + y, _ = eval_tree_array(tree, X, operators) # Sum them (so we can take a gradient, rather than a jacobian) return sum(y) end @@ -55,18 +54,17 @@ d_tree = begin node.val = 0.0 end end - autodiff( - Reverse, - my_loss_function, - Active, - Duplicated(tree, storage_tree), - Const(X), - Const(operators), - ) + fetch(schedule(Task(64 * 1024^2) do + autodiff( + Reverse, + my_loss_function, + Active, + Duplicated(tree, storage_tree), + Const(X), + Const(operators), + ) + end)) storage_tree end -@test_broken get_scalar_constants(d_tree) ≈ [1.0, 0.717356] -end - -#! format: on +@test isapprox(first(get_scalar_constants(d_tree)), [1.0, 0.717356]; atol=1e-3) From cbcd221f8c31271a3da05536c4c0c3bbf8cce6e4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 26 Jul 2024 19:54:03 +0100 Subject: [PATCH 28/41] style: fix formatting --- test/test_enzyme.jl | 48 ++++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/test/test_enzyme.jl b/test/test_enzyme.jl index 287ccea9..b35b9361 100644 --- a/test/test_enzyme.jl +++ b/test/test_enzyme.jl @@ -20,16 +20,20 @@ end output = [0.0] doutput = [1.0] -fetch(schedule(Task(64 * 1024^2) do - autodiff( - Reverse, - f, - Const(tree), - Duplicated(X, dX), - Const(operators), - Duplicated(output, doutput), - ) -end)) +fetch( + schedule( + Task(64 * 1024^2) do + autodiff( + Reverse, + f, + Const(tree), + Duplicated(X, dX), + Const(operators), + Duplicated(output, doutput), + ) + end, + ), +) true_dX = cat(ones(100), -sin.(X[2, :]), zeros(100); dims=2)' @@ -54,16 +58,20 @@ d_tree = begin node.val = 0.0 end end - fetch(schedule(Task(64 * 1024^2) do - autodiff( - Reverse, - my_loss_function, - Active, - Duplicated(tree, storage_tree), - Const(X), - Const(operators), - ) - end)) + fetch( + schedule( + Task(64 * 1024^2) do + autodiff( + Reverse, + my_loss_function, + Active, + Duplicated(tree, storage_tree), + Const(X), + Const(operators), + ) + end, + ), + ) storage_tree end From 87c6225b76c8bf863f53be83de2858d3b1320cb6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 26 Jul 2024 20:27:47 +0100 Subject: [PATCH 29/41] ci: install fixed Enzyme --- .github/workflows/CI.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f37fa390..c667f621 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -82,8 +82,10 @@ jobs: version: ${{ matrix.julia-version }} - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 + - name: Add fixed version of Enzyme + if: ${{ matrix.test_name == 'enzyme' }} + run: julia --color=yes --project=test/ -e 'import Pkg; Pkg.add(url="https://github.com/EnzymeAD/Enzyme.jl"; rev="ebc27a17dcae0d69f788c1647ef643dbcf20913d")' - name: Run tests - continue-on-error: ${{ matrix.test_name == 'enzyme' }} run: | julia --color=yes -e 'import Pkg; Pkg.add("Coverage")' SR_TEST=${{ matrix.test_name }} julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)' From 86d209660b6e3785aa695200cddd39968408633e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 26 Jul 2024 21:04:13 +0100 Subject: [PATCH 30/41] fix issue due to https://github.com/JuliaLang/Pkg.jl/issues/1585 --- .github/workflows/CI.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c667f621..5f9b27ad 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -84,7 +84,9 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - name: Add fixed version of Enzyme if: ${{ matrix.test_name == 'enzyme' }} - run: julia --color=yes --project=test/ -e 'import Pkg; Pkg.add(url="https://github.com/EnzymeAD/Enzyme.jl"; rev="ebc27a17dcae0d69f788c1647ef643dbcf20913d")' + run: | + julia --color=yes --project=test/ -e 'import Pkg; Pkg.add(url="https://github.com/EnzymeAD/Enzyme.jl"; rev="ebc27a17dcae0d69f788c1647ef643dbcf20913d")' + julia --color=yes --project=. -e 'import Pkg; Pkg.add(url="https://github.com/EnzymeAD/Enzyme.jl"; rev="ebc27a17dcae0d69f788c1647ef643dbcf20913d")' - name: Run tests run: | julia --color=yes -e 'import Pkg; Pkg.add("Coverage")' From 796cbae76397c3f98a9a4b64a1eeccfcc47b7634 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 26 Jul 2024 21:24:39 +0100 Subject: [PATCH 31/41] fix custom enzyme install --- .github/workflows/CI.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5f9b27ad..b7d9813d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -82,15 +82,10 @@ jobs: version: ${{ matrix.julia-version }} - uses: julia-actions/cache@v1 - uses: julia-actions/julia-buildpkg@v1 - - name: Add fixed version of Enzyme - if: ${{ matrix.test_name == 'enzyme' }} - run: | - julia --color=yes --project=test/ -e 'import Pkg; Pkg.add(url="https://github.com/EnzymeAD/Enzyme.jl"; rev="ebc27a17dcae0d69f788c1647ef643dbcf20913d")' - julia --color=yes --project=. -e 'import Pkg; Pkg.add(url="https://github.com/EnzymeAD/Enzyme.jl"; rev="ebc27a17dcae0d69f788c1647ef643dbcf20913d")' - name: Run tests run: | julia --color=yes -e 'import Pkg; Pkg.add("Coverage")' - SR_TEST=${{ matrix.test_name }} julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)' + SR_TEST=${{ matrix.test_name }} julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.add(url="https://github.com/EnzymeAD/Enzyme.jl"; rev="ebc27a17dcae0d69f788c1647ef643dbcf20913d"); Pkg.test(coverage=true)' julia --color=yes coverage.jl shell: bash - name: Coveralls From 5a30d473affca7345b8fd944f0cd51e4fdeedb66 Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Sun, 28 Jul 2024 07:54:44 +0900 Subject: [PATCH 32/41] ci: remove Enzyme revision test --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1909ce51..82164262 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -87,7 +87,7 @@ jobs: continue-on-error: ${{ matrix.test_name == 'enzyme' }} run: | julia --color=yes -e 'import Pkg; Pkg.add("Coverage")' - SR_TEST=${{ matrix.test_name }} julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.add(url="https://github.com/EnzymeAD/Enzyme.jl"; rev="ebc27a17dcae0d69f788c1647ef643dbcf20913d"); Pkg.test(coverage=true)' + SR_TEST=${{ matrix.test_name }} julia --color=yes --threads=auto --check-bounds=yes --depwarn=yes --code-coverage=user -e 'import Coverage; import Pkg; Pkg.activate("."); Pkg.test(coverage=true)' julia --color=yes coverage.jl shell: bash - name: Coveralls From 64c797c09629acccd9804c5ff3b6c6e409a5be46 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 00:15:36 +0100 Subject: [PATCH 33/41] refactor: reduce code complexity of eval options --- src/Evaluate.jl | 87 ++++++++++++++++++++++++------------------------- 1 file changed, 42 insertions(+), 45 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index fd28761f..087f0d9f 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -36,33 +36,53 @@ EvalOptions contain flags for the different modes to evaluate an expression. # Fields - `turbo::Val`: If `Val{true}`, use LoopVectorization.jl for faster - evaluation. + evaluation. - `bumper::Val`: If `Val{true}, use Bumper.jl for faster evaluation. - `early_exit::Val`: If `Val{true}`, any element of any step becoming - `NaN` or `Inf` will terminate the computation and the whole buffer will be - returned with `NaN`s. This makes sure that expressions with singularities - don't wast compute cycles. Setting `Val{false}` will continue the computation - as usual and thus result in `NaN`s only in the elements that actually have - `NaN`s. - -# Constructors - - EvalOptions(; turbo=Val(false), bumper=Val(false), early_exit=Val(true)) - -Construct EvalOptions with defaults. Can also be called with boolean values -instead of `Val`s for convenience, although this should be avoided as it -introduces a type instability. + `NaN` or `Inf` will terminate the computation and the whole buffer will be + returned with `NaN`s. This makes sure that expressions with singularities + don't wast compute cycles. Setting `Val{false}` will continue the computation + as usual and thus result in `NaN`s only in the elements that actually have + `NaN`s. """ struct EvalOptions{T,B,E} turbo::Val{T} bumper::Val{B} early_exit::Val{E} end -function EvalOptions(; turbo=Val(false), bumper=Val(false), early_exit=Val(true)) - v_turbo = isa(turbo, Val) ? turbo : (turbo ? Val(true) : Val(false)) - v_bumper = isa(bumper, Val) ? bumper : (bumper ? Val(true) : Val(false)) - v_early_exit = isa(early_exit, Val) ? early_exit : (early_exit ? Val(true) : Val(false)) - return EvalOptions(v_turbo, v_bumper, v_early_exit) + +@inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false) +@inline _to_bool_val(x::Val{T}) where {T} = Val(T::Bool) + +function EvalOptions(; + turbo::Union{Bool,Val}=Val(false), + bumper::Union{Bool,Val}=Val(false), + early_exit::Union{Bool,Val}=Val(true), +) + return EvalOptions(_to_bool_val(turbo), _to_bool_val(bumper), _to_bool_val(early_exit)) +end + +function _process_deprecated_kws(eval_options, deprecated_kws) + turbo = get(deprecated_kws, :turbo, nothing) + bumper = get(deprecated_kws, :bumper, nothing) + if any(Base.Fix2(∉, (:turbo, :bumper)), keys(deprecated_kws)) + throw(ArgumentError("Invalid keyword argument(s): $(keys(deprecated_kws))")) + end + if !isempty(deprecated_kws) + @assert eval_options === nothing "Cannot use both `eval_options` and deprecated flags `turbo` and `bumper`." + Base.depwarn( + "The `turbo` and `bumper` keyword arguments are deprecated. Please use `eval_options` instead.", + :eval_tree_array, + ) + end + if eval_options !== nothing + return eval_options + else + return EvalOptions(; + turbo=turbo === nothing ? Val(false) : turbo, + bumper=bumper === nothing ? Val(false) : bumper, + ) + end end """ @@ -71,8 +91,6 @@ end cX::AbstractMatrix{T}, operators::OperatorEnum; eval_options::Union{EvalOptions,Nothing}=nothing, - turbo::Union{Bool,Val,Nothing}=nothing, - bumper::Union{Bool,Val,Nothing}=nothing, ) where {T} Evaluate a binary tree (equation) over a given input data matrix. The @@ -84,9 +102,7 @@ and triplets of operations for lower memory usage. - `cX::AbstractMatrix{T}`: The input data to evaluate the tree on. - `operators::OperatorEnum`: The operators used in the tree. - `eval_options::Union{EvalOptions,Nothing}`: See EvalOptions for documenation - on the different evaluation modes. -- `turbo::Union{Bool,Val,Nothing}`: Deprecated. Part of EvalOptions now. -- `bumper::Union{Bool,Val,Nothing}`: Deprecated. Part of EvalOptions now. + on the different evaluation modes. # Returns @@ -116,28 +132,9 @@ function eval_tree_array( cX::AbstractMatrix{T}, operators::OperatorEnum; eval_options::Union{EvalOptions,Nothing}=nothing, - turbo::Union{Bool,Val,Nothing}=nothing, - bumper::Union{Bool,Val,Nothing}=nothing, + _deprecated_kws..., ) where {T} - @assert( - eval_options === nothing || (turbo === nothing && bumper === nothing), - "Cannot use both `eval_options` and deprecated flags `turbo` and `bumper`." - ) - #! format: off - _eval_options = - if eval_options !== nothing - eval_options - else - (turbo !== nothing || bumper !== nothing) && Base.depwarn( - "The `turbo` and `bumper` keyword arguments are deprecated. Please use `eval_options` instead.", - :eval_tree_array, - ) - EvalOptions(; - turbo = turbo === nothing ? Val(false) : turbo, - bumper = bumper === nothing ? Val(false) : bumper, - ) - end - #! format: on + _eval_options = _process_deprecated_kws(eval_options, _deprecated_kws) if _eval_options.turbo isa Val{true} || _eval_options.bumper isa Val{true} @assert T in (Float32, Float64) end From a87016e19eff3eb6bceb9f0b6d615d4ac85d2dbb Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 00:32:27 +0100 Subject: [PATCH 34/41] docs: render `EvalOptions` in docs --- docs/src/eval.md | 13 +++++++++---- src/Evaluate.jl | 10 +++++----- src/Node.jl | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/docs/src/eval.md b/docs/src/eval.md index 1d92874f..370cda80 100644 --- a/docs/src/eval.md +++ b/docs/src/eval.md @@ -11,13 +11,10 @@ eval_tree_array( cX::AbstractMatrix{T}, operators::OperatorEnum; eval_options::Union{EvalOptions,Nothing}=nothing, - turbo::Union{Bool,Val,Nothing}=nothing, - bumper::Union{Bool,Val,Nothing}=nothing, ) where {T} ``` -Assuming you are only using a single `OperatorEnum`, you can also use -the following shorthand by using the expression as a function: +You can also use the following shorthand by using the expression as a function: ``` (tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...) @@ -59,6 +56,14 @@ It also re-defines `print`, `show`, and the various operators, to work with the Thus, if you define an expression with one `OperatorEnum`, and then try to evaluate it or print it with a different `OperatorEnum`, you will get undefined behavior! + For safer behavior, you should use [`Expression`](@ref) objects. + +Evaluation options are specified using `EvalOptions`: + +```@docs +EvalOptions +``` + You can also work with arbitrary types, by defining a `GenericOperatorEnum` instead. The notation is the same for `eval_tree_array`, though it will return `nothing` when it can't find a method, and not do any NaN checks: diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 087f0d9f..7323aaac 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -31,14 +31,14 @@ end """ EvalOptions{T,B,E} -EvalOptions contain flags for the different modes to evaluate an expression. +This holds options for expression evaluation, such as evaluation backend. # Fields -- `turbo::Val`: If `Val{true}`, use LoopVectorization.jl for faster +- `turbo::Val{T}`: If `Val{true}`, use LoopVectorization.jl for faster evaluation. -- `bumper::Val`: If `Val{true}, use Bumper.jl for faster evaluation. -- `early_exit::Val`: If `Val{true}`, any element of any step becoming +- `bumper::Val{B}`: If `Val{true}, use Bumper.jl for faster evaluation. +- `early_exit::Val{E}`: If `Val{true}`, any element of any step becoming `NaN` or `Inf` will terminate the computation and the whole buffer will be returned with `NaN`s. This makes sure that expressions with singularities don't wast compute cycles. Setting `Val{false}` will continue the computation @@ -101,7 +101,7 @@ and triplets of operations for lower memory usage. - `tree::AbstractExpressionNode`: The root node of the tree to evaluate. - `cX::AbstractMatrix{T}`: The input data to evaluate the tree on. - `operators::OperatorEnum`: The operators used in the tree. -- `eval_options::Union{EvalOptions,Nothing}`: See EvalOptions for documenation +- `eval_options::Union{EvalOptions,Nothing}`: See [`EvalOptions`](@ref) for documentation on the different evaluation modes. diff --git a/src/Node.jl b/src/Node.jl index 095a6bc5..4355aea5 100644 --- a/src/Node.jl +++ b/src/Node.jl @@ -142,7 +142,7 @@ be performed with this assumption, to preserve structure of the graph. ```julia julia> operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos, sin] - ); + ); julia> x = GraphNode(feature=1) x1 From ace5c19c7b1cf12a811807bf8330b1b5fe51d86f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 01:08:55 +0100 Subject: [PATCH 35/41] test: more coverage of `EvalOptions` branches --- src/Evaluate.jl | 76 +++++++++++++++++++++++---------------- test/test_deprecations.jl | 12 +++++-- test/test_evaluation.jl | 12 +++++++ 3 files changed, 66 insertions(+), 34 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 7323aaac..29a25177 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -1,6 +1,6 @@ module EvaluateModule -using DispatchDoctor: @unstable +using DispatchDoctor: @stable, @unstable import ..NodeModule: AbstractExpressionNode, constructorof import ..StringsModule: string_tree @@ -51,39 +51,53 @@ struct EvalOptions{T,B,E} early_exit::Val{E} end -@inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false) -@inline _to_bool_val(x::Val{T}) where {T} = Val(T::Bool) - -function EvalOptions(; - turbo::Union{Bool,Val}=Val(false), - bumper::Union{Bool,Val}=Val(false), - early_exit::Union{Bool,Val}=Val(true), +@stable( + default_mode = "disable", + default_union_limit = 2, + @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false) ) - return EvalOptions(_to_bool_val(turbo), _to_bool_val(bumper), _to_bool_val(early_exit)) -end - -function _process_deprecated_kws(eval_options, deprecated_kws) - turbo = get(deprecated_kws, :turbo, nothing) - bumper = get(deprecated_kws, :bumper, nothing) - if any(Base.Fix2(∉, (:turbo, :bumper)), keys(deprecated_kws)) - throw(ArgumentError("Invalid keyword argument(s): $(keys(deprecated_kws))")) - end - if !isempty(deprecated_kws) - @assert eval_options === nothing "Cannot use both `eval_options` and deprecated flags `turbo` and `bumper`." - Base.depwarn( - "The `turbo` and `bumper` keyword arguments are deprecated. Please use `eval_options` instead.", - :eval_tree_array, - ) - end - if eval_options !== nothing - return eval_options - else - return EvalOptions(; - turbo=turbo === nothing ? Val(false) : turbo, - bumper=bumper === nothing ? Val(false) : bumper, +@stable(default_mode = "disable", @inline _to_bool_val(x::Val{T}) where {T} = Val(T::Bool)) + +@stable( + default_mode = "disable", + default_union_limit = 4, + begin + function EvalOptions(; + turbo::Union{Bool,Val}=Val(false), + bumper::Union{Bool,Val}=Val(false), + early_exit::Union{Bool,Val}=Val(true), ) + return EvalOptions( + _to_bool_val(turbo), _to_bool_val(bumper), _to_bool_val(early_exit) + ) + end + + function _process_deprecated_kws(eval_options, deprecated_kws) + turbo = get(deprecated_kws, :turbo, nothing) + bumper = get(deprecated_kws, :bumper, nothing) + if any(Base.Fix2(∉, (:turbo, :bumper)), keys(deprecated_kws)) + throw( + ArgumentError("Invalid keyword argument(s): $(keys(deprecated_kws))") + ) + end + if !isempty(deprecated_kws) + @assert eval_options === nothing "Cannot use both `eval_options` and deprecated flags `turbo` and `bumper`." + Base.depwarn( + "The `turbo` and `bumper` keyword arguments are deprecated. Please use `eval_options` instead.", + :eval_tree_array, + ) + end + if eval_options !== nothing + return eval_options + else + return EvalOptions(; + turbo=turbo === nothing ? Val(false) : turbo, + bumper=bumper === nothing ? Val(false) : bumper, + ) + end + end end -end +) """ eval_tree_array( diff --git a/test/test_deprecations.jl b/test/test_deprecations.jl index 0b1af26b..45decb0f 100644 --- a/test/test_deprecations.jl +++ b/test/test_deprecations.jl @@ -1,6 +1,4 @@ -using DynamicExpressions -using Test -using Zygote +using Test, DynamicExpressions, Zygote, LoopVectorization using Suppressor: @capture_err using DispatchDoctor: allow_unstable @@ -47,6 +45,14 @@ if VERSION >= v"1.9" ) end +# Old usage of evaluation options +if VERSION >= v"1.9-" + ex = Expression(Node{Float64}(; feature=1)) + @test_logs (:warn, r"The `turbo` and `bumper` keyword arguments are deprecated.*") (ex( + randn(Float64, 1, 10), OperatorEnum(); turbo=true + )) +end + # Test deprecated modules logs = @capture_err begin @eval using DynamicExpressions.EquationModule diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index dbf9f224..2ed0d771 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -271,3 +271,15 @@ end @test !isfinite(y[2]) end end + +@testitem "Test EvalOptions constructor" begin + using DynamicExpressions, LoopVectorization + + @test EvalOptions(; turbo=true) isa EvalOptions{true} + @test EvalOptions(; turbo=Val(true)) isa EvalOptions{true} + @test EvalOptions(; turbo=false) isa EvalOptions{false} + @test EvalOptions(; turbo=Val(false)) isa EvalOptions{false} + + ex = Expression(Node{Float64}(; feature=1)) + @test_throws ArgumentError ex(randn(1, 5), OperatorEnum(); bad_arg=1) +end From 2c34b4ec0bd20f1a40066986e0b9fae3932ba45e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 01:12:36 +0100 Subject: [PATCH 36/41] test: prevent soft scope problem --- test/test_evaluation.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index 2ed0d771..587fba2e 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -226,12 +226,16 @@ end num_tests = 100 n_features = 3 for _ in 1:num_tests - tree = gen_random_tree_fixed_size(20, only_basic_ops_operator, n_features, Float64) - X = randn(Float64, n_features, 10) - basic_eval = tree(X, only_basic_ops_operator) - many_ops_eval = tree(X, many_ops_operators) - @test (all(isnan, basic_eval) && all(isnan, many_ops_eval)) || - basic_eval ≈ many_ops_eval + let tree = gen_random_tree_fixed_size( + 20, only_basic_ops_operator, n_features, Float64 + ), + X = randn(Float64, n_features, 10), + basic_eval = tree(X, only_basic_ops_operator), + many_ops_eval = tree(X, many_ops_operators) + + @test (all(isnan, basic_eval) && all(isnan, many_ops_eval)) || + basic_eval ≈ many_ops_eval + end end end From 3113499ce5478602b7505979eeac46d6eb0fa381 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 01:15:19 +0100 Subject: [PATCH 37/41] refactor: clean up evaluation --- src/Evaluate.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 29a25177..707c8a1c 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -776,7 +776,7 @@ function eval(current_node) operators::GenericOperatorEnum; throw_errors::Union{Val,Bool}=Val(true), ) where {T1,T2,N} - v_throw_errors = throw_errors isa Val ? throw_errors : Val(throw_errors) + v_throw_errors = _to_bool_val(throw_errors) try return _eval_tree_array_generic(tree, cX, operators, v_throw_errors) catch e From 586baf8625f67e798e47b74371cd99e6f348b202 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 01:51:59 +0100 Subject: [PATCH 38/41] benchmarks: fix benchmark eval options --- benchmark/benchmarks.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 5c9331c9..05ea4678 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -74,11 +74,18 @@ function benchmark_evaluation() extra_kws... ) suite[T]["evaluation$(extra_key)"] = @benchmarkable( - [eval_tree_array(tree, X, $operators; turbo=$turbo, $extra_kws...) for tree in trees], + [eval_tree_array(tree, X, $operators; kws...) for tree in trees], setup=( X=randn(MersenneTwister(0), $T, 5, $n); treesize=20; ntrees=100; + kws=$( + if @isdefined(EvalOptions) + (; eval_options=EvalOptions(; turbo=turbo, extra_kws...)) + else + (; turbo, extra_kws...) + end + ); trees=[gen_random_tree_fixed_size(treesize, $operators, 5, $T) for _ in 1:ntrees] ) ) From 3d7b52941e45b715dc714e73c8ff635b1f6889ae Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 02:00:33 +0100 Subject: [PATCH 39/41] fix: note instability in kw deprecation --- src/Evaluate.jl | 68 +++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 707c8a1c..f8b78cbf 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -56,48 +56,38 @@ end default_union_limit = 2, @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false) ) -@stable(default_mode = "disable", @inline _to_bool_val(x::Val{T}) where {T} = Val(T::Bool)) +@inline _to_bool_val(x::Val{T}) where {T} = Val(T::Bool) -@stable( - default_mode = "disable", - default_union_limit = 4, - begin - function EvalOptions(; - turbo::Union{Bool,Val}=Val(false), - bumper::Union{Bool,Val}=Val(false), - early_exit::Union{Bool,Val}=Val(true), - ) - return EvalOptions( - _to_bool_val(turbo), _to_bool_val(bumper), _to_bool_val(early_exit) - ) - end +@unstable function EvalOptions(; + turbo::Union{Bool,Val}=Val(false), + bumper::Union{Bool,Val}=Val(false), + early_exit::Union{Bool,Val}=Val(true), +) + return EvalOptions(_to_bool_val(turbo), _to_bool_val(bumper), _to_bool_val(early_exit)) +end - function _process_deprecated_kws(eval_options, deprecated_kws) - turbo = get(deprecated_kws, :turbo, nothing) - bumper = get(deprecated_kws, :bumper, nothing) - if any(Base.Fix2(∉, (:turbo, :bumper)), keys(deprecated_kws)) - throw( - ArgumentError("Invalid keyword argument(s): $(keys(deprecated_kws))") - ) - end - if !isempty(deprecated_kws) - @assert eval_options === nothing "Cannot use both `eval_options` and deprecated flags `turbo` and `bumper`." - Base.depwarn( - "The `turbo` and `bumper` keyword arguments are deprecated. Please use `eval_options` instead.", - :eval_tree_array, - ) - end - if eval_options !== nothing - return eval_options - else - return EvalOptions(; - turbo=turbo === nothing ? Val(false) : turbo, - bumper=bumper === nothing ? Val(false) : bumper, - ) - end - end +@unstable function _process_deprecated_kws(eval_options, deprecated_kws) + turbo = get(deprecated_kws, :turbo, nothing) + bumper = get(deprecated_kws, :bumper, nothing) + if any(Base.Fix2(∉, (:turbo, :bumper)), keys(deprecated_kws)) + throw(ArgumentError("Invalid keyword argument(s): $(keys(deprecated_kws))")) end -) + if !isempty(deprecated_kws) + @assert eval_options === nothing "Cannot use both `eval_options` and deprecated flags `turbo` and `bumper`." + Base.depwarn( + "The `turbo` and `bumper` keyword arguments are deprecated. Please use `eval_options` instead.", + :eval_tree_array, + ) + end + if eval_options !== nothing + return eval_options + else + return EvalOptions(; + turbo=turbo === nothing ? Val(false) : turbo, + bumper=bumper === nothing ? Val(false) : bumper, + ) + end +end """ eval_tree_array( From 09b7a3d1acc48821786a024fe5e38f9f7356382f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 02:16:58 +0100 Subject: [PATCH 40/41] feat: also include `early_exit` in scalar checks --- src/Evaluate.jl | 71 +++++++++++++++++++++++-------------------------- 1 file changed, 33 insertions(+), 38 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index f8b78cbf..670a983b 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -12,17 +12,17 @@ import ..ValueInterfaceModule: is_valid, is_valid_array const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15 -macro return_on_check(val, X) +macro return_on_nonfinite_val(eval_options, val, X) :( - if !is_valid($(esc(val))) + if $(esc(eval_options)).early_exit isa Val{true} && !is_valid($(esc(val))) return $(ResultOk)(similar($(esc(X)), axes($(esc(X)), 2)), false) end ) end -macro return_on_nonfinite_array(array) +macro return_on_nonfinite_array(eval_options, array) :( - if !is_valid_array($(esc(array))) + if $(esc(eval_options)).early_exit isa Val{true} && !is_valid_array($(esc(array))) return $(ResultOk)($(esc(array)), false) end ) @@ -257,10 +257,10 @@ end return quote result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l - eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_l.x + @return_on_nonfinite_array(eval_options, result_l.x) result_r = _eval_tree_array(tree.r, cX, operators, eval_options) !result_r.ok && return result_r - eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_r.x + @return_on_nonfinite_array(eval_options, result_r.x) # op(x, y), for any x or y deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], eval_options) end @@ -275,26 +275,22 @@ end elseif tree.r.degree == 0 result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l - eval_options.early_exit isa Val{true} && - @return_on_nonfinite_array result_l.x + @return_on_nonfinite_array(eval_options, result_l.x) # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, eval_options) elseif tree.l.degree == 0 result_r = _eval_tree_array(tree.r, cX, operators, eval_options) !result_r.ok && return result_r - eval_options.early_exit isa Val{true} && - @return_on_nonfinite_array result_r.x + @return_on_nonfinite_array(eval_options, result_r.x) # op(x, y), where x is a constant or variable but y is not. deg2_l0_eval(tree, result_r.x, cX, op, eval_options) else result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l - eval_options.early_exit isa Val{true} && - @return_on_nonfinite_array result_l.x + @return_on_nonfinite_array(eval_options, result_l.x) result_r = _eval_tree_array(tree.r, cX, operators, eval_options) !result_r.ok && return result_r - eval_options.early_exit isa Val{true} && - @return_on_nonfinite_array result_r.x + @return_on_nonfinite_array(eval_options, result_r.x) # op(x, y), for any x or y deg2_eval(result_l.x, result_r.x, op, eval_options) end @@ -315,7 +311,7 @@ end return quote result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result - eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result.x + @return_on_nonfinite_array(eval_options, result.x) deg1_eval(result.x, operators.unaops[op_idx], eval_options) end end @@ -342,8 +338,7 @@ end # op(x), for any x. result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result - eval_options.early_exit isa Val{true} && - @return_on_nonfinite_array result.x + @return_on_nonfinite_array(eval_options, result.x) deg1_eval(result.x, op, eval_options) end end @@ -396,21 +391,21 @@ function deg1_l2_ll0_lr0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvalOptions{false,false}, + eval_options::EvalOptions{false,false}, ) where {T,F,F2} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val val_lr = tree.l.r.val - @return_on_check val_ll cX - @return_on_check val_lr cX + @return_on_nonfinite_val(eval_options, val_ll, cX) + @return_on_nonfinite_val(eval_options, val_lr, cX) x_l = op_l(val_ll, val_lr)::T - @return_on_check x_l cX + @return_on_nonfinite_val(eval_options, x_l, cX) x = op(x_l)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.l.constant val_ll = tree.l.l.val - @return_on_check val_ll cX + @return_on_nonfinite_val(eval_options, val_ll, cX) feature_lr = tree.l.r.feature cumulator = similar(cX, axes(cX, 2)) @inbounds @simd for j in axes(cX, 2) @@ -422,7 +417,7 @@ function deg1_l2_ll0_lr0_eval( elseif tree.l.r.constant feature_ll = tree.l.l.feature val_lr = tree.l.r.val - @return_on_check val_lr cX + @return_on_nonfinite_val(eval_options, val_lr, cX) cumulator = similar(cX, axes(cX, 2)) @inbounds @simd for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j], val_lr)::T @@ -449,15 +444,15 @@ function deg1_l1_ll0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvalOptions{false,false}, + eval_options::EvalOptions{false,false}, ) where {T,F,F2} if tree.l.l.constant val_ll = tree.l.l.val - @return_on_check val_ll cX + @return_on_nonfinite_val(eval_options, val_ll, cX) x_l = op_l(val_ll)::T - @return_on_check x_l cX + @return_on_nonfinite_val(eval_options, x_l, cX) x = op(x_l)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) else feature_ll = tree.l.l.feature @@ -476,20 +471,20 @@ function deg2_l0_r0_eval( tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, - ::EvalOptions{false,false}, + eval_options::EvalOptions{false,false}, ) where {T,F} if tree.l.constant && tree.r.constant val_l = tree.l.val - @return_on_check val_l cX + @return_on_nonfinite_val(eval_options, val_l, cX) val_r = tree.r.val - @return_on_check val_r cX + @return_on_nonfinite_val(eval_options, val_r, cX) x = op(val_l, val_r)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.constant cumulator = similar(cX, axes(cX, 2)) val_l = tree.l.val - @return_on_check val_l cX + @return_on_nonfinite_val(eval_options, val_l, cX) feature_r = tree.r.feature @inbounds @simd for j in axes(cX, 2) x = op(val_l, cX[feature_r, j])::T @@ -500,7 +495,7 @@ function deg2_l0_r0_eval( cumulator = similar(cX, axes(cX, 2)) feature_l = tree.l.feature val_r = tree.r.val - @return_on_check val_r cX + @return_on_nonfinite_val(eval_options, val_r, cX) @inbounds @simd for j in axes(cX, 2) x = op(cX[feature_l, j], val_r)::T cumulator[j] = x @@ -524,11 +519,11 @@ function deg2_l0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvalOptions{false,false}, + eval_options::EvalOptions{false,false}, ) where {T,F} if tree.l.constant val = tree.l.val - @return_on_check val cX + @return_on_nonfinite_val(eval_options, val, cX) @inbounds @simd for j in eachindex(cumulator) x = op(val, cumulator[j])::T cumulator[j] = x @@ -550,11 +545,11 @@ function deg2_r0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvalOptions{false,false}, + eval_options::EvalOptions{false,false}, ) where {T,F} if tree.r.constant val = tree.r.val - @return_on_check val cX + @return_on_nonfinite_val(eval_options, val, cX) @inbounds @simd for j in eachindex(cumulator) x = op(cumulator[j], val)::T cumulator[j] = x From 17a4a24ef4e93f68fe8cebc080fa55cfedf87ada Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 02:20:03 +0100 Subject: [PATCH 41/41] fix: incorporate `@return_on_nonfinite_val` in LoopVectorization extension --- ext/DynamicExpressionsLoopVectorizationExt.jl | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index 35da7de0..7edbd704 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt using LoopVectorization: @turbo using DynamicExpressions: AbstractExpressionNode using DynamicExpressions.UtilsModule: ResultOk, fill_similar -using DynamicExpressions.EvaluateModule: @return_on_check, EvalOptions +using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions import DynamicExpressions.EvaluateModule: deg1_eval, deg2_eval, @@ -45,21 +45,21 @@ function deg1_l2_ll0_lr0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvalOptions{true}, + eval_options::EvalOptions{true}, ) where {T<:Number,F,F2} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val val_lr = tree.l.r.val - @return_on_check val_ll cX - @return_on_check val_lr cX + @return_on_nonfinite_val(eval_options, val_ll, cX) + @return_on_nonfinite_val(eval_options, val_lr, cX) x_l = op_l(val_ll, val_lr)::T - @return_on_check x_l cX + @return_on_nonfinite_val(eval_options, x_l, cX) x = op(x_l)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.l.constant val_ll = tree.l.l.val - @return_on_check val_ll cX + @return_on_nonfinite_val(eval_options, val_ll, cX) feature_lr = tree.l.r.feature cumulator = similar(cX, axes(cX, 2)) @turbo for j in axes(cX, 2) @@ -71,7 +71,7 @@ function deg1_l2_ll0_lr0_eval( elseif tree.l.r.constant feature_ll = tree.l.l.feature val_lr = tree.l.r.val - @return_on_check val_lr cX + @return_on_nonfinite_val(eval_options, val_lr, cX) cumulator = similar(cX, axes(cX, 2)) @turbo for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j], val_lr) @@ -97,15 +97,15 @@ function deg1_l1_ll0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvalOptions{true}, + eval_options::EvalOptions{true}, ) where {T<:Number,F,F2} if tree.l.l.constant val_ll = tree.l.l.val - @return_on_check val_ll cX + @return_on_nonfinite_val(eval_options, val_ll, cX) x_l = op_l(val_ll)::T - @return_on_check x_l cX + @return_on_nonfinite_val(eval_options, x_l, cX) x = op(x_l)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) else feature_ll = tree.l.l.feature @@ -120,20 +120,23 @@ function deg1_l1_ll0_eval( end function deg2_l0_r0_eval( - tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, op::F, ::EvalOptions{true} + tree::AbstractExpressionNode{T}, + cX::AbstractMatrix{T}, + op::F, + eval_options::EvalOptions{true}, ) where {T<:Number,F} if tree.l.constant && tree.r.constant val_l = tree.l.val - @return_on_check val_l cX + @return_on_nonfinite_val(eval_options, val_l, cX) val_r = tree.r.val - @return_on_check val_r cX + @return_on_nonfinite_val(eval_options, val_r, cX) x = op(val_l, val_r)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.constant cumulator = similar(cX, axes(cX, 2)) val_l = tree.l.val - @return_on_check val_l cX + @return_on_nonfinite_val(eval_options, val_l, cX) feature_r = tree.r.feature @turbo for j in axes(cX, 2) x = op(val_l, cX[feature_r, j]) @@ -144,7 +147,7 @@ function deg2_l0_r0_eval( cumulator = similar(cX, axes(cX, 2)) feature_l = tree.l.feature val_r = tree.r.val - @return_on_check val_r cX + @return_on_nonfinite_val(eval_options, val_r, cX) @turbo for j in axes(cX, 2) x = op(cX[feature_l, j], val_r) cumulator[j] = x @@ -168,11 +171,11 @@ function deg2_l0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvalOptions{true}, + eval_options::EvalOptions{true}, ) where {T<:Number,F} if tree.l.constant val = tree.l.val - @return_on_check val cX + @return_on_nonfinite_val(eval_options, val, cX) @turbo for j in eachindex(cumulator) x = op(val, cumulator[j]) cumulator[j] = x @@ -193,11 +196,11 @@ function deg2_r0_eval( cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, - ::EvalOptions{true}, + eval_options::EvalOptions{true}, ) where {T<:Number,F} if tree.r.constant val = tree.r.val - @return_on_check val cX + @return_on_nonfinite_val(eval_options, val, cX) @turbo for j in eachindex(cumulator) x = op(cumulator[j], val) cumulator[j] = x