Skip to content

Commit

Permalink
fix: incorporate @return_on_nonfinite_val in LoopVectorization exte…
Browse files Browse the repository at this point in the history
…nsion
  • Loading branch information
MilesCranmer committed Jul 28, 2024
1 parent 09b7a3d commit 2831152
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DynamicExpressionsLoopVectorizationExt
using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateModule: @return_on_check, EvalOptions
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
import DynamicExpressions.EvaluateModule:
deg1_eval,
deg2_eval,
Expand Down Expand Up @@ -45,21 +45,21 @@ function deg1_l2_ll0_lr0_eval(
cX::AbstractMatrix{T},
op::F,
op_l::F2,
::EvalOptions{true},
eval_options::EvalOptions{true},
) where {T<:Number,F,F2}
if tree.l.l.constant && tree.l.r.constant
val_ll = tree.l.l.val
val_lr = tree.l.r.val
@return_on_check val_ll cX
@return_on_check val_lr cX
@return_on_nonfinite_val(eval_options, val_ll, cX)
@return_on_nonfinite_val(eval_options, val_lr, cX)
x_l = op_l(val_ll, val_lr)::T
@return_on_check x_l cX
@return_on_nonfinite_val(eval_options, x_l, cX)
x = op(x_l)::T
@return_on_check x cX
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
elseif tree.l.l.constant
val_ll = tree.l.l.val
@return_on_check val_ll cX
@return_on_nonfinite_val(eval_options, val_ll, cX)
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
Expand All @@ -71,7 +71,7 @@ function deg1_l2_ll0_lr0_eval(
elseif tree.l.r.constant
feature_ll = tree.l.l.feature
val_lr = tree.l.r.val
@return_on_check val_lr cX
@return_on_nonfinite_val(eval_options, val_lr, cX)
cumulator = similar(cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], val_lr)
Expand Down Expand Up @@ -101,11 +101,11 @@ function deg1_l1_ll0_eval(
) where {T<:Number,F,F2}
if tree.l.l.constant
val_ll = tree.l.l.val
@return_on_check val_ll cX
@return_on_nonfinite_val(eval_options, val_ll, cX)
x_l = op_l(val_ll)::T
@return_on_check x_l cX
@return_on_nonfinite_val(eval_options, x_l, cX)
x = op(x_l)::T
@return_on_check x cX
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
else
feature_ll = tree.l.l.feature
Expand All @@ -124,16 +124,16 @@ function deg2_l0_r0_eval(
) where {T<:Number,F}
if tree.l.constant && tree.r.constant
val_l = tree.l.val
@return_on_check val_l cX
@return_on_nonfinite_val(eval_options, val_l, cX)
val_r = tree.r.val
@return_on_check val_r cX
@return_on_nonfinite_val(eval_options, val_r, cX)
x = op(val_l, val_r)::T
@return_on_check x cX
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
elseif tree.l.constant
cumulator = similar(cX, axes(cX, 2))
val_l = tree.l.val
@return_on_check val_l cX
@return_on_nonfinite_val(eval_options, val_l, cX)
feature_r = tree.r.feature
@turbo for j in axes(cX, 2)
x = op(val_l, cX[feature_r, j])
Expand All @@ -144,7 +144,7 @@ function deg2_l0_r0_eval(
cumulator = similar(cX, axes(cX, 2))
feature_l = tree.l.feature
val_r = tree.r.val
@return_on_check val_r cX
@return_on_nonfinite_val(eval_options, val_r, cX)
@turbo for j in axes(cX, 2)
x = op(cX[feature_l, j], val_r)
cumulator[j] = x
Expand Down Expand Up @@ -172,7 +172,7 @@ function deg2_l0_eval(
) where {T<:Number,F}
if tree.l.constant
val = tree.l.val
@return_on_check val cX
@return_on_nonfinite_val(eval_options, val, cX)
@turbo for j in eachindex(cumulator)
x = op(val, cumulator[j])
cumulator[j] = x
Expand All @@ -197,7 +197,7 @@ function deg2_r0_eval(
) where {T<:Number,F}
if tree.r.constant
val = tree.r.val
@return_on_check val cX
@return_on_nonfinite_val(eval_options, val, cX)
@turbo for j in eachindex(cumulator)
x = op(cumulator[j], val)
cumulator[j] = x
Expand Down

0 comments on commit 2831152

Please sign in to comment.