From 2831152245b004162f5091d91fa5f10ef9e3b1a3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 28 Jul 2024 02:20:03 +0100 Subject: [PATCH] fix: incorporate `@return_on_nonfinite_val` in LoopVectorization extension --- ext/DynamicExpressionsLoopVectorizationExt.jl | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/ext/DynamicExpressionsLoopVectorizationExt.jl b/ext/DynamicExpressionsLoopVectorizationExt.jl index 35da7de0..5a7bf579 100644 --- a/ext/DynamicExpressionsLoopVectorizationExt.jl +++ b/ext/DynamicExpressionsLoopVectorizationExt.jl @@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt using LoopVectorization: @turbo using DynamicExpressions: AbstractExpressionNode using DynamicExpressions.UtilsModule: ResultOk, fill_similar -using DynamicExpressions.EvaluateModule: @return_on_check, EvalOptions +using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions import DynamicExpressions.EvaluateModule: deg1_eval, deg2_eval, @@ -45,21 +45,21 @@ function deg1_l2_ll0_lr0_eval( cX::AbstractMatrix{T}, op::F, op_l::F2, - ::EvalOptions{true}, + eval_options::EvalOptions{true}, ) where {T<:Number,F,F2} if tree.l.l.constant && tree.l.r.constant val_ll = tree.l.l.val val_lr = tree.l.r.val - @return_on_check val_ll cX - @return_on_check val_lr cX + @return_on_nonfinite_val(eval_options, val_ll, cX) + @return_on_nonfinite_val(eval_options, val_lr, cX) x_l = op_l(val_ll, val_lr)::T - @return_on_check x_l cX + @return_on_nonfinite_val(eval_options, x_l, cX) x = op(x_l)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.l.constant val_ll = tree.l.l.val - @return_on_check val_ll cX + @return_on_nonfinite_val(eval_options, val_ll, cX) feature_lr = tree.l.r.feature cumulator = similar(cX, axes(cX, 2)) @turbo for j in axes(cX, 2) @@ -71,7 +71,7 @@ function deg1_l2_ll0_lr0_eval( elseif tree.l.r.constant feature_ll = tree.l.l.feature val_lr = tree.l.r.val - @return_on_check val_lr cX + @return_on_nonfinite_val(eval_options, val_lr, cX) cumulator = similar(cX, axes(cX, 2)) @turbo for j in axes(cX, 2) x_l = op_l(cX[feature_ll, j], val_lr) @@ -101,11 +101,11 @@ function deg1_l1_ll0_eval( ) where {T<:Number,F,F2} if tree.l.l.constant val_ll = tree.l.l.val - @return_on_check val_ll cX + @return_on_nonfinite_val(eval_options, val_ll, cX) x_l = op_l(val_ll)::T - @return_on_check x_l cX + @return_on_nonfinite_val(eval_options, x_l, cX) x = op(x_l)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) else feature_ll = tree.l.l.feature @@ -124,16 +124,16 @@ function deg2_l0_r0_eval( ) where {T<:Number,F} if tree.l.constant && tree.r.constant val_l = tree.l.val - @return_on_check val_l cX + @return_on_nonfinite_val(eval_options, val_l, cX) val_r = tree.r.val - @return_on_check val_r cX + @return_on_nonfinite_val(eval_options, val_r, cX) x = op(val_l, val_r)::T - @return_on_check x cX + @return_on_nonfinite_val(eval_options, x, cX) return ResultOk(fill_similar(x, cX, axes(cX, 2)), true) elseif tree.l.constant cumulator = similar(cX, axes(cX, 2)) val_l = tree.l.val - @return_on_check val_l cX + @return_on_nonfinite_val(eval_options, val_l, cX) feature_r = tree.r.feature @turbo for j in axes(cX, 2) x = op(val_l, cX[feature_r, j]) @@ -144,7 +144,7 @@ function deg2_l0_r0_eval( cumulator = similar(cX, axes(cX, 2)) feature_l = tree.l.feature val_r = tree.r.val - @return_on_check val_r cX + @return_on_nonfinite_val(eval_options, val_r, cX) @turbo for j in axes(cX, 2) x = op(cX[feature_l, j], val_r) cumulator[j] = x @@ -172,7 +172,7 @@ function deg2_l0_eval( ) where {T<:Number,F} if tree.l.constant val = tree.l.val - @return_on_check val cX + @return_on_nonfinite_val(eval_options, val, cX) @turbo for j in eachindex(cumulator) x = op(val, cumulator[j]) cumulator[j] = x @@ -197,7 +197,7 @@ function deg2_r0_eval( ) where {T<:Number,F} if tree.r.constant val = tree.r.val - @return_on_check val cX + @return_on_nonfinite_val(eval_options, val, cX) @turbo for j in eachindex(cumulator) x = op(cumulator[j], val) cumulator[j] = x