diff --git a/Project.toml b/Project.toml index 8e2552d6..63b5bbee 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.8.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -15,6 +16,7 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" @@ -22,6 +24,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DynamicExpressionsBumperExt = "Bumper" +DynamicExpressionsCUDAExt = "CUDA" DynamicExpressionsLoopVectorizationExt = "LoopVectorization" DynamicExpressionsOptimExt = "Optim" DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils" @@ -29,7 +32,9 @@ DynamicExpressionsZygoteExt = "Zygote" [compat] Bumper = "0.6" +CUDA = "4, 5" ChainRulesCore = "1" +Compat = "4.16" DispatchDoctor = "0.4" Interfaces = "0.3" LoopVectorization = "0.12" @@ -43,6 +48,7 @@ julia = "1.10" [extras] Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Optim = "429524aa-4258-5aef-a3af-852621145aeb" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" diff --git a/ext/DynamicExpressionsCUDAExt.jl b/ext/DynamicExpressionsCUDAExt.jl new file mode 100644 index 00000000..83cea775 --- /dev/null +++ b/ext/DynamicExpressionsCUDAExt.jl @@ -0,0 +1,215 @@ +module DynamicExpressionsCUDAExt + +# TODO: Switch to KernelAbstractions.jl (once they hit v1.0) +using CUDA: @cuda, CuArray, blockDim, blockIdx, threadIdx +using DynamicExpressions: OperatorEnum, AbstractExpressionNode +using DynamicExpressions.EvaluateModule: get_nbin, get_nuna +using DynamicExpressions.AsArrayModule: + as_array, + IDX_DEGREE, + IDX_FEATURE, + IDX_OP, + IDX_EXECUTION_ORDER, + IDX_SELF, + IDX_L, + IDX_R, + IDX_CONSTANT +using DispatchDoctor: @stable + +import DynamicExpressions.EvaluateModule: eval_tree_array + +# array type for exclusively testing purposes +struct FakeCuArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} + a::A +end +Base.similar(x::FakeCuArray, dims::Integer...) = FakeCuArray(similar(x.a, dims...)) +Base.getindex(x::FakeCuArray, i::Int...) = getindex(x.a, i...) +Base.setindex!(x::FakeCuArray, v, i::Int...) = setindex!(x.a, v, i...) +Base.size(x::FakeCuArray) = size(x.a) + +const MaybeCuArray{T,N} = Union{CuArray{T,N},FakeCuArray{T,N}} + +@stable default_mode = "disable" begin + to_device(a, ::CuArray) = CuArray(a) + to_device(a, ::FakeCuArray) = FakeCuArray(a) +end + +@stable default_mode = "disable" function eval_tree_array( + tree::AbstractExpressionNode{T}, gcX::MaybeCuArray{T,2}, operators::OperatorEnum; kws... +) where {T<:Number} + (outs, is_good) = eval_tree_array((tree,), gcX, operators; kws...) + return (only(outs), only(is_good)) +end + +@stable default_mode = "disable" function eval_tree_array( + trees::Union{Tuple{N,Vararg{N}},AbstractVector{N}}, + gcX::MaybeCuArray{T,2}, + operators::OperatorEnum; + buffer=nothing, + gpu_workspace=nothing, + gpu_buffer=nothing, + roots=nothing, + num_nodes=nothing, + num_launches=nothing, + update_buffers::Val{_update_buffers}=Val(true), + kws..., +) where {T<:Number,N<:AbstractExpressionNode{T},_update_buffers} + local val + if _update_buffers + (; val, roots, buffer, num_nodes, num_launches) = as_array(Int32, trees; buffer) + end + # TODO: Fix this type instability? + num_elem = size(gcX, 2) + + num_launches = num_launches isa Integer ? num_launches : num_launches[] + + ## The following array is our "workspace" for + ## the GPU kernel, with size equal to the number of rows + ## in the input data by the number of nodes in the tree. + ## It has one extra row to store the constant values. + gworkspace = @something(gpu_workspace, similar(gcX, num_elem + 1, num_nodes)) + if _update_buffers + copyto!(@view(gworkspace[end, :]), val) + end + val_idx = size(gworkspace, 1) + + gbuffer = if !_update_buffers + gpu_buffer + elseif gpu_buffer === nothing + to_device(buffer, gcX) + else + copyto!(gpu_buffer, buffer) + end + + # Removed @view definitions of gdegree, gfeature, etc. + # We'll index directly into gbuffer using the constants above. + + num_threads = 256 + num_blocks = nextpow(2, ceil(Int, num_elem * num_nodes / num_threads)) + + #! format: off + _launch_gpu_kernel!( + num_threads, num_blocks, num_launches, gworkspace, + # Thread info: + num_elem, num_nodes, + # We'll pass gbuffer directly to the kernel now: + operators, gcX, gbuffer, val_idx, + ) + #! format: on + + out = map(r -> @view(gworkspace[begin:(end - 1), r]), roots) + is_good = map(Returns(true), trees) + + return (out, is_good) +end + +#! format: off +@stable default_mode = "disable" function _launch_gpu_kernel!( + num_threads, num_blocks, num_launches::Integer, buffer::AbstractArray{T,2}, + # Thread info: + num_elem::Integer, num_nodes::Integer, + operators::OperatorEnum, cX::AbstractArray{T,2}, gbuffer::AbstractArray{Int32,2}, + val_idx::Integer +) where {T} + #! format: on + nuna = get_nuna(typeof(operators)) + nbin = get_nbin(typeof(operators)) + (nuna > 10 || nbin > 10) && + error("Too many operators. Kernels are only compiled up to 10.") + gpu_kernel! = create_gpu_kernel(operators, Val(nuna), Val(nbin)) + for launch in one(Int32):Int32(num_launches) + #! format: off + if buffer isa CuArray + @cuda threads=num_threads blocks=num_blocks gpu_kernel!( + buffer, + launch, num_elem, num_nodes, + cX, gbuffer, val_idx + ) + else + Threads.@threads for i in 1:(num_threads * num_blocks) + gpu_kernel!( + buffer, + launch, num_elem, num_nodes, + cX, gbuffer, val_idx, i + ) + end + end + #! format: on + end + return nothing +end + +# Need to pre-compute the GPU kernels with an `@eval` for each number of operators +# 1. We need to use an `@nif` over operators, as GPU kernels +# can't index into arrays of operators. +# 2. `@nif` is evaluated at parse time and needs to know the number of +# ifs to generate at that time, so we can't simply use specialization. +# 3. We can't use `@generated` because we can't create closures in those. +for nuna in 0:10, nbin in 0:10 + @eval function create_gpu_kernel(operators::OperatorEnum, ::Val{$nuna}, ::Val{$nbin}) + #! format: off + function ( + buffer, + launch::Integer, num_elem::Integer, num_nodes::Integer, + cX::AbstractArray, gbuffer::AbstractArray{Int32,2}, + val_idx::Integer, + i=nothing, + ) + i = @something(i, (blockIdx().x - 1) * blockDim().x + threadIdx().x) + if i > num_elem * num_nodes + return nothing + end + + node = (i - 1) % num_nodes + 1 + elem = (i - node) ÷ num_nodes + 1 + + + @inbounds begin + if gbuffer[IDX_EXECUTION_ORDER, node] != launch + return nothing + end + + # Use constants to index gbuffer: + cur_degree = gbuffer[IDX_DEGREE, node] + cur_idx = gbuffer[IDX_SELF, node] + + if cur_degree == 0 + if gbuffer[IDX_CONSTANT, node] == 1 + cur_val = buffer[val_idx, node] + buffer[elem, cur_idx] = cur_val + else + cur_feature = gbuffer[IDX_FEATURE, node] + buffer[elem, cur_idx] = cX[cur_feature, elem] + end + else + if cur_degree == 1 && $nuna > 0 + cur_op = gbuffer[IDX_OP, node] + l_idx = gbuffer[IDX_L, node] + Base.Cartesian.@nif( + $nuna, + i -> i == cur_op, + i -> let op = operators.unaops[i] + buffer[elem, cur_idx] = op(buffer[elem, l_idx]) + end + ) + elseif $nbin > 0 + cur_op = gbuffer[IDX_OP, node] + l_idx = gbuffer[IDX_L, node] + r_idx = gbuffer[IDX_R, node] + Base.Cartesian.@nif( + $nbin, + i -> i == cur_op, + i -> let op = operators.binops[i] + buffer[elem, cur_idx] = op(buffer[elem, l_idx], buffer[elem, r_idx]) + end + ) + end + end + end + #! format: on + return nothing + end + end +end + +end diff --git a/src/AsArray.jl b/src/AsArray.jl new file mode 100644 index 00000000..c63f79da --- /dev/null +++ b/src/AsArray.jl @@ -0,0 +1,179 @@ +module AsArrayModule + +using Compat: Fix + +using ..NodeModule: AbstractExpressionNode, tree_mapreduce, count_nodes +using ..EvaluateModule: ArrayBuffer, get_array, get_filled_array + +function as_array( + ::Type{I}, trees::N; buffer::Union{ArrayBuffer,Nothing}=nothing +) where {T,N<:AbstractExpressionNode{T},I} + return as_array(I, (trees,); buffer=buffer) +end + +Base.@kwdef struct TreeBuffer{ + T,I,A<:AbstractArray{I},B<:AbstractArray{T},C,D<:AbstractArray{I} +} + # Corresponds to the `Node` fields + degree::A + constant::A + val::B + feature::A + op::A + idx_l::A + idx_r::A + + # Indexing information + execution_order::A + idx_self::A + num_launches::Base.RefValue{I} + cursor::Base.RefValue{I} + + # Segment information + roots::C + num_nodes::I + + # Original buffer + buffer::D +end + +const IDX_DEGREE = 1 +const IDX_FEATURE = 2 +const IDX_OP = 3 +const IDX_EXECUTION_ORDER = 4 +const IDX_SELF = 5 +const IDX_L = 6 +const IDX_R = 7 +const IDX_CONSTANT = 8 + +function as_array( + ::Type{I}, + trees::Union{Tuple{N,Vararg{N}},AbstractVector{N}}; + buffer::Union{AbstractArray{I},Nothing}=nothing, +) where {T,N<:AbstractExpressionNode{T},I} + each_num_nodes = map(t -> count_nodes(t; break_sharing=Val(true)), trees) + num_nodes = sum(each_num_nodes) + + # Compute the roots array for indexing. + roots = cumsum( + if each_num_nodes isa Tuple + tuple(one(I), each_num_nodes[begin:(end - 1)]...) + else + vcat(one(I), @view(each_num_nodes[begin:(end - 1)])) + end, + ) + + val = Array{T}(undef, num_nodes) + + # If no buffer is provided, create a new ArrayBuffer from scratch + buffer = @something(buffer, Array{I}(undef, 8, num_nodes)) + + # Obtain arrays from the buffer. Each call to get_array consumes one "slot". + #! format: off + degree = @view buffer[IDX_DEGREE, :] + feature = @view buffer[IDX_FEATURE, :] + op = @view buffer[IDX_OP, :] + execution_order = @view buffer[IDX_EXECUTION_ORDER, :] + idx_self = @view buffer[IDX_SELF, :] + idx_l = @view buffer[IDX_L, :] + idx_r = @view buffer[IDX_R, :] + constant = @view buffer[IDX_CONSTANT, :] + #! format: on + + tree_buffers = TreeBuffer(; + degree=degree, + constant=constant, + val=val, + feature=feature, + op=op, + idx_l=idx_l, + idx_r=idx_r, + + # Indexing information + execution_order=execution_order, + idx_self=idx_self, + num_launches=Ref(zero(I)), + cursor=Ref(zero(I)), + + # Segment information + roots=roots, + num_nodes=I(num_nodes), + + # Original buffer + buffer=buffer, + ) + + fill_tree_buffer!(tree_buffers, trees) + + return tree_buffers +end + +function fill_tree_buffer!( + tree_buffers::TreeBuffer{T,I}, trees::Union{Tuple{N,Vararg{N}},AbstractVector{N}} +) where {T,I,N<:AbstractExpressionNode{T}} + return foreach(Fix{1}(fill_single_tree!, tree_buffers), trees) +end + +function fill_single_tree!( + tree_buffers::TreeBuffer{T,I}, tree::N +) where {T,I,N<:AbstractExpressionNode{T}} + return tree_mapreduce( + Fix{1}(fill_single_leaf!, tree_buffers), + Fix{1}(fill_single_branch!, tree_buffers), + Fix{1}(link_parent_and_children!, tree_buffers), + tree; + break_sharing=Val(true), + ) +end + +function fill_single_leaf!( + tree_buffers::TreeBuffer{T,I}, leaf::N +) where {T,I,N<:AbstractExpressionNode{T}} + self = (tree_buffers.cursor[] += one(I)) + tree_buffers.idx_self[self] = self + tree_buffers.degree[self] = 0 + tree_buffers.execution_order[self] = one(I) + tree_buffers.constant[self] = leaf.constant + if leaf.constant + tree_buffers.val[self] = leaf.val::T + else + tree_buffers.feature[self] = leaf.feature + end + + return (id=self, order=one(I)) +end + +function fill_single_branch!( + tree_buffers::TreeBuffer{T,I}, branch::N +) where {T,I,N<:AbstractExpressionNode{T}} + self = (tree_buffers.cursor[] += one(I)) + tree_buffers.idx_self[self] = self + tree_buffers.op[self] = branch.op + tree_buffers.degree[self] = branch.degree + + return (id=self, order=one(I)) +end + +function link_parent_and_children!( + tree_buffers::TreeBuffer{T,I}, parent, children::Vararg{Any,C} +) where {T,I,C} + tree_buffers.idx_l[parent.id] = children[1].id + if C == 2 + tree_buffers.idx_r[parent.id] = children[2].id + end + parent_execution_order = if C == 1 + children[1].order + one(I) + else + max(children[1].order, children[2].order) + one(I) + end + + tree_buffers.execution_order[parent.id] = parent_execution_order + + if parent_execution_order > tree_buffers.num_launches[] + tree_buffers.num_launches[] = parent_execution_order + end + + return (id=parent.id, order=parent_execution_order) +end + +end diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 6c0ba5f8..dc4bd181 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -24,6 +24,7 @@ using DispatchDoctor: @stable, @unstable include("ParametricExpression.jl") include("ReadOnlyNode.jl") include("StructuredExpression.jl") + include("AsArray.jl") end import Reexport: @reexport @@ -100,6 +101,7 @@ import .ParseModule: parse_leaf import .ReadOnlyNodeModule: ReadOnlyNode @reexport import .StructuredExpressionModule: StructuredExpression import .StructuredExpressionModule: AbstractStructuredExpression +@reexport import .AsArrayModule: as_array @stable default_mode = "disable" begin include("Interfaces.jl") diff --git a/src/Evaluate.jl b/src/Evaluate.jl index cf2ee2e0..0c36947a 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -2,7 +2,7 @@ module EvaluateModule using DispatchDoctor: @stable, @unstable -import ..NodeModule: AbstractExpressionNode, constructorof +import ..NodeModule: AbstractExpressionNode, constructorof, with_type_parameters import ..StringsModule: string_tree import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum import ..UtilsModule: fill_similar, counttuple, ResultOk @@ -233,6 +233,15 @@ function eval_tree_array( cX = Base.Fix1(convert, T).(cX) return eval_tree_array(tree, cX, operators; kws...) end +function eval_tree_array( + trees::Union{Tuple{N,Vararg{N}},AbstractVector{N}}, + cX::AbstractMatrix{T}, + operators::OperatorEnum; + kws..., +) where {T<:Number,N<:AbstractExpressionNode{T}} + outs = map(t -> eval_tree_array(t, cX, operators; kws...), trees) + return map(first, outs), map(last, outs) +end # These are marked unstable due to issues discussed on # https://github.com/JuliaLang/julia/issues/55147 diff --git a/test/Project.toml b/test/Project.toml index 952e0370..3e47197d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" diff --git a/test/test_cuda.jl b/test/test_cuda.jl new file mode 100644 index 00000000..b281c024 --- /dev/null +++ b/test/test_cuda.jl @@ -0,0 +1,189 @@ +@testitem "Random Evals: Single Tree" begin + using DynamicExpressions, CUDA, Random + using DynamicExpressions.AsArrayModule: as_array + + include("tree_gen_utils.jl") + + safe_sin(x) = isfinite(x) ? sin(x) : convert(eltype(x), NaN) + safe_cos(x) = isfinite(x) ? cos(x) : convert(eltype(x), NaN) + + ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsCUDAExt) + const FakeCuArray = ext.FakeCuArray + + operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[safe_sin, safe_cos] + ) + + for T in (Float32, Float64, ComplexF64) + for seed in 0:10 + Random.seed!(seed) + nrow = rand(10:30) + nnodes = rand(10:25) + tree = gen_random_tree_fixed_size(nnodes, operators, 3, T) + X = randn(T, 3, nrow) + + y, completed = eval_tree_array(tree, X, operators) + y_gpu, completed_gpu = eval_tree_array(tree, FakeCuArray(X), operators) + + # TODO: Fix this + # @test completed == completed_gpu + if completed + @test y ≈ y_gpu + end + end + end +end + +@testitem "Random Evals: Multiple Trees" begin + using DynamicExpressions, CUDA, Random + using DynamicExpressions.AsArrayModule: as_array + + include("tree_gen_utils.jl") + + safe_sin(x) = isfinite(x) ? sin(x) : convert(eltype(x), NaN) + safe_cos(x) = isfinite(x) ? cos(x) : convert(eltype(x), NaN) + + operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[safe_sin, safe_cos] + ) + + ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsCUDAExt) + const FakeCuArray = ext.FakeCuArray + + for T in (Float32, Float64, ComplexF64), ntrees in (2, 3), seed in 0:10 + Random.seed!(seed) + + nrow = rand(10:30) + nnodes = rand(10:25, ntrees) + use_tuple = rand(Bool) + + buffer = rand(Bool) ? ones(Int32, 8, sum(nnodes)) : nothing + gpu_buffer = rand(Bool) ? FakeCuArray(ones(Int32, 8, sum(nnodes))) : nothing + gpu_workspace = rand(Bool) ? FakeCuArray(ones(T, nrow + 1, sum(nnodes))) : nothing + + trees = ntuple(i -> gen_random_tree_fixed_size(nnodes[i], operators, 3, T), ntrees) + trees = use_tuple ? trees : collect(trees) + + X = randn(T, 3, nrow) + + y, completed = eval_tree_array(trees, X, operators) + gpu_y, gpu_completed = eval_tree_array( + trees, FakeCuArray(X), operators; buffer, gpu_workspace, gpu_buffer + ) + + # TODO: Fix this + # @test completed == gpu_completed + + for i in eachindex(completed, gpu_completed) + if completed[i] + @test y[i] ≈ gpu_y[i] + end + end + + # Check return type matches input type (tuple or vector) + if use_tuple + @test y isa Tuple + @test gpu_y isa Tuple + else + @test y isa Vector + @test gpu_y isa Vector + end + end +end + +@testitem "Pre-Computed Buffers: Basic Equivalence" begin + using DynamicExpressions, CUDA, Random + using DynamicExpressions.AsArrayModule: as_array + + ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsCUDAExt) + const FakeCuArray = ext.FakeCuArray + + # No random trees here, we define a fixed tree + x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3) + operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) + + Random.seed!(0) + tree = sin(x1 * 3.1 - x3 * 0.9 + 0.2) * x2 - x3 * x3 * 1.5 + X = randn(Float64, 3, 100) + + y1, _ = eval_tree_array(tree, X, operators) + y2, _ = eval_tree_array(tree, FakeCuArray(X), operators) + @test y1 ≈ y2 +end + +@testitem "Pre-Computed Buffers: Using Provided Buffers" begin + using DynamicExpressions, CUDA, Random + using DynamicExpressions.AsArrayModule: as_array + + ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsCUDAExt) + const FakeCuArray = ext.FakeCuArray + + x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3) + operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) + + Random.seed!(0) + tree = sin(x1 * 3.1 - x3 * 0.9 + 0.2) * x2 - x3 * x3 * 1.5 + X = randn(Float64, 3, 100) + + y1, _ = eval_tree_array(tree, X, operators) + + # Extract arrays + (; val, roots, buffer, num_nodes, num_launches) = as_array(Int32, tree) + gpu_buffer = FakeCuArray(buffer) + gpu_workspace = FakeCuArray(zeros(Float64, size(X, 2) + 1, num_nodes)) + copyto!((@view gpu_workspace[end, :]), val) + + y3, _ = eval_tree_array( + tree, + FakeCuArray(X), + operators; + gpu_workspace, + gpu_buffer, + roots, + num_nodes, + num_launches, + update_buffers=Val(false), + ) + @test y1 ≈ y3 +end + +@testitem "Pre-Computed Buffers: Modified Values" begin + using DynamicExpressions, CUDA, Random + using DynamicExpressions.AsArrayModule: as_array + + ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsCUDAExt) + const FakeCuArray = ext.FakeCuArray + + x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3) + operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) + + Random.seed!(0) + tree = sin(x1 * 3.1 - x3 * 0.9 + 0.2) * x2 - x3 * x3 * 1.5 + X = randn(Float64, 3, 100) + + y1, _ = eval_tree_array(tree, X, operators) + + (; val, roots, buffer, num_nodes, num_launches) = as_array(Int32, tree) + gpu_buffer = FakeCuArray(buffer) + gpu_workspace = FakeCuArray(zeros(Float64, size(X, 2) + 1, num_nodes)) + gpu_workspace[end, :] .= val + + # Change a constant (0.9 to 0.8) + i = findfirst(gpu_workspace[end, :] .== 0.9) + gpu_workspace[end, i] = 0.8 + + tree_prime = sin(x1 * 3.1 - x3 * 0.8 + 0.2) * x2 - x3 * x3 * 1.5 + y1_prime, _ = eval_tree_array(tree_prime, X, operators) + y3_prime, _ = eval_tree_array( + x1, # dummy tree + FakeCuArray(X), + operators; + gpu_workspace, + gpu_buffer, + roots, + num_nodes, + num_launches, + update_buffers=Val(false), + ) + @test y1_prime ≈ y3_prime +end diff --git a/test/unittest.jl b/test/unittest.jl index 42ae11bb..4685c119 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -133,3 +133,4 @@ include("test_expression_math.jl") include("test_structured_expression.jl") include("test_readonlynode.jl") include("test_zygote_gradient_wrapper.jl") +include("test_cuda.jl")