forked from zzd1992/Image-Local-Attention
-
Notifications
You must be signed in to change notification settings - Fork 0
/
function.py
141 lines (108 loc) · 4.49 KB
/
function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from localAttention import (similar_forward,
similar_backward,
weighting_forward,
weighting_backward_ori,
weighting_backward_weight)
__all__ = ['f_similar', 'f_weighting', 'LocalAttention', 'TorchLocalAttention']
class similarFunction(Function):
@staticmethod
def forward(ctx, x_ori, x_loc, kH, kW):
ctx.save_for_backward(x_ori, x_loc)
ctx.kHW = (kH, kW)
output = similar_forward(x_ori, x_loc, kH, kW)
return output
@staticmethod
#@once_differentiable
def backward(ctx, grad_outputs):
x_ori, x_loc = ctx.saved_tensors
kH, kW = ctx.kHW
grad_ori = similar_backward(x_loc, grad_outputs, kH, kW, True)
grad_loc = similar_backward(x_ori, grad_outputs, kH, kW, False)
return grad_ori, grad_loc, None, None
class weightingFunction(Function):
@staticmethod
def forward(ctx, x_ori, x_weight, kH, kW):
ctx.save_for_backward(x_ori, x_weight)
ctx.kHW = (kH, kW)
output = weighting_forward(x_ori, x_weight, kH, kW)
return output
@staticmethod
#@once_differentiable
def backward(ctx, grad_outputs):
x_ori, x_weight = ctx.saved_tensors
kH, kW = ctx.kHW
grad_ori = weighting_backward_ori(x_weight, grad_outputs, kH, kW)
grad_weight = weighting_backward_weight(x_ori, grad_outputs, kH, kW)
return grad_ori, grad_weight, None, None
f_similar = similarFunction.apply
f_weighting = weightingFunction.apply
class LocalAttention(nn.Module):
def __init__(self, inp_channels, out_channels, kH, kW):
super(LocalAttention, self).__init__()
self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.kH = kH
self.kW = kW
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x3 = self.conv3(x)
weight = f_similar(x1, x2, self.kH, self.kW)
weight = F.softmax(weight, -1)
out = f_weighting(x3, weight, self.kH, self.kW)
return out
class TorchLocalAttention(nn.Module):
def __init__(self, inp_channels, out_channels, kH, kW):
super(TorchLocalAttention, self).__init__()
self.conv1 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(inp_channels, out_channels, kernel_size=1, bias=False)
self.kH = kH
self.kW = kW
@staticmethod
def f_similar(x_theta, x_phi, kh, kw):
n, c, h, w = x_theta.size() # (N, inter_channels, H, W)
pad = (kh // 2, kw // 2)
x_theta = x_theta.permute(0, 2, 3, 1).contiguous()
x_theta = x_theta.view(n * h * w, 1, c)
x_phi = F.unfold(x_phi, kernel_size=(kh, kw), stride=1, padding=pad)
x_phi = x_phi.contiguous().view(n, c, kh * kw, h * w)
x_phi = x_phi.permute(0, 3, 1, 2).contiguous()
x_phi = x_phi.view(n * h * w, c, kh * kw)
out = torch.matmul(x_theta, x_phi)
out = out.view(n, h, w, kh * kw)
return out
@staticmethod
def f_weighting(x_theta, x_phi, kh, kw):
n, c, h, w = x_theta.size() # (N, inter_channels, H, W)
pad = (kh // 2, kw // 2)
x_theta = F.unfold(x_theta, kernel_size=(kh, kw), stride=1, padding=pad)
x_theta = x_theta.permute(0, 2, 1).contiguous()
x_theta = x_theta.view(n * h * w, c, kh * kw)
x_phi = x_phi.view(n * h * w, kh * kw, 1)
out = torch.matmul(x_theta, x_phi)
out = out.squeeze(-1)
out = out.view(n, h, w, c)
out = out.permute(0, 3, 1, 2).contiguous()
return out
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x3 = self.conv3(x)
weight = self.f_similar(x1, x2, self.kH, self.kW)
weight = F.softmax(weight, -1)
out = self.f_weighting(x3, weight, self.kH, self.kW)
return out
if __name__ == '__main__':
b, c, h, w = 8, 3, 32, 32
kH, kW = 5, 5
x = torch.rand(b, c, h, w).cuda()
m = LocalAttention(c, c, kH, kW)
m.cuda()
y = m(x)