Skip to content

Commit

Permalink
Dynamically split kernels based on parameter memory
Browse files Browse the repository at this point in the history
Remove depot

Refactor, move parameter mem to separate file

More improvements
  • Loading branch information
charleskawczynski committed Oct 31, 2024
1 parent 041fdee commit 9a55a0a
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 33 deletions.
1 change: 0 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ agents:

env:
JULIA_LOAD_PATH: "${JULIA_LOAD_PATH}:${BUILDKITE_BUILD_CHECKOUT_PATH}/.buildkite"
JULIA_DEPOT_PATH: "${BUILDKITE_BUILD_PATH}/${BUILDKITE_PIPELINE_SLUG}/depot/default"
JULIA_MAX_NUM_PRECOMPILE_FILES: 100
JULIA_CPU_TARGET: 'broadwell;skylake'
JULIA_NVTX_CALLBACKS: gc
Expand Down
82 changes: 67 additions & 15 deletions ext/MultiBroadcastFusionCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,75 @@ import MultiBroadcastFusion: fused_copyto!

MBF.device(x::CUDA.CuArray) = MBF.MBF_CUDA()

include("parameter_memory.jl")

"""
partition_kernels(fmb;
fused_broadcast_constructor = MBF.FusedMultiBroadcast,
args_func::Function =
)
Splits fused broadcast kernels into a vector
of kernels, based on parameter memory limitations.
We first attempt to fuse
1:N, 1:N-1, 1:N-2, ... until we fuse 1:N-k
Next, we attempt to fuse
N-k+1:N, N-k+1:N-1, N-k+1:N-2, ...
And so forth.
"""
function partition_kernels(
fmb,
fused_broadcast_constructor = MBF.FusedMultiBroadcast,
args_func::Function = fused_multibroadcast_args,
)
plim = get_param_lim()
usage = param_usage_args(args_func(fmb))
n_bins = 1
fmbs = (fmb,)
usage plim && return fmbs
fmbs_split = []
N = length(fmb.pairs)
i_start = 1
i_stop = N
while i_stop i_start
ith_pairs = fmb.pairs[i_start:i_stop]
ith_fmb = fused_broadcast_constructor(ith_pairs)
if param_usage_args(args_func(ith_fmb)) plim # first iteration will likely fail (ambitious)
push!(fmbs_split, ith_fmb)
i_stop == N && break
i_start = i_stop + 1 # N on first iteration
i_stop = N # reset i_stop
else
i_stop = i_stop - 1
end
end
return fmbs_split
end

function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA)
(; pairs) = fmb
dest = first(pairs).first
destinations = map(p -> p.first, pairs)
all(a -> axes(a) == axes(dest), destinations) ||
error("Cannot fuse broadcast expressions with unequal broadcast axes")
nitems = length(parent(dest))
CI = CartesianIndices(axes(dest))
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
destinations = map(p -> p.first, fmb.pairs)
fmbs = partition_kernels(fmb)
for fmb in fmbs
(; pairs) = fmb
dest = first(pairs).first
dests = map(p -> p.first, pairs)
all(a -> axes(a) == axes(dest), dests) || error(
"Cannot fuse broadcast expressions with unequal broadcast axes",
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
nitems = length(parent(dest))
CI = CartesianIndices(axes(dest))
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
end
return destinations
end
import Base.Broadcast
Expand Down
16 changes: 16 additions & 0 deletions ext/parameter_memory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function get_param_lim()
config = CUDA.compiler_config(CUDA.device())
(; ptx, cap) = config.params
return cap >= v"7.0" && ptx >= v"8.1" ? 32764 : 4096
end
param_usage(arg) = sizeof(typeof(CUDA.cudaconvert(arg)))
param_usage_args(args) =
sum(x -> param_usage(x), args) + param_usage(CUDA.KernelState)

function fused_multibroadcast_args(fmb::MBF.FusedMultiBroadcast)
dest = first(fmb.pairs).first
CI = CartesianIndices(axes(dest))
return (fmb, CI)
end

# TODO: Add recursive version of this (maybe similar to `StructuredPrinting` pattern?)
42 changes: 42 additions & 0 deletions src/collection/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,45 @@ macro make_fused(fusion_style, type_name, fused_name)
end
end
end

"""
@make_get_fused fusion_style type_name fused_named
This macro
- Defines a type type_name
- Defines a macro, `@fused_name`, using the fusion type `fusion_style`
This allows users to flexibility
to customize their broadcast fusion.
# Example
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedBroadcast
MBF.@make_get_fused MBF.fused_direct MyFusedBroadcast get_fused
x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)
# 4 reads, 2 writes
fmb = @get_fused begin
@. y1 = x1
@. y2 = x1
end
@test fmb isa MyFusedBroadcast
```
"""
macro make_get_fused(fusion_style, type_name, fused_name)
t = esc(type_name)
f = esc(fused_name)
return quote
macro $f(expr)
_pairs = esc($(fusion_style)(expr))
t = $t
quote
$t($_pairs)
end
end
end
end
1 change: 1 addition & 0 deletions src/execution/fused_kernels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@make_type FusedMultiBroadcast
@make_fused fused_direct FusedMultiBroadcast fused_direct
@make_fused fused_assemble FusedMultiBroadcast fused_assemble
@make_get_fused fused_direct FusedMultiBroadcast get_fused_direct

struct MBF_CPU end
struct MBF_CUDA end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using TestEnv
TestEnv.activate()
using CUDA # (optional)
using Revise; include(joinpath("test", "execution", "parameter_memory.jl"))
using Revise; include(joinpath("test", "execution", "kernel_splitting.jl"))
=#

include("utils_test.jl")
Expand Down Expand Up @@ -49,9 +49,7 @@ function perf_kernel_shared_reads_fused!(X, Y)
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 # breaks on A100 due to too much parameter memory
end
end
#! format: on
Expand All @@ -66,18 +64,8 @@ problem_size = (50, 5, 5, 6, 5400)
array_size = problem_size # array
X = get_arrays(:x, AType, bm.float_type, array_size)
Y = get_arrays(:y, AType, bm.float_type, array_size)
@testset "Test breaking case with parameter memory" begin
if use_cuda
try
perf_kernel_shared_reads_fused!(X, Y)
error("The above kernel should error")
catch e
@test startswith(
e.msg,
"Kernel invocation uses too much parameter memory.",
)
end
end
@testset "Test kernel splitting with too much parameter memory" begin
use_cuda && perf_kernel_shared_reads_fused!(X, Y)
end

nothing
73 changes: 73 additions & 0 deletions test/execution/measure_parameter_memory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#=
using TestEnv
TestEnv.activate()
using CUDA # (optional)
using Revise; include(joinpath("test", "execution", "measure_parameter_memory.jl"))
=#

include("utils_test.jl")
include("utils_setup.jl")
include("utils_benchmark.jl")

import MultiBroadcastFusion as MBF

#! format: off
function perf_kernel_shared_reads_fused!(X, Y)
(; x1, x2, x3, x4) = X
(; y1, y2, y3, y4) = Y
# TODO: can we write this more compactly with `@fused_assemble`?

# Let's make sure that every broadcasted object is different,
# so that we use up a lot of parameter memory:
fmb = MBF.@get_fused_direct begin
@. y1 = x1
@. y2 = x1 + x2
@. y3 = x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 # breaks on A100 due to too much parameter memory
end
MBFExt = Base.get_extension(MBF, :MultiBroadcastFusionCUDAExt)
@show MBFExt.param_usage_args(fmb)
end
#! format: on

@static get(ENV, "USE_CUDA", nothing) == "true" && using CUDA
use_cuda = @isdefined(CUDA) && CUDA.has_cuda() # will be true if you first run `using CUDA`
AType = use_cuda ? CUDA.CuArray : Array
device_name = use_cuda ? CUDA.name(CUDA.device()) : "CPU"
bm = Benchmark(; device_name, float_type = Float32)
problem_size = (50, 5, 5, 6, 5400)

array_size = problem_size # array
X = get_arrays(:x, AType, bm.float_type, array_size)
Y = get_arrays(:y, AType, bm.float_type, array_size)
@testset "Test measuring parameter memory" begin
use_cuda && perf_kernel_shared_reads_fused!(X, Y)
end

nothing
3 changes: 2 additions & 1 deletion test/execution/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ using Revise; include(joinpath("test", "execution", "runtests.jl"))
@safetestset "fused_shared_reads" begin; @time include("bm_fused_shared_reads.jl"); end
@safetestset "fused_shared_reads_writes" begin; @time include("bm_fused_shared_reads_writes.jl"); end
@safetestset "bm_fused_reads_vs_hard_coded" begin; @time include("bm_fused_reads_vs_hard_coded.jl"); end
@safetestset "parameter_memory" begin; @time include("parameter_memory.jl"); end
@safetestset "measure_parameter_memory" begin; @time include("measure_parameter_memory.jl"); end
@safetestset "kernel_splitting" begin; @time include("kernel_splitting.jl"); end
#! format: on

0 comments on commit 9a55a0a

Please sign in to comment.