diff --git a/robosat/scse.py b/robosat/scse.py new file mode 100644 index 00000000..8eeade41 --- /dev/null +++ b/robosat/scse.py @@ -0,0 +1,63 @@ +"""Squeeze and Excitation blocks - attention for classification and segmentation + +See: +- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation Networks +- https://arxiv.org/abs/1803.02579 - Concurrent Spatial and Channel 'Squeeze & Excitation' in Fully Convolutional Networks + +""" + +import torch +import torch.nn as nn + + +class SpatialSqChannelEx(nn.Module): + """Spatial Squeeze and Channel Excitation (cSE) block + See https://arxiv.org/abs/1803.02579 Figure 1 b + """ + + def __init__(self, num_in, r): + super().__init__() + self.fc0 = Conv1x1(num_in, num_in // r) + self.fc1 = Conv1x1(num_in // r, num_in) + + def forward(self, x): + xx = nn.functional.adaptive_avg_pool2d(x, 1) + xx = self.fc0(xx) + xx = nn.functional.relu(xx, inplace=True) + xx = self.fc1(xx) + xx = torch.sigmoid(xx) + return x * xx + + +class ChannelSqSpatialEx(nn.Module): + """Channel Squeeze and Spatial Excitation (sSE) block + See https://arxiv.org/abs/1803.02579 Figure 1 c + """ + + def __init__(self, num_in): + super().__init__() + self.conv = Conv1x1(num_in, 1) + + def forward(self, x): + xx = self.conv(x) + xx = torch.sigmoid(xx) + return x * xx + + +class SpatialChannelSqChannelEx(nn.Module): + """Concurrent Spatial and Channel Squeeze and Channel Excitation (csSE) block + See https://arxiv.org/abs/1803.02579 Figure 1 d + """ + + def __init__(self, num_in, r=16): + super().__init__() + + self.cse = SpatialSqChannelEx(num_in, r) + self.sse = ChannelSqSpatialEx(num_in) + + def forward(self, x): + return self.cse(x) + self.sse(x) + + +def Conv1x1(num_in, num_out): + return nn.Conv2d(num_in, num_out, kernel_size=1, bias=False) diff --git a/robosat/unet.py b/robosat/unet.py index 1acc4ac7..9b0ddaf1 100644 --- a/robosat/unet.py +++ b/robosat/unet.py @@ -14,6 +14,8 @@ from torchvision.models import resnet50 +from robosat.scse import SpatialChannelSqChannelEx + class ConvRelu(nn.Module): """3x3 convolution followed by ReLU activation building block. @@ -91,10 +93,23 @@ def __init__(self, num_classes, num_filters=32, pretrained=True): # Todo: make input channels configurable, not hard-coded to three channels for RGB - self.resnet = resnet50(pretrained=pretrained) - # Access resnet directly in forward pass; do not store refs here due to # https://github.com/pytorch/pytorch/issues/8392 + self.resnet = resnet50(pretrained=pretrained) + + # seSE blocks to append to encoder and decoder as recommended by + # https://arxiv.org/abs/1803.02579 + self.scse0 = SpatialChannelSqChannelEx(64) + self.scse1 = SpatialChannelSqChannelEx(256) + self.scse2 = SpatialChannelSqChannelEx(512) + self.scse3 = SpatialChannelSqChannelEx(1024) + self.scse4 = SpatialChannelSqChannelEx(2048) + + self.scse5 = SpatialChannelSqChannelEx(num_filters * 8) + self.scse6 = SpatialChannelSqChannelEx(num_filters * 8) + self.scse7 = SpatialChannelSqChannelEx(num_filters * 2) + self.scse8 = SpatialChannelSqChannelEx(num_filters * 2 * 2) + self.scse9 = SpatialChannelSqChannelEx(num_filters) self.center = DecoderBlock(2048, num_filters * 8) @@ -122,20 +137,21 @@ def forward(self, x): enc0 = self.resnet.conv1(x) enc0 = self.resnet.bn1(enc0) enc0 = self.resnet.relu(enc0) + enc0 = self.scse0(enc0) enc0 = self.resnet.maxpool(enc0) - enc1 = self.resnet.layer1(enc0) - enc2 = self.resnet.layer2(enc1) - enc3 = self.resnet.layer3(enc2) - enc4 = self.resnet.layer4(enc3) + enc1 = self.scse1(self.resnet.layer1(enc0)) + enc2 = self.scse2(self.resnet.layer2(enc1)) + enc3 = self.scse3(self.resnet.layer3(enc2)) + enc4 = self.scse4(self.resnet.layer4(enc3)) center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2)) - dec0 = self.dec0(torch.cat([enc4, center], dim=1)) - dec1 = self.dec1(torch.cat([enc3, dec0], dim=1)) - dec2 = self.dec2(torch.cat([enc2, dec1], dim=1)) - dec3 = self.dec3(torch.cat([enc1, dec2], dim=1)) - dec4 = self.dec4(dec3) + dec0 = self.scse5(self.dec0(torch.cat([enc4, center], dim=1))) + dec1 = self.scse6(self.dec1(torch.cat([enc3, dec0], dim=1))) + dec2 = self.scse7(self.dec2(torch.cat([enc2, dec1], dim=1))) + dec3 = self.scse8(self.dec3(torch.cat([enc1, dec2], dim=1))) + dec4 = self.scse9(self.dec4(dec3)) dec5 = self.dec5(dec4) return self.final(dec5)