-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6a8e8cb
commit c154582
Showing
6 changed files
with
226 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
test/Manifest.toml | ||
Manifest.toml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[deps] | ||
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712" | ||
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" | ||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
using CUDA, GemmKernels | ||
using Hyperopt | ||
using Plots | ||
pythonplot() | ||
|
||
const NUM_SAMPLES = 250 | ||
const NUM_SAMPLES_PLOT = 250 | ||
|
||
include("../configs/configs.jl") | ||
|
||
AB_type = Float16 | ||
CD_type = Float32 | ||
|
||
zero_c = true | ||
|
||
OP_M, OP_N, OP_K = 16, 16, 16 | ||
|
||
markershapes = Dict( | ||
"NN" => :circle, | ||
"TT" => :cross, | ||
"TN" => :diamond, | ||
"NT" => :dtriangle, | ||
) | ||
|
||
function print_counters(counters) | ||
count_str(categ) = "$(counters[categ]) ($(round(100*counters[categ]/counters["total"]; digits=1))%)" | ||
|
||
println("Total: $(counters["total"]) configurations") | ||
println(repeat("-", 100)) | ||
println("Skipped due to invalid GemmKernels config: $(count_str("invalid_config"))") | ||
println("Produced incorrect result: $(count_str("invalid_result"))") | ||
println("Threw an error: $(count_str("error"))") | ||
println("Successful runs: $(count_str("success"))") | ||
end | ||
|
||
function optimise(transpose_a, transpose_b) | ||
M = N = K = 4096 | ||
|
||
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64 | ||
WARPS_M = 4 | ||
WARPS_N = 2 | ||
kernel = Kernel.matmul_singlestage | ||
cf = @get_wmma_config | ||
c_h, a, b, c, d = generate_inputs(cf) | ||
|
||
@info "Starting hyperopt..." | ||
|
||
counters = Dict( | ||
"total" => 0, | ||
"invalid_config" => 0, | ||
"invalid_result" => 0, | ||
"error" => 0, | ||
"success" => 0, | ||
) | ||
|
||
ho = @hyperopt for i = 1000, | ||
BLOCK_M = 2 .^ (1:8), | ||
BLOCK_N = 2 .^ (1:8), | ||
BLOCK_K = 2 .^ (1:8), | ||
WARPS_M = 2 .^ (0:3), | ||
WARPS_N = 2 .^ (0:3), | ||
kernel in [Kernel.matmul_singlestage, Kernel.matmul_pipelined] | ||
|
||
counters["total"] += 1 | ||
|
||
try | ||
cf = @get_wmma_config | ||
catch err | ||
if isa(err, GemmKernels.ConfigError) | ||
counters["invalid_config"] += 1 | ||
return Inf | ||
end | ||
end | ||
|
||
@info "Trying configuration: $(cf.config)" | ||
|
||
d .= 0 | ||
|
||
try | ||
run_gemm(cf, a, b, c, d) | ||
catch err | ||
if isa(err, GemmKernels.ConfigError) | ||
counters["invalid_config"] += 1 | ||
return Inf | ||
end | ||
|
||
counters["error"] += 1 | ||
|
||
if isa(err, CuError) | ||
@error "Configuration failed: $(cf.config)" | ||
rethrow() | ||
end | ||
|
||
@info "Skipping configuration: $(cf.config)\n" * sprint(Base.showerror, err) | ||
return Inf | ||
end | ||
|
||
if !verify(cf, c_h, d) | ||
@warn "Configuration produced invalid result: $(cf.config)" | ||
counters["invalid_result"] += 1 | ||
return Inf | ||
end | ||
|
||
times = [] | ||
|
||
try | ||
for i in 1:NUM_SAMPLES | ||
prof = CUDA.@profile run_gemm(cf, a, b, c, d) | ||
push!(times, sum(prof.device[!, "stop"] - prof.device[!, "start"])) | ||
end | ||
catch err | ||
counters["error"] += 1 | ||
|
||
if isa(err, CuError) | ||
@error "Configuration failed: $(cf.config)" | ||
rethrow() | ||
end | ||
|
||
@info "Skipping configuration: $(cf.config)\n" * sprint(Base.showerror, err) | ||
return Inf | ||
end | ||
|
||
counters["success"] += 1 | ||
return minimum(times) | ||
end | ||
|
||
print_counters(counters) | ||
|
||
ho, counters | ||
end | ||
|
||
get_label(transpose_a, transpose_b) = "$(transpose_a ? "T" : "N")$(transpose_b ? "T" : "N")" | ||
|
||
function make_plot(BLOCK_M, BLOCK_N, BLOCK_K, WARPS_M, WARPS_N, kernel, transpose_a, transpose_b) | ||
label = get_label(transpose_a, transpose_b) | ||
|
||
N_vals = 2 .^ (7:14) | ||
gemmkernels = [] | ||
cublas = [] | ||
|
||
for N in N_vals | ||
@show N | ||
M = K = N | ||
|
||
cf = @get_wmma_config | ||
c_h, a, b, c, d = generate_inputs(cf) | ||
|
||
samples = [] | ||
|
||
for i in 1:NUM_SAMPLES_PLOT | ||
prof = CUDA.@profile run_gemm(cf, a, b, c, d) | ||
push!(samples, sum(prof.device[!, "stop"] - prof.device[!, "start"])) | ||
end | ||
|
||
push!(gemmkernels, minimum(samples)) | ||
|
||
samples = [] | ||
|
||
for i in 1:NUM_SAMPLES_PLOT | ||
prof = CUDA.@profile run_baseline(cf, a, b, c, d) | ||
push!(samples, sum(prof.device[!, "stop"] - prof.device[!, "start"])) | ||
end | ||
|
||
push!(cublas, minimum(samples)) | ||
end | ||
|
||
ratios = 100 .* cublas ./ gemmkernels | ||
|
||
plot!(N_vals, ratios, label=label, markershape=markershapes[label], xscale=:log2, ylims=(0, max(100, ratios...))) | ||
title!("$AB_type x $AB_type = $CD_type") | ||
xlabel!("Matrix size [-]") | ||
ylabel!("Performance relative to cuBLAS [%]") | ||
end | ||
|
||
function main() | ||
hos = Dict() | ||
|
||
total_counters = Dict( | ||
"total" => 0, | ||
"invalid_config" => 0, | ||
"invalid_result" => 0, | ||
"error" => 0, | ||
"success" => 0, | ||
) | ||
|
||
for transpose_a in [false, true], | ||
transpose_b in [false, true] | ||
hos[(transpose_a, transpose_b)], counters = optimise(transpose_a, transpose_b) | ||
total_counters = Dict(k => total_counters[k] + counters[k] for k in keys(counters)) | ||
end | ||
|
||
println(repeat("=", 100)) | ||
println("Overall configurations:") | ||
println(repeat("=", 100)) | ||
|
||
print_counters(total_counters) | ||
|
||
println("Optimal parameters:") | ||
|
||
for transpose_a in [false, true], | ||
transpose_b in [false, true] | ||
println("$(get_label(transpose_a, transpose_b)): $(hos[(transpose_a, transpose_b)].minimizer)") | ||
end | ||
|
||
for transpose_a in [false, true], | ||
transpose_b in [false, true] | ||
make_plot(hos[(transpose_a, transpose_b)].minimizer..., transpose_a, transpose_b) | ||
end | ||
|
||
savefig("plot.pdf") | ||
end | ||
|
||
isinteractive() || main() |