Skip to content

Commit

Permalink
Store descriptor keys as native Julia dtypes
Browse files Browse the repository at this point in the history
However, there's still a few `CUDNN_xyz_t` datatypes, which are Cenums.
We could still map those to Julia integers if serialization is difficult otherwise.
  • Loading branch information
RomeoV committed Jun 12, 2023
1 parent ab8883d commit e9fa1b3
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions lib/cudnn/src/convolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,6 @@ end

## Utilities to find a fast algorithm

# Helper fct to recover cudnn descriptor tuples from cudnn descriptor pointers
# so that we can cache algorithms based on data descriptors.
# Actually just reverses the cache dict and returns the descriptor as a tuple.
map_cudnn_ptr_to_jl_tuple(cache_dict, desc_ptr) = Dict(zip(values(cache_dict),
keys(cache_dict)))[desc_ptr]

const cudnnConvolutionFwdAlgoPerfCache = Dict{Tuple,cudnnConvolutionFwdAlgoPerf_t}()
const cudnnConvolutionFwdAlgoPerfCacheLock = ReentrantLock()

Expand All @@ -207,11 +201,11 @@ const cudnnConvolutionFwdAlgoPerfCacheLock = ReentrantLock()
It can be set to false when beta is zero to save an allocation and must otherwise be set to true.
"""
function cudnnConvolutionFwdAlgoPerf(xDesc, x, wDesc, w, convDesc, yDesc, y, biasDesc, activation, allocateTmpBuf=true)
xDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnTensorDescriptorCache, xDesc)
wDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnFilterDescriptorCache, wDesc)
convDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnConvolutionDescriptorCache, convDesc)
xDesc_native = cudnnGetTensorDescriptor(xDesc)
wDesc_native = cudnnGetFilterDescriptor(wDesc)
convDesc_native = cudnnGetConvolutionDescriptor(convDesc)
biasDesc_native = (isnothing(biasDesc) ? nothing
: map_cudnn_ptr_to_jl_tuple(cudnnTensorDescriptorCache, biasDesc))
: cudnnGetTensorDescriptor(biasDesc))

key = (xDesc_native, wDesc_native, convDesc_native, biasDesc, activation)
val = lock(cudnnConvolutionFwdAlgoPerfCacheLock) do
Expand Down Expand Up @@ -249,9 +243,9 @@ const cudnnConvolutionBwdDataAlgoPerfCacheLock = ReentrantLock()
It can be set to false when beta is zero to save an allocation and must otherwise be set to true.
"""
function cudnnConvolutionBwdDataAlgoPerf(wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, allocateTmpBuf=true)
wDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnFilterDescriptorCache, wDesc)
dyDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnTensorDescriptorCache, dyDesc)
convDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnConvolutionDescriptorCache, convDesc)
wDesc_native = cudnnGetFilterDescriptor(wDesc)
dyDesc_native = cudnnGetTensorDescriptor(dyDesc)
convDesc_native = cudnnGetConvolutionDescriptor(convDesc)

key = (wDesc_native, dyDesc_native, convDesc_native)
val = lock(cudnnConvolutionBwdDataAlgoPerfCacheLock) do
Expand Down Expand Up @@ -289,9 +283,9 @@ const cudnnConvolutionBwdFilterAlgoPerfCacheLock = ReentrantLock()
It can be set to false when beta is zero to save an allocation and must otherwise be set to true.
"""
function cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, dyDesc, dy, convDesc, dwDesc, dw, allocateTmpBuf=true)
xDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnTensorDescriptorCache, xDesc)
dyDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnTensorDescriptorCache, dyDesc)
convDesc_native = map_cudnn_ptr_to_jl_tuple(cudnnConvolutionDescriptorCache, convDesc)
xDesc_native = cudnnGetTensorDescriptor(xDesc)
dyDesc_native = cudnnGetTensorDescriptor(dyDesc)
convDesc_native = cudnnGetConvolutionDescriptor(convDesc)

key = (xDesc_native, dyDesc_native, convDesc_native)
val = lock(cudnnConvolutionBwdFilterAlgoPerfCacheLock) do
Expand Down

0 comments on commit e9fa1b3

Please sign in to comment.