Skip to content

Commit

Permalink
dynamically ask cuDNN about available number of alorithms (root-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
kgizdov authored and lmoneta committed Aug 12, 2020
1 parent 0cb647c commit db20c3f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tmva/tmva/src/DNN/Architectures/Cudnn/Propagate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ void TCudnn<AFloat>::InitializeConvWorkspace(TWorkspace * & workspace,
/**
* I'm sure there may be a faster way, but this works
*/
int convRequestedAlgoCount{8}; // requestedAlgoCount is setting how many algorithms to try, can be tuned, fixed for now as all available
int convRequestedAlgoCount{0}; // requestedAlgoCount is setting how many algorithms to try
CUDNNCHECK(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle, &convRequestedAlgoCount)) // ask cuDNN how much it can try

int algoCount;
cudnnConvolutionFwdAlgoPerf_t convFwdPerfResults[convRequestedAlgoCount]; // this will store metrics to choose convolution algorithm
Expand Down Expand Up @@ -575,7 +576,7 @@ void TCudnn<AFloat>::InitializeConvWorkspace(TWorkspace * & workspace,
/**
* I'm sure there may be a faster way, but this works
*/
convRequestedAlgoCount = 6; // reset to max number of available backward algorithms
CUDNNCHECK(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle, &convRequestedAlgoCount)) // ask cuDNN how much it can try
cudnnConvolutionBwdDataAlgoPerf_t convBwdDataPerfResults[convRequestedAlgoCount]; // this will store metrics to choose convolution algorithm
CUDNNCHECK(cudnnFindConvolutionBackwardDataAlgorithm(
cudnnHandle,
Expand Down Expand Up @@ -650,7 +651,7 @@ void TCudnn<AFloat>::InitializeConvWorkspace(TWorkspace * & workspace,
/**
* I'm sure there may be a faster way, but this works
*/
convRequestedAlgoCount = 6; // reset to max number of available backward algorithms
CUDNNCHECK(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle, &convRequestedAlgoCount)) // ask cuDNN how much it can try
cudnnConvolutionBwdFilterAlgoPerf_t convBwdFilterPerfResults[convRequestedAlgoCount]; // this will store metrics to choose convolution algorithm
CUDNNCHECK(cudnnFindConvolutionBackwardFilterAlgorithm(
cudnnHandle,
Expand Down

0 comments on commit db20c3f

Please sign in to comment.