Skip to content

feat: option to skip fused kernels #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ import .StringsModule: get_op_name, get_pretty_op_name
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
@reexport import .EvaluateModule:
eval_tree_array, differentiable_eval_tree_array, EvalOptions
import .EvaluateModule: ArrayBuffer
import .EvaluateModule: ArrayBuffer, ResultOk
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
@reexport import .SimplifyModule: combine_operators, simplify_tree!
Expand Down
26 changes: 19 additions & 7 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,34 +93,43 @@ This holds options for expression evaluation, such as evaluation backend.
- `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation.
This should be an instance of `ArrayBuffer` which has an `array` field and an
`index` field used to iterate which buffer slot to use.
- `use_fused::Val{U}=Val(true)`: If `Val{true}`, use fused kernels for faster
evaluation. Setting this to `Val{false}` will skip the fused kernels, meaning that
you would only need to overload `deg0_eval`, `deg1_eval` and `deg2_eval` for custom
evaluation.
"""
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing}}
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing},U}
turbo::Val{T}
bumper::Val{B}
early_exit::Val{E}
buffer::BUF
use_fused::Val{U}
end

@unstable function EvalOptions(;
turbo::Union{Bool,Val}=Val(false),
bumper::Union{Bool,Val}=Val(false),
early_exit::Union{Bool,Val}=Val(true),
buffer::Union{ArrayBuffer,Nothing}=nothing,
use_fused::Union{Bool,Val}=Val(true),
)
v_turbo = _to_bool_val(turbo)
v_bumper = _to_bool_val(bumper)
v_early_exit = _to_bool_val(early_exit)
v_use_fused = _to_bool_val(use_fused)

if v_bumper isa Val{true}
@assert buffer === nothing
end

return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer)
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer, v_use_fused)
end

@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
@inline _to_bool_val(::Val{T}) where {T} = Val(T::Bool)

@inline use_fused(eval_options::EvalOptions) = eval_options.use_fused isa Val{true}

_copy(x) = copy(x)
_copy(::Nothing) = nothing
function Base.copy(eval_options::EvalOptions)
Expand All @@ -129,6 +138,7 @@ function Base.copy(eval_options::EvalOptions)
bumper=eval_options.bumper,
early_exit=eval_options.early_exit,
buffer=_copy(eval_options.buffer),
use_fused=eval_options.use_fused,
)
end

Expand Down Expand Up @@ -340,19 +350,20 @@ end
end
end
return quote
fused = use_fused(eval_options)
return Base.Cartesian.@nif(
$nbin,
i -> i == op_idx,
i -> let op = operators.binops[i]
if tree.l.degree == 0 && tree.r.degree == 0
if fused && tree.l.degree == 0 && tree.r.degree == 0
deg2_l0_r0_eval(tree, cX, op, eval_options)
elseif tree.r.degree == 0
elseif fused && tree.r.degree == 0
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
!result_l.ok && return result_l
@return_on_nonfinite_array(eval_options, result_l.x)
# op(x, y), where y is a constant or variable but x is not.
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
elseif tree.l.degree == 0
elseif fused && tree.l.degree == 0
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
!result_r.ok && return result_r
@return_on_nonfinite_array(eval_options, result_r.x)
Expand Down Expand Up @@ -392,17 +403,18 @@ end
# This @nif lets us generate an if statement over choice of operator,
# which means the compiler will be able to completely avoid type inference on operators.
return quote
fused = use_fused(eval_options)
Base.Cartesian.@nif(
$nuna,
i -> i == op_idx,
i -> let op = operators.unaops[i]
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
if fused && tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
# op(op2(x, y)), where x, y, z are constants or variables.
l_op_idx = tree.l.op
dispatch_deg1_l2_ll0_lr0_eval(
tree, cX, op, l_op_idx, operators.binops, eval_options
)
elseif tree.l.degree == 1 && tree.l.l.degree == 0
elseif fused && tree.l.degree == 1 && tree.l.l.degree == 0
# op(op2(x)), where x is a constant or variable.
l_op_idx = tree.l.op
dispatch_deg1_l1_ll0_eval(
Expand Down
Loading