Skip to content

Commit 8ffc901

Browse files
authored
Tkurth/sgbn fixes (#1685)
* fixing order of class instantiation and device extraction in mixed precision lamb * this commit fixes the SGBN graph capture problem by caching the cudnn plan and re-using it * disentangling the mplamb MR and SGBN MR * cleaner caching
1 parent 30a7ad3 commit 8ffc901

File tree

3 files changed

+622
-511
lines changed

3 files changed

+622
-511
lines changed

apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
#include "norm_sample.h"
99

10+
// define this enum:
11+
enum bn_type { BN_FWD, BN_BWD };
12+
13+
// this is a global variable
14+
static std::map<std::vector<int64_t>, cudnn_frontend::ExecutionPlan> gbn_plan_cache;
15+
1016
at::Tensor gbn_forward(const at::Tensor& x,
1117
const at::Tensor& scale,
1218
const at::Tensor& bias,
@@ -38,28 +44,43 @@ at::Tensor gbn_forward(const at::Tensor& x,
3844
void_peer_buffers.push_back((void*)addr);
3945
}
4046

47+
// we need the peer size for the buffer reset
48+
size_t peer_size = 1;
49+
for (size_t i = 0; i < 4; ++i){
50+
peer_size *= peerDims[i];
51+
}
52+
53+
// sanity check
4154
assert(bn_group == void_peer_buffers.size());
42-
run_batch_norm_forward(
43-
perChannelDims,
44-
epsilonDims,
45-
tensorDims,
46-
peerDims,
47-
x.data_ptr(),
48-
y.data_ptr(),
49-
scale.data_ptr(),
50-
bias.data_ptr(),
51-
running_mean.data_ptr(),
52-
running_var.data_ptr(),
53-
running_mean.data_ptr(),
54-
running_var.data_ptr(),
55-
minibatch_mean.data_ptr(),
56-
minibatch_inv_var.data_ptr(),
57-
void_peer_buffers,
58-
epsilon,
59-
momentum,
60-
rank_id
61-
);
6255

56+
// check if plan already exists
57+
std::vector<int64_t> fv = {(int64_t)BN_FWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
58+
if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) {
59+
auto plan = run_batch_norm_forward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
60+
gbn_plan_cache.emplace(fv, std::move(plan));
61+
}
62+
63+
// get plan and handle
64+
auto plan = gbn_plan_cache.find(fv)->second;
65+
66+
// execute
67+
execute_batch_norm_forward(plan,
68+
x.data_ptr(),
69+
y.data_ptr(),
70+
scale.data_ptr(),
71+
bias.data_ptr(),
72+
running_mean.data_ptr(),
73+
running_var.data_ptr(),
74+
running_mean.data_ptr(),
75+
running_var.data_ptr(),
76+
minibatch_mean.data_ptr(),
77+
minibatch_inv_var.data_ptr(),
78+
void_peer_buffers,
79+
static_cast<double>(epsilon),
80+
static_cast<double>(momentum),
81+
peer_size,
82+
rank_id);
83+
6384
return y;
6485
}
6586

@@ -98,26 +119,37 @@ std::vector<at::Tensor> gbn_backward(
98119
void_peer_buffers.push_back((void*)addr);
99120
}
100121

122+
// we need the peer size for the buffer reset
123+
size_t peer_size = 1;
124+
for (size_t i = 0; i < 4; ++i){
125+
peer_size *= peerDims[i];
126+
}
127+
101128
assert(bn_group == void_peer_buffers.size());
102129

103-
run_batch_norm_backward(
104-
perChannelDims,
105-
epsilonDims,
106-
tensorDims,
107-
peerDims,
108-
x.data_ptr(),
109-
dy.data_ptr(),
110-
scale.data_ptr(),
111-
minibatch_mean.data_ptr(),
112-
minibatch_inv_var.data_ptr(),
113-
x_grad.data_ptr(),
114-
scale_grad.data_ptr(),
115-
bias_grad.data_ptr(),
116-
void_peer_buffers,
117-
epsilon,
118-
rank_id);
119-
120-
130+
std::vector<int64_t> fv = {(int64_t)BN_BWD, N, C, H, W, bn_group, (int64_t)CUDNN_DATA_HALF};
131+
if ( gbn_plan_cache.find(fv) == gbn_plan_cache.end() ) {
132+
auto plan = run_batch_norm_backward(tensorDims, perChannelDims, epsilonDims, peerDims, CUDNN_DATA_HALF);
133+
gbn_plan_cache.emplace(fv, std::move(plan));
134+
}
135+
136+
// get plan and handle
137+
auto plan = gbn_plan_cache.find(fv)->second;
138+
139+
// execute
140+
execute_batch_norm_backward(plan,
141+
x.data_ptr(),
142+
dy.data_ptr(),
143+
scale.data_ptr(),
144+
minibatch_mean.data_ptr(),
145+
minibatch_inv_var.data_ptr(),
146+
void_peer_buffers,
147+
x_grad.data_ptr(),
148+
scale_grad.data_ptr(),
149+
bias_grad.data_ptr(),
150+
static_cast<double>(epsilon),
151+
peer_size,
152+
rank_id);
121153

122154
return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
123155
}

0 commit comments

Comments
 (0)