Skip to content

Commit

Permalink
Implements Feature Pyramid Network (FPN), closes #60
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h authored and Bhargav Kowshik committed Aug 10, 2018
1 parent 495edfb commit 6f8cd87
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 15 deletions.
141 changes: 141 additions & 0 deletions robosat/fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Feature Pyramid Network (FPN) on top of ResNet. Comes with task-specific heads on top of it.
See:
- https://arxiv.org/abs/1612.03144 - Feature Pyramid Networks for Object Detection
- http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf - A Unified Architecture for Instance
and Semantic Segmentation
"""

import torch
import torch.nn as nn

from torchvision.models import resnet50


class FPN(nn.Module):
"""Feature Pyramid Network (FPN): top-down architecture with lateral connections.
Can be used as feature extractor for object detection or segmentation.
"""

def __init__(self, num_filters=256, pretrained=True):
"""Creates an `FPN` instance for feature extraction.
Args:
num_filters: the number of filters in each output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""

super().__init__()

self.resnet = resnet50(pretrained=pretrained)

# Access resnet directly in forward pass; do not store refs here due to
# https://github.com/pytorch/pytorch/issues/8392

self.lateral4 = Conv1x1(2048, num_filters)
self.lateral3 = Conv1x1(1024, num_filters)
self.lateral2 = Conv1x1(512, num_filters)
self.lateral1 = Conv1x1(256, num_filters)

self.smooth4 = Conv3x3(num_filters, num_filters)
self.smooth3 = Conv3x3(num_filters, num_filters)
self.smooth2 = Conv3x3(num_filters, num_filters)
self.smooth1 = Conv3x3(num_filters, num_filters)

def forward(self, x):
# Bottom-up pathway, from ResNet

enc0 = self.resnet.conv1(x)
enc0 = self.resnet.bn1(enc0)
enc0 = self.resnet.relu(enc0)
enc0 = self.resnet.maxpool(enc0)

enc1 = self.resnet.layer1(enc0)
enc2 = self.resnet.layer2(enc1)
enc3 = self.resnet.layer3(enc2)
enc4 = self.resnet.layer4(enc3)

# Lateral connections

lateral4 = self.lateral4(enc4)
lateral3 = self.lateral3(enc3)
lateral2 = self.lateral2(enc2)
lateral1 = self.lateral1(enc1)

# Top-down pathway

map4 = lateral4
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")

# Reduce aliasing effect of upsampling

map4 = self.smooth4(map4)
map3 = self.smooth3(map3)
map2 = self.smooth2(map2)
map1 = self.smooth1(map1)

return map1, map2, map3, map4


class FPNSegmentation(nn.Module):
"""Semantic segmentation model on top of a Feature Pyramid Network (FPN).
"""

def __init__(self, num_classes, num_filters=128, num_filters_fpn=256, pretrained=True):
"""Creates an `FPNSegmentation` instance for feature extraction.
Args:
num_classes: number of classes to predict
num_filters: the number of filters in each segmentation head pyramid level
num_filters_fpn: the number of filters in each FPN output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""

super().__init__()

# Feature Pyramid Network (FPN) with four feature maps of resolutions
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.

self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)

# The segmentation heads on top of the FPN

self.head1 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head2 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head3 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head4 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))

self.final = nn.Conv2d(4 * num_filters, num_classes, kernel_size=3, padding=1)

def forward(self, x):
map1, map2, map3, map4 = self.fpn(x)

map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
map1 = self.head1(map1)

final = self.final(torch.cat([map4, map3, map2, map1], dim=1))

return nn.functional.upsample(final, scale_factor=4, mode="bilinear", align_corners=False)


class Conv1x1(nn.Module):
def __init__(self, num_in, num_out):
super().__init__()
self.block = nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)

def forward(self, x):
return self.block(x)


class Conv3x3(nn.Module):
def __init__(self, num_in, num_out):
super().__init__()
self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False)

def forward(self, x):
return self.block(x)
4 changes: 2 additions & 2 deletions robosat/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.autograd

from robosat.config import load_config
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation


def add_parser(subparser):
Expand All @@ -25,7 +25,7 @@ def main(args):
dataset = load_config(args.dataset)

num_classes = len(dataset["common"]["classes"])
net = UNet(num_classes)
net = FPNSegmentation(num_classes)

def map_location(storage, _):
return storage.cpu()
Expand Down
4 changes: 2 additions & 2 deletions robosat/tools/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from PIL import Image

from robosat.datasets import BufferedSlippyMapDirectory
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.config import load_config
from robosat.colors import continuous_palette_for_color
from robosat.transforms import ConvertImageMode, ImageToTensor
Expand Down Expand Up @@ -59,7 +59,7 @@ def map_location(storage, _):
# https://github.com/pytorch/pytorch/issues/7178
chkpt = torch.load(args.checkpoint, map_location=map_location)

net = UNet(num_classes).to(device)
net = FPNSegmentation(num_classes).to(device)
net = nn.DataParallel(net)

if cuda:
Expand Down
4 changes: 2 additions & 2 deletions robosat/tools/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flask import Flask, send_file, render_template, abort

from robosat.tiles import fetch_image
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.config import load_config
from robosat.colors import make_palette
from robosat.transforms import ConvertImageMode, ImageToTensor
Expand Down Expand Up @@ -180,7 +180,7 @@ def map_location(storage, _):

num_classes = len(self.dataset["common"]["classes"])

net = UNet(num_classes).to(self.device)
net = FPNSegmentation(num_classes).to(self.device)
net = nn.DataParallel(net)

if self.cuda:
Expand Down
11 changes: 2 additions & 9 deletions robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from robosat.datasets import SlippyMapTilesConcatenation
from robosat.metrics import MeanIoU
from robosat.losses import CrossEntropyLoss2d
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.utils import plot
from robosat.config import load_config

Expand All @@ -51,25 +51,18 @@ def main(args):
if model["common"]["cuda"] and not torch.cuda.is_available():
sys.exit("Error: CUDA requested but not available")

# if args.batch_size < 2:
# sys.exit('Error: PSPNet requires more than one image for BatchNorm in Pyramid Pooling')

os.makedirs(model["common"]["checkpoint"], exist_ok=True)

num_classes = len(dataset["common"]["classes"])
net = UNet(num_classes)
net = DataParallel(net)
net = net.to(device)
net = FPNSegmentation(num_classes).to(device)

if model["common"]["cuda"]:
torch.backends.cudnn.benchmark = True

optimizer = Adam(net.parameters(), lr=model["opt"]["lr"], weight_decay=model["opt"]["decay"])

weight = torch.Tensor(dataset["weights"]["values"])

criterion = CrossEntropyLoss2d(weight=weight).to(device)
# criterion = FocalLoss2d(weight=weight).to(device)

train_loader, val_loader = get_dataset_loaders(model, dataset)

Expand Down
1 change: 1 addition & 0 deletions robosat/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):
Args:
num_classes: number of classes to predict.
num_filters: the number of filters for the decoder block
pretrained: use ImageNet pre-trained backbone feature extractor
"""

Expand Down

0 comments on commit 6f8cd87

Please sign in to comment.