-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathfastflow.py
160 lines (145 loc) · 5.81 KB
/
fastflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import FrEIA.framework as Ff
import FrEIA.modules as Fm
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import constants as const
def subnet_conv_func(kernel_size, hidden_ratio):
def subnet_conv(in_channels, out_channels):
hidden_channels = int(in_channels * hidden_ratio)
return nn.Sequential(
nn.Conv2d(in_channels, hidden_channels, kernel_size, padding="same"),
nn.ReLU(),
nn.Conv2d(hidden_channels, out_channels, kernel_size, padding="same"),
)
return subnet_conv
def nf_fast_flow(input_chw, conv3x3_only, hidden_ratio, flow_steps, clamp=2.0):
nodes = Ff.SequenceINN(*input_chw)
for i in range(flow_steps):
if i % 2 == 1 and not conv3x3_only:
kernel_size = 1
else:
kernel_size = 3
nodes.append(
Fm.AllInOneBlock,
subnet_constructor=subnet_conv_func(kernel_size, hidden_ratio),
affine_clamping=clamp,
permute_soft=False,
)
return nodes
class FastFlow(nn.Module):
def __init__(
self,
backbone_name,
flow_steps,
input_size,
conv3x3_only=False,
hidden_ratio=1.0,
):
super(FastFlow, self).__init__()
assert (
backbone_name in const.SUPPORTED_BACKBONES
), "backbone_name must be one of {}".format(const.SUPPORTED_BACKBONES)
if backbone_name in [const.BACKBONE_CAIT, const.BACKBONE_DEIT]:
self.feature_extractor = timm.create_model(backbone_name, pretrained=True)
channels = [768]
scales = [16]
else:
self.feature_extractor = timm.create_model(
backbone_name,
pretrained=True,
features_only=True,
out_indices=[1, 2, 3],
)
channels = self.feature_extractor.feature_info.channels()
scales = self.feature_extractor.feature_info.reduction()
# for transformers, use their pretrained norm w/o grad
# for resnets, self.norms are trainable LayerNorm
self.norms = nn.ModuleList()
for in_channels, scale in zip(channels, scales):
self.norms.append(
nn.LayerNorm(
[in_channels, int(input_size / scale), int(input_size / scale)],
elementwise_affine=True,
)
)
for param in self.feature_extractor.parameters():
param.requires_grad = False
self.nf_flows = nn.ModuleList()
for in_channels, scale in zip(channels, scales):
self.nf_flows.append(
nf_fast_flow(
[in_channels, int(input_size / scale), int(input_size / scale)],
conv3x3_only=conv3x3_only,
hidden_ratio=hidden_ratio,
flow_steps=flow_steps,
)
)
self.input_size = input_size
def forward(self, x):
self.feature_extractor.eval()
if isinstance(
self.feature_extractor, timm.models.vision_transformer.VisionTransformer
):
x = self.feature_extractor.patch_embed(x)
cls_token = self.feature_extractor.cls_token.expand(x.shape[0], -1, -1)
if self.feature_extractor.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat(
(
cls_token,
self.feature_extractor.dist_token.expand(x.shape[0], -1, -1),
x,
),
dim=1,
)
x = self.feature_extractor.pos_drop(x + self.feature_extractor.pos_embed)
for i in range(8): # paper Table 6. Block Index = 7
x = self.feature_extractor.blocks[i](x)
x = self.feature_extractor.norm(x)
x = x[:, 2:, :]
N, _, C = x.shape
x = x.permute(0, 2, 1)
x = x.reshape(N, C, self.input_size // 16, self.input_size // 16)
features = [x]
elif isinstance(self.feature_extractor, timm.models.cait.Cait):
x = self.feature_extractor.patch_embed(x)
x = x + self.feature_extractor.pos_embed
x = self.feature_extractor.pos_drop(x)
for i in range(41): # paper Table 6. Block Index = 40
x = self.feature_extractor.blocks[i](x)
N, _, C = x.shape
x = self.feature_extractor.norm(x)
x = x.permute(0, 2, 1)
x = x.reshape(N, C, self.input_size // 16, self.input_size // 16)
features = [x]
else:
features = self.feature_extractor(x)
features = [self.norms[i](feature) for i, feature in enumerate(features)]
loss = 0
outputs = []
for i, feature in enumerate(features):
output, log_jac_dets = self.nf_flows[i](feature)
loss += torch.mean(
0.5 * torch.sum(output**2, dim=(1, 2, 3)) - log_jac_dets
)
outputs.append(output)
ret = {"loss": loss}
if not self.training:
anomaly_map_list = []
for output in outputs:
log_prob = -torch.mean(output**2, dim=1, keepdim=True) * 0.5
prob = torch.exp(log_prob)
a_map = F.interpolate(
-prob,
size=[self.input_size, self.input_size],
mode="bilinear",
align_corners=False,
)
anomaly_map_list.append(a_map)
anomaly_map_list = torch.stack(anomaly_map_list, dim=-1)
anomaly_map = torch.mean(anomaly_map_list, dim=-1)
ret["anomaly_map"] = anomaly_map
return ret