From e10b17e2918560a4722a418105ac28ea66a2e03a Mon Sep 17 00:00:00 2001 From: "Daniel J. Hofmann" Date: Wed, 22 May 2019 00:28:54 +0200 Subject: [PATCH] Adds squeeze and excitation (scSE) modules, resolves #157 --- robosat/scse.py | 63 +++++++++++++++++++++++++++++++++++++++++++++++++ robosat/unet.py | 2 ++ 2 files changed, 65 insertions(+) create mode 100644 robosat/scse.py diff --git a/robosat/scse.py b/robosat/scse.py new file mode 100644 index 00000000..efb23dc3 --- /dev/null +++ b/robosat/scse.py @@ -0,0 +1,63 @@ +"""Squeeze and Excitation blocks - attention for classification and segmentation + +See: +- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation Networks +- https://arxiv.org/abs/1803.02579 - Concurrent Spatial and Channel 'Squeeze & Excitation' in Fully Convolutional Networks + +""" + +import torch +import torch.nn as nn + + +class SpatialSqChannelEx: + """Spatial Squeeze and Channel Excitation (cSE) block + See https://arxiv.org/abs/1803.02579 Figure 1 b + """ + + def __init__(self, num_in, r): + super().__init__() + self.fc0 = Conv1x1(num_in, num_in // r) + self.fc1 = Conv1x1(num_in // r, num_in) + + def forward(self, x): + xx = nn.functional.adaptive_avg_pool2d(x, 1) + xx = self.fc0(xx) + xx = nn.functional.relu(xx, inplace=True) + xx = self.fc1(xx) + xx = nn.functional.sigmoid(xx) + return x * xx + + +class ChannelSqSpatialEx: + """Channel Squeeze and Spatial Excitation (sSE) block + See https://arxiv.org/abs/1803.02579 Figure 1 c + """ + + def __init__(self, num_in): + super().__init__() + self.conv = Conv1x1(num_in, 1) + + def forward(self, x): + xx = self.conv(x) + xx = nn.functional.sigmoid(xx) + return x * xx + + +class SpatialChannelSqChannelEx: + """Concurrent Spatial and Channel Squeeze and Channel Excitation (csSE) block + See https://arxiv.org/abs/1803.02579 Figure 1 d + """ + + def __init__(self, num_in, r=16): + super().__init__() + + self.cse = SpatialSqChannelEx(num_in, r) + self.sse = ChannelSqSpatialEx(num_in) + + def forward(self, x): + return self.cse(x) + self.sse(x) + + +def Conv1x1(num_in, num_out): + return nn.Conv2d(num_in, num_out, kernel_size=1, bias=False) diff --git a/robosat/unet.py b/robosat/unet.py index bedabea1..8ca9d6c0 100644 --- a/robosat/unet.py +++ b/robosat/unet.py @@ -14,6 +14,8 @@ from torchvision.models import resnet50 +from robosat.scse import SpatialChannelSqChannelEx + class ConvRelu(nn.Module): """3x3 convolution followed by ReLU activation building block.