Skip to content

Commit

Permalink
More improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 31, 2024
1 parent acebd14 commit 9f4abf5
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 2 deletions.
42 changes: 42 additions & 0 deletions src/collection/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,45 @@ macro make_fused(fusion_style, type_name, fused_name)
end
end
end

"""
@make_get_fused fusion_style type_name fused_named
This macro
- Defines a type type_name
- Defines a macro, `@fused_name`, using the fusion type `fusion_style`
This allows users to flexibility
to customize their broadcast fusion.
# Example
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedBroadcast
MBF.@make_get_fused MBF.fused_direct MyFusedBroadcast get_fused
x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)
# 4 reads, 2 writes
fmb = @get_fused begin
@. y1 = x1
@. y2 = x1
end
@test fmb isa MyFusedBroadcast
```
"""
macro make_get_fused(fusion_style, type_name, fused_name)
t = esc(type_name)
f = esc(fused_name)
return quote
macro $f(expr)
_pairs = esc($(fusion_style)(expr))
t = $t
quote
$t($_pairs)
end
end
end
end
1 change: 1 addition & 0 deletions src/execution/fused_kernels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@make_type FusedMultiBroadcast
@make_fused fused_direct FusedMultiBroadcast fused_direct
@make_fused fused_assemble FusedMultiBroadcast fused_assemble
@make_get_fused fused_direct FusedMultiBroadcast get_fused_direct

struct MBF_CPU end
struct MBF_CUDA end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using TestEnv
TestEnv.activate()
using CUDA # (optional)
using Revise; include(joinpath("test", "execution", "parameter_memory.jl"))
using Revise; include(joinpath("test", "execution", "kernel_splitting.jl"))
=#

include("utils_test.jl")
Expand Down
73 changes: 73 additions & 0 deletions test/execution/measure_parameter_memory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#=
using TestEnv
TestEnv.activate()
using CUDA # (optional)
using Revise; include(joinpath("test", "execution", "measure_parameter_memory.jl"))
=#

include("utils_test.jl")
include("utils_setup.jl")
include("utils_benchmark.jl")

import MultiBroadcastFusion as MBF

#! format: off
function perf_kernel_shared_reads_fused!(X, Y)
(; x1, x2, x3, x4) = X
(; y1, y2, y3, y4) = Y
# TODO: can we write this more compactly with `@fused_assemble`?

# Let's make sure that every broadcasted object is different,
# so that we use up a lot of parameter memory:
fmb = MBF.@get_fused_direct begin
@. y1 = x1
@. y2 = x1 + x2
@. y3 = x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 # breaks on A100 due to too much parameter memory
end
MBFExt = Base.get_extension(MBF, :MultiBroadcastFusionCUDAExt)
@show MBFExt.param_usage_args(fmb)
end
#! format: on

@static get(ENV, "USE_CUDA", nothing) == "true" && using CUDA
use_cuda = @isdefined(CUDA) && CUDA.has_cuda() # will be true if you first run `using CUDA`
AType = use_cuda ? CUDA.CuArray : Array
device_name = use_cuda ? CUDA.name(CUDA.device()) : "CPU"
bm = Benchmark(; device_name, float_type = Float32)
problem_size = (50, 5, 5, 6, 5400)

array_size = problem_size # array
X = get_arrays(:x, AType, bm.float_type, array_size)
Y = get_arrays(:y, AType, bm.float_type, array_size)
@testset "Test measuring parameter memory" begin
use_cuda && perf_kernel_shared_reads_fused!(X, Y)
end

nothing
3 changes: 2 additions & 1 deletion test/execution/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ using Revise; include(joinpath("test", "execution", "runtests.jl"))
@safetestset "fused_shared_reads" begin; @time include("bm_fused_shared_reads.jl"); end
@safetestset "fused_shared_reads_writes" begin; @time include("bm_fused_shared_reads_writes.jl"); end
@safetestset "bm_fused_reads_vs_hard_coded" begin; @time include("bm_fused_reads_vs_hard_coded.jl"); end
@safetestset "parameter_memory" begin; @time include("parameter_memory.jl"); end
@safetestset "measure_parameter_memory" begin; @time include("measure_parameter_memory.jl"); end
@safetestset "kernel_splitting" begin; @time include("kernel_splitting.jl"); end
#! format: on

0 comments on commit 9f4abf5

Please sign in to comment.