forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathWeightNorm.cpp
117 lines (100 loc) · 4.45 KB
/
WeightNorm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/NativeFunctions.h"
#include <cstring>
#include <memory>
#include <sstream>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
namespace at {
namespace native {
// Staying faithful to the Python for now for clarity, look for optimizations later
// (e.g., single return statement for RVO)
Tensor norm_except_dim(const Tensor & v, int64_t pow, int64_t dim)
{
// I assume tensor.contiguous(), view(), norm(), etc. here will dispatch through VariableType.
if (dim == -1) {
return v.norm(pow);
} else if (dim == 0) {
std::vector<int64_t> output_size(v.dim(), 1);
output_size[0] = v.size(0);
return v.contiguous().view({v.size(0), -1}).norm(pow, 1).view(output_size);
} else if (dim == v.dim() - 1) {
std::vector<int64_t> output_size(v.dim(), 1);
output_size[v.dim() - 1] = v.size(v.dim() - 1);
return v.contiguous().view({-1, v.size(v.dim() - 1)}).norm(pow, 0).view(output_size);
} else {
// To consider: at::native::norm_except_dim is probably fine as well,
// and would avoid an additional dynamic dispatch.
return at::norm_except_dim(v.transpose(0, dim), pow, 0).transpose(0, dim); // optimize?
}
}
Tensor _weight_norm
(const Tensor & v_in,
const Tensor & g_in,
int64_t dim)
{
AT_CHECK(
v_in.device() == g_in.device(),
"weight_norm: expected v_in and g_in to be on the same device, but v_in is "
"on ", v_in.device(), " and g_in is on ", g_in.device());
auto v = v_in.contiguous();
auto g = g_in.contiguous();
bool can_use_fused = v.type().is_cuda() && (dim == 0 || dim == v.dim() - 1);
if (can_use_fused) {
// weight_norm does not have a derivative defined for it, so this will route back through
// VariableType.cpp, and construct a WeightNormFusedBackward object in the autograd graph.
return std::get<0>(at::_weight_norm_cuda_interface(v, g, dim));
} else {
// Double-differentiable primitive ops
// at::native::norm_except_dim would probably be fine as well.
return v*(g/at::norm_except_dim(v, 2, dim));
}
}
// Differentiable backward path, an alternative to weight_norm_cuda_backward, to be used
// when backward is itself creating a graph.
// The GradMode::is_enabled() check must be performed within Functions.cpp; that's why we
// define a separate function here, instead of inlining it in weight_norm_cuda_backward.
std::tuple<Tensor, Tensor> _weight_norm_differentiable_backward
(const Tensor & grad_w,
const Tensor & saved_v,
const Tensor & saved_g,
const Tensor & saved_norms,
int64_t dim)
{
// In Functions.cpp, the HardshrinkBackward object supplies "grad.contiguous()"
// as the first argument, so grad_w should be contiguous here.
// All these checks should succeed:
AT_CHECK(grad_w.is_contiguous(), "grad_w must be contiguous");
AT_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
AT_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
AT_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
int64_t last_dim = saved_v.dim() - 1;
int64_t last_size = saved_v.size(last_dim);
// Like weight_norm_fused_backward, weight_norm_differentiable_backward should only ever be called
// through a WeightNormFusedBackward object, so we expect that dim == 0 || dim == saved_v.size(-1)
AT_CHECK(dim == 0 || dim == last_dim, "Expected dim to be the first or last dimension");
// saved_g and saved_norms are already shaped to broadcast over the correct dimensions
// ...but saved_norms might be Float when saved_g and saved_v are half.
// To consider: saved_norms.to(..., True /*non_blocking*/);
auto norms = saved_norms.to(saved_g.type().scalarType());
std::vector<int64_t> bcast_size(saved_v.dim(), 1);
// Analytic backward path using differentiable primitive ops
if (dim == 0) {
bcast_size[0] = saved_v.size(0);
auto per_dim_sums = (grad_w*saved_v).view({saved_v.size(0), -1}).sum(1).view(bcast_size);
auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
auto grad_g = per_dim_sums/norms;
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
} else { // dim == last_dim
bcast_size[last_dim] = last_size;
auto per_dim_sums = (grad_w*saved_v).view({-1, last_size}).sum(0).view(bcast_size);
auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
auto grad_g = per_dim_sums/norms;
return std::tuple<Tensor, Tensor>{grad_v, grad_g};
}
}
} // namespace native
} // namespace at