CBAM, a simple yet effective attention module for feed-forward convolutional neural networks. It is a lightweight and general module that can be integrated into any CNN architectures seamlessly and is end-to-end trainable along with base CNNs. CBAM is used to increase representation power by using attention mechanism: focusing on important features and suppressing unnecessary ones. To achieve this, we sequentially apply channel and spatial attention modules so that each of the branches can learn 'what' and 'where' to attend in the channel and spatial axes respectively.
class CBAM(nn.Module):
def __init__(self, channels, r):
super(CBAM, self).__init__()
self.channels = channels
self.r = r
self.sam = SAM(bias=False)
self.cam = CAM(channels=self.channels, r=self.r)
def forward(self, x):
output = self.cam(x)
output = self.sam(output)
return output + x
CAM generate a channel attention map by exploiting the inter-channel relationship of features. As each channel of a feature map is considered as a feature detector, channel attention focuses on 'what' is meaningful given an input image.
class CAM(nn.Module):
def __init__(self, channels, r):
super(CAM, self).__init__()
self.channels = channels
self.r = r
self.linear = nn.Sequential(
nn.Linear(in_features=self.channels, out_features=self.channels//self.r, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=self.channels//self.r, out_features=self.channels, bias=True))
def forward(self, x):
max = F.adaptive_max_pool2d(x, output_size=1)
avg = F.adaptive_avg_pool2d(x, output_size=1)
b, c, _, _ = x.size()
linear_max = self.linear(max.view(b,c)).view(b, c, 1, 1)
linear_avg = self.linear(avg.view(b,c)).view(b, c, 1, 1)
output = linear_max + linear_avg
output = F.sigmoid(output) * x
return output
SAM generate a spatial attention map by utilizing the inter-spatial relationship of features. Different from the channel attention, the spatial attention focuses on 'where' is an informative part, which is complementary to the channel attention.
class SAM(nn.Module):
def __init__(self, bias=False):
super(SAM, self).__init__()
self.bias = bias
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, bias=self.bias)
def forward(self, x):
max = torch.max(x,1)[0].unsqueeze(1)
avg = torch.mean(x,1).unsqueeze(1)
concat = torch.cat((max,avg), dim=1)
output = self.conv(concat)
output = F.sigmoid(output) * x
return output
Evaluated CBAM using VGG19 model with BatchNorm on CIFAR100 dataset.
The configurations used to train VGG:
- AdamW Optimizer
- 0.0001 Learning Rate
- 0.005 Weight Decay
- Cross Entropy Loss
- 50 Epochs
Model | Top-1 Acc | Top-5 Acc |
---|---|---|
VGG19 | 39.87% | 67.69% |
VGG19 + CBAM 1 | 40.95% | 68.24% |
To implement CBAM, please use cbam.py
file that I provided.
More detailed information can be seen using the link down below:
https://arxiv.org/abs/1807.06521
Footnotes
-
Reduction ratio of 4 was used during training ↩