Skip to content

Commit

Permalink
Add script to tune parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasfaingnaert committed Nov 22, 2023
1 parent 6a8e8cb commit 6afd275
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
test/Manifest.toml
Manifest.toml
8 changes: 5 additions & 3 deletions configs/configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using GemmKernels
using LinearAlgebra
using ForwardDiff
using Octavian

struct Configuration
name # Human-readable name of the configuration.
Expand Down Expand Up @@ -238,10 +239,10 @@ macro get_wmma_config()
CD_type,
transpose_a,
transpose_b,
mul!,
Octavian.matmul!,
Epilogue.Default(),
verify_default,
Kernel.matmul_pipelined,
kernel,
wmma_baseline)
end end)
end
Expand Down Expand Up @@ -520,7 +521,8 @@ function get_configs()
[2, 2, 1],
[1, 1, 2],
[2, 2, 2]], [[2048, 2048, 2048]]),
zero_c in [false]
zero_c in [false],
kernel in [Kernel.matmul_pipelined]

push!(rv, @get_wmma_config)
end
Expand Down
1 change: 1 addition & 0 deletions src/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ function matmul_pipelined(conf::GemmKernels.Config, a, b, c, d,

@loopinfo unroll for j = 1 : num_fragments_n
b_tile = translate_offset(warp_tile.KN, (K = 0, N = (j-1)*conf.compute_op_shape.N))
@assert ((b_tile.base.M + b_tile.offset.M) < conf.matmul_shape.M) && ((b_tile.base.K + b_tile.offset.K) < conf.matmul_shape.K)
@inbounds @immutable b_frags[nxt_stage, j] = transf_sh2rf_b(Operator.load_b(conf.operator, conf.shared_b_layout, shmem_b, b_tile), b_tile)
end

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
XUnit = "3e3c03f2-1a94-11e9-2981-050a4ca824ab"
5 changes: 5 additions & 0 deletions tuning/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[deps]
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9"
208 changes: 208 additions & 0 deletions tuning/tune-wmma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
using CUDA, GemmKernels
using Hyperopt
using Plots
pythonplot()

const NUM_SAMPLES = 250
const NUM_SAMPLES_PLOT = 250

include("../configs/configs.jl")

AB_type = Float16
CD_type = Float32

zero_c = true

OP_M, OP_N, OP_K = 16, 16, 16

markershapes = Dict(
"NN" => :circle,
"TT" => :cross,
"TN" => :diamond,
"NT" => :dtriangle,
)

function print_counters(counters)
count_str(categ) = "$(counters[categ]) ($(round(100*counters[categ]/counters["total"]; digits=1))%)"

println("Total: $(counters["total"]) configurations")
println(repeat("-", 100))
println("Skipped due to invalid GemmKernels config: $(count_str("invalid_config"))")
println("Produced incorrect result: $(count_str("invalid_result"))")
println("Threw an error: $(count_str("error"))")
println("Successful runs: $(count_str("success"))")
end

function optimise(transpose_a, transpose_b)
M = N = K = 4096

BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64
WARPS_M = 4
WARPS_N = 2
kernel = Kernel.matmul_singlestage
cf = @get_wmma_config
c_h, a, b, c, d = generate_inputs(cf)

@info "Starting hyperopt..."

counters = Dict(
"total" => 0,
"invalid_config" => 0,
"invalid_result" => 0,
"error" => 0,
"success" => 0,
)

ho = @hyperopt for i = 1000,
BLOCK_M = 2 .^ (1:8),
BLOCK_N = 2 .^ (1:8),
BLOCK_K = 2 .^ (1:8),
WARPS_M = 2 .^ (0:3),
WARPS_N = 2 .^ (0:3),
kernel in [Kernel.matmul_singlestage, Kernel.matmul_pipelined]

counters["total"] += 1

try
cf = @get_wmma_config
catch err
if isa(err, GemmKernels.ConfigError)
counters["invalid_config"] += 1
return Inf
end
end

@info "Trying configuration: $(cf.config)"

d .= 0

try
run_gemm(cf, a, b, c, d)
catch err
counters["error"] += 1

if isa(err, CuError)
@error "Configuration failed: $(cf.config)"
rethrow()
end

@info "Skipping configuration: $(cf.config)\n" * sprint(Base.showerror, err)
return Inf
end

if !verify(cf, c_h, d)
@warn "Configuration produced invalid result: $(cf.config)"
counters["invalid_result"] += 1
return Inf
end

times = []

try
for i in 1:NUM_SAMPLES
prof = CUDA.@profile run_gemm(cf, a, b, c, d)
push!(times, sum(prof.device[!, "stop"] - prof.device[!, "start"]))
end
catch err
counters["error"] += 1

if isa(err, CuError)
@error "Configuration failed: $(cf.config)"
rethrow()
end

@info "Skipping configuration: $(cf.config)\n" * sprint(Base.showerror, err)
return Inf
end

counters["success"] += 1
return minimum(times)
end

print_counters(counters)

ho, counters
end

get_label(transpose_a, transpose_b) = "$(transpose_a ? "T" : "N")$(transpose_b ? "T" : "N")"

function make_plot(BLOCK_M, BLOCK_N, BLOCK_K, WARPS_M, WARPS_N, kernel, transpose_a, transpose_b)
label = get_label(transpose_a, transpose_b)

N_vals = 2 .^ (7:14)
gemmkernels = []
cublas = []

for N in N_vals
@show N
M = K = N

cf = @get_wmma_config
c_h, a, b, c, d = generate_inputs(cf)

samples = []

for i in 1:NUM_SAMPLES_PLOT
prof = CUDA.@profile run_gemm(cf, a, b, c, d)
push!(samples, sum(prof.device[!, "stop"] - prof.device[!, "start"]))
end

push!(gemmkernels, minimum(samples))

samples = []

for i in 1:NUM_SAMPLES_PLOT
prof = CUDA.@profile run_baseline(cf, a, b, c, d)
push!(samples, sum(prof.device[!, "stop"] - prof.device[!, "start"]))
end

push!(cublas, minimum(samples))
end

ratios = 100 .* cublas ./ gemmkernels

plot!(N_vals, ratios, label=label, markershape=markershapes[label], xscale=:log2, ylims=(0, max(100, ratios...)))
title!("$AB_type x $AB_type = $CD_type")
xlabel!("Matrix size [-]")
ylabel!("Performance relative to cuBLAS [%]")
end

function main()
hos = Dict()

total_counters = Dict(
"total" => 0,
"invalid_config" => 0,
"invalid_result" => 0,
"error" => 0,
"success" => 0,
)

for transpose_a in [false, true],
transpose_b in [false, true]
hos[(transpose_a, transpose_b)], counters = optimise(transpose_a, transpose_b)
total_counters = Dict(k => total_counters[k] + counters[k] for k in keys(counters))
end

println(repeat("=", 100))
println("Overall configurations:")
println(repeat("=", 100))

print_counters(total_counters)

println("Optimal parameters:")

for transpose_a in [false, true],
transpose_b in [false, true]
println("$(get_label(transpose_a, transpose_b)): $(hos[(transpose_a, transpose_b)].minimizer)")
end

for transpose_a in [false, true],
transpose_b in [false, true]
make_plot(hos[(transpose_a, transpose_b)].minimizer..., transpose_a, transpose_b)
end

savefig("plot.pdf")
end

isinteractive() || main()

0 comments on commit 6afd275

Please sign in to comment.