Skip to content

Commit

Permalink
Support Nvidia Hopper GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
giordano committed Jan 13, 2024
1 parent d046063 commit b67cee6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ CUDAExt = "CUDA"
CairoMakieExt = "CairoMakie"

[compat]
CUDA = "3.8.4, 3.12, 4.4"
CUDA = "3.8.4, 3.12, 4.4, 5"
CairoMakie = "0.7, 0.10.7"
CpuId = "0.3"
DocStringExtensions = "0.9"
Expand Down
32 changes: 25 additions & 7 deletions ext/CUDAExt/implementations/peakflops_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function _theoretical_peakflops_gpu_cudacores(; device, dtype)
elseif dtype == Float64
max_peakflops *= 1
else
throw(ArgumentError("Unsupported dtype."))
throw(ArgumentError("Unsupported dtype $(dtype)."))
end
return max_peakflops
end
Expand All @@ -60,7 +60,9 @@ function _theoretical_peakflops_gpu_tensorcores(;
device=CUDA.device(), dtype=Float16, verbose=true
)
cap = CUDA.capability(device)
if cap == v"8.0.0"
if cap == v"9.0.0"
devtype = :Hopper
elseif cap == v"8.0.0"
devtype = :A100
elseif cap == v"7.0.0"
devtype = :V100
Expand All @@ -70,10 +72,26 @@ function _theoretical_peakflops_gpu_tensorcores(;
max_clock_rate = CUDA.attribute(device, CUDA.CU_DEVICE_ATTRIBUTE_CLOCK_RATE) # in kHz
num_tensor_cores = ntensorcores(device)
max_peakflops = max_clock_rate * num_tensor_cores * 1e-9 # in TFLOP/s
if devtype == :A100
if devtype == :Hopper
# matrix dimensions 8x8x4, factor 2 for nflops in A*B+C see
# * <https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper> (figures 10-11)
# * <https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/> (figures 5-8)
if Symbol(dtype) == :Float16
# matrix dimensions 8x8x4, factor 2 for nflops in A*B+C
# see e.g. https://peerj.com/articles/cs-330.pdf
max_peakflops *= 2 * 16 * 8 * 4 # XXX: Wrong result!
elseif Symbol(dtype) in (:Float32, :TensorFloat32, :TF32)
max_peakflops *= 2 * 8 * 8 * 4 # XXX: Wrong result!
elseif Symbol(dtype) == :Float64
max_peakflops *= 2 * 4 * 4 * 2
elseif Symbol(dtype) == :Int8
max_peakflops *= 2 * 2 * 32 * 8 * 4 # XXX: Wrong result!
else
throw(ArgumentError("Unsupported dtype $(dtype)."))
end
elseif devtype == :A100
if Symbol(dtype) == :Float16
# matrix dimensions 8x8x4, factor 2 for nflops in A*B+C see
# e.g. <https://doi.org/10.7717/peerj-cs.330> or
# <https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/nvidia-ampere-architecture-whitepaper.pdf>
max_peakflops *= 2 * 8 * 8 * 4
elseif Symbol(dtype) in (:Float32, :TensorFloat32, :TF32)
max_peakflops *= 2 * 4 * 8 * 4
Expand All @@ -82,13 +100,13 @@ function _theoretical_peakflops_gpu_tensorcores(;
elseif Symbol(dtype) == :Int8
max_peakflops *= 2 * 2 * 8 * 8 * 4
else
throw(ArgumentError("Unsupported dtype."))
throw(ArgumentError("Unsupported dtype $(dtype)."))
end
elseif devtype == :V100
if Symbol(dtype) == :Float16
max_peakflops *= 2 * 4 * 4 * 4
else
throw(ArgumentError("Unsupported dtype."))
throw(ArgumentError("Unsupported dtype $(dtype)."))
end
end
return max_peakflops
Expand Down
4 changes: 2 additions & 2 deletions ext/CUDAExt/peakflops_gpu_wmmas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ function _peakflops_gpu_wmmas(;
dtype_a = dtype_b = BFloat16
dtype_c = dtype_d = Float32
else
throw(ArgumentError("Unsupported dtype."))
throw(ArgumentError("Unsupported dtype $(dtype)."))
end
d_a = CUDA.rand(dtype_a, m, k)
d_b = CUDA.rand(dtype_b, k, n)
Expand All @@ -165,7 +165,7 @@ function _peakflops_gpu_wmmas(;
elseif Symbol(dtype) in (:BFloat16, :BF16)
kernel = @cuda launch = false _kernel_wmma_bf16_lowlevel(d_a, d_b, d_c, d_d)
else
throw(ArgumentError("Unsupported dtype."))
throw(ArgumentError("Unsupported dtype $(dtype)."))
end
warpsize = CUDA.attribute(device, CUDA.CU_DEVICE_ATTRIBUTE_WARP_SIZE)
# @show threads
Expand Down

0 comments on commit b67cee6

Please sign in to comment.