Skip to content

Commit

Permalink
Adds squeeze and excitation (scSE) modules, resolves #157
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed May 29, 2019
1 parent 54e20dc commit 7f00fc7
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
63 changes: 63 additions & 0 deletions robosat/scse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Squeeze and Excitation blocks - attention for classification and segmentation
See:
- https://arxiv.org/abs/1709.01507 - Squeeze-and-Excitation Networks
- https://arxiv.org/abs/1803.02579 - Concurrent Spatial and Channel 'Squeeze & Excitation' in Fully Convolutional Networks
"""

import torch
import torch.nn as nn


class SpatialSqChannelEx:
"""Spatial Squeeze and Channel Excitation (cSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 b
"""

def __init__(self, num_in, r):
super().__init__()
self.fc0 = Conv1x1(num_in, num_in // r)
self.fc1 = Conv1x1(num_in // r, num_in)

def forward(self, x):
xx = nn.functional.adaptive_avg_pool2d(x, 1)
xx = self.fc0(xx)
xx = nn.functional.relu(xx, inplace=True)
xx = self.fc1(xx)
xx = nn.functional.sigmoid(xx)
return x * xx


class ChannelSqSpatialEx:
"""Channel Squeeze and Spatial Excitation (sSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 c
"""

def __init__(self, num_in):
super().__init__()
self.conv = Conv1x1(num_in, 1)

def forward(self, x):
xx = self.conv(x)
xx = nn.functional.sigmoid(xx)
return x * xx


class SpatialChannelSqChannelEx:
"""Concurrent Spatial and Channel Squeeze and Channel Excitation (csSE) block
See https://arxiv.org/abs/1803.02579 Figure 1 d
"""

def __init__(self, num_in, r=16):
super().__init__()

self.cse = SpatialSqChannelEx(num_in, r)
self.sse = ChannelSqSpatialEx(num_in)

def forward(self, x):
return self.cse(x) + self.sse(x)


def Conv1x1(num_in, num_out):
return nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)
2 changes: 2 additions & 0 deletions robosat/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from torchvision.models import resnet50

from robosat.scse import SpatialChannelSqChannelEx


class ConvRelu(nn.Module):
"""3x3 convolution followed by ReLU activation building block.
Expand Down

0 comments on commit 7f00fc7

Please sign in to comment.