diff --git a/src/collection/macros.jl b/src/collection/macros.jl index c4ce9ee..5e20976 100644 --- a/src/collection/macros.jl +++ b/src/collection/macros.jl @@ -55,3 +55,45 @@ macro make_fused(fusion_style, type_name, fused_name) end end end + +""" + @make_get_fused fusion_style type_name fused_named + +This macro + - Defines a type type_name + - Defines a macro, `@fused_name`, using the fusion type `fusion_style` + +This allows users to flexibility +to customize their broadcast fusion. + +# Example +```julia +import MultiBroadcastFusion as MBF +MBF.@make_type MyFusedBroadcast +MBF.@make_get_fused MBF.fused_direct MyFusedBroadcast get_fused + +x1 = rand(3,3) +y1 = rand(3,3) +y2 = rand(3,3) + +# 4 reads, 2 writes +fmb = @get_fused begin + @. y1 = x1 + @. y2 = x1 +end +@test fmb isa MyFusedBroadcast +``` +""" +macro make_get_fused(fusion_style, type_name, fused_name) + t = esc(type_name) + f = esc(fused_name) + return quote + macro $f(expr) + _pairs = esc($(fusion_style)(expr)) + t = $t + quote + $t($_pairs) + end + end + end +end diff --git a/src/execution/fused_kernels.jl b/src/execution/fused_kernels.jl index 67608c6..e45e1e6 100644 --- a/src/execution/fused_kernels.jl +++ b/src/execution/fused_kernels.jl @@ -1,6 +1,7 @@ @make_type FusedMultiBroadcast @make_fused fused_direct FusedMultiBroadcast fused_direct @make_fused fused_assemble FusedMultiBroadcast fused_assemble +@make_get_fused fused_direct FusedMultiBroadcast get_fused_direct struct MBF_CPU end struct MBF_CUDA end diff --git a/test/execution/parameter_memory.jl b/test/execution/kernel_splitting.jl similarity index 98% rename from test/execution/parameter_memory.jl rename to test/execution/kernel_splitting.jl index bbf90a0..e086859 100644 --- a/test/execution/parameter_memory.jl +++ b/test/execution/kernel_splitting.jl @@ -2,7 +2,7 @@ using TestEnv TestEnv.activate() using CUDA # (optional) -using Revise; include(joinpath("test", "execution", "parameter_memory.jl")) +using Revise; include(joinpath("test", "execution", "kernel_splitting.jl")) =# include("utils_test.jl") diff --git a/test/execution/measure_parameter_memory.jl b/test/execution/measure_parameter_memory.jl new file mode 100644 index 0000000..669c3e8 --- /dev/null +++ b/test/execution/measure_parameter_memory.jl @@ -0,0 +1,73 @@ +#= +using TestEnv +TestEnv.activate() +using CUDA # (optional) +using Revise; include(joinpath("test", "execution", "measure_parameter_memory.jl")) +=# + +include("utils_test.jl") +include("utils_setup.jl") +include("utils_benchmark.jl") + +import MultiBroadcastFusion as MBF + +#! format: off +function perf_kernel_shared_reads_fused!(X, Y) + (; x1, x2, x3, x4) = X + (; y1, y2, y3, y4) = Y + # TODO: can we write this more compactly with `@fused_assemble`? + + # Let's make sure that every broadcasted object is different, + # so that we use up a lot of parameter memory: + fmb = MBF.@get_fused_direct begin + @. y1 = x1 + @. y2 = x1 + x2 + @. y3 = x1 + x2 + x3 + @. y4 = x1 * x2 + x3 + x4 + @. y1 = x1 * x2 + x3 + x4 + x1 + @. y2 = x1 * x2 + x3 + x4 + x1 + x2 + @. y3 = x1 * x2 + x3 + x4 + x1 + x2 + x3 + @. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + @. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + @. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + @. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3 + @. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + @. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + @. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + @. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3 + @. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + @. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + @. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + @. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3 + @. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + @. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + @. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + @. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3 + @. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + @. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + @. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + @. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3 + @. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + @. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + @. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 # breaks on A100 due to too much parameter memory + end + MBFExt = Base.get_extension(MBF, :MultiBroadcastFusionCUDAExt) + @show MBFExt.param_usage_args(fmb) +end +#! format: on + +@static get(ENV, "USE_CUDA", nothing) == "true" && using CUDA +use_cuda = @isdefined(CUDA) && CUDA.has_cuda() # will be true if you first run `using CUDA` +AType = use_cuda ? CUDA.CuArray : Array +device_name = use_cuda ? CUDA.name(CUDA.device()) : "CPU" +bm = Benchmark(; device_name, float_type = Float32) +problem_size = (50, 5, 5, 6, 5400) + +array_size = problem_size # array +X = get_arrays(:x, AType, bm.float_type, array_size) +Y = get_arrays(:y, AType, bm.float_type, array_size) +@testset "Test measuring parameter memory" begin + use_cuda && perf_kernel_shared_reads_fused!(X, Y) +end + +nothing diff --git a/test/execution/runtests.jl b/test/execution/runtests.jl index 6c5d4c0..57d5a40 100644 --- a/test/execution/runtests.jl +++ b/test/execution/runtests.jl @@ -6,5 +6,6 @@ using Revise; include(joinpath("test", "execution", "runtests.jl")) @safetestset "fused_shared_reads" begin; @time include("bm_fused_shared_reads.jl"); end @safetestset "fused_shared_reads_writes" begin; @time include("bm_fused_shared_reads_writes.jl"); end @safetestset "bm_fused_reads_vs_hard_coded" begin; @time include("bm_fused_reads_vs_hard_coded.jl"); end -@safetestset "parameter_memory" begin; @time include("parameter_memory.jl"); end +@safetestset "measure_parameter_memory" begin; @time include("measure_parameter_memory.jl"); end +@safetestset "kernel_splitting" begin; @time include("kernel_splitting.jl"); end #! format: on