Skip to content

Commit

Permalink
Implements Feature Pyramid Network (FPN), closes #60
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed Oct 15, 2018
1 parent b7f6ebf commit 717e1a5
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 10 deletions.
141 changes: 141 additions & 0 deletions robosat/fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Feature Pyramid Network (FPN) on top of ResNet. Comes with task-specific heads on top of it.
See:
- https://arxiv.org/abs/1612.03144 - Feature Pyramid Networks for Object Detection
- http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf - A Unified Architecture for Instance
and Semantic Segmentation
"""

import torch
import torch.nn as nn

from torchvision.models import resnet50


class FPN(nn.Module):
"""Feature Pyramid Network (FPN): top-down architecture with lateral connections.
Can be used as feature extractor for object detection or segmentation.
"""

def __init__(self, num_filters=256, pretrained=True):
"""Creates an `FPN` instance for feature extraction.
Args:
num_filters: the number of filters in each output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""

super().__init__()

self.resnet = resnet50(pretrained=pretrained)

# Access resnet directly in forward pass; do not store refs here due to
# https://github.com/pytorch/pytorch/issues/8392

self.lateral4 = Conv1x1(2048, num_filters)
self.lateral3 = Conv1x1(1024, num_filters)
self.lateral2 = Conv1x1(512, num_filters)
self.lateral1 = Conv1x1(256, num_filters)

self.smooth4 = Conv3x3(num_filters, num_filters)
self.smooth3 = Conv3x3(num_filters, num_filters)
self.smooth2 = Conv3x3(num_filters, num_filters)
self.smooth1 = Conv3x3(num_filters, num_filters)

def forward(self, x):
# Bottom-up pathway, from ResNet

enc0 = self.resnet.conv1(x)
enc0 = self.resnet.bn1(enc0)
enc0 = self.resnet.relu(enc0)
enc0 = self.resnet.maxpool(enc0)

enc1 = self.resnet.layer1(enc0)
enc2 = self.resnet.layer2(enc1)
enc3 = self.resnet.layer3(enc2)
enc4 = self.resnet.layer4(enc3)

# Lateral connections

lateral4 = self.lateral4(enc4)
lateral3 = self.lateral3(enc3)
lateral2 = self.lateral2(enc2)
lateral1 = self.lateral1(enc1)

# Top-down pathway

map4 = lateral4
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")

# Reduce aliasing effect of upsampling

map4 = self.smooth4(map4)
map3 = self.smooth3(map3)
map2 = self.smooth2(map2)
map1 = self.smooth1(map1)

return map1, map2, map3, map4


class FPNSegmentation(nn.Module):
"""Semantic segmentation model on top of a Feature Pyramid Network (FPN).
"""

def __init__(self, num_classes, num_filters=128, num_filters_fpn=256, pretrained=True):
"""Creates an `FPNSegmentation` instance for feature extraction.
Args:
num_classes: number of classes to predict
num_filters: the number of filters in each segmentation head pyramid level
num_filters_fpn: the number of filters in each FPN output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""

super().__init__()

# Feature Pyramid Network (FPN) with four feature maps of resolutions
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.

self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)

# The segmentation heads on top of the FPN

self.head1 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head2 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head3 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))
self.head4 = nn.Sequential(Conv3x3(num_filters_fpn, num_filters), Conv3x3(num_filters, num_filters))

self.final = nn.Conv2d(4 * num_filters, num_classes, kernel_size=3, padding=1)

def forward(self, x):
map1, map2, map3, map4 = self.fpn(x)

map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
map1 = self.head1(map1)

final = self.final(torch.cat([map4, map3, map2, map1], dim=1))

return nn.functional.upsample(final, scale_factor=4, mode="bilinear", align_corners=False)


class Conv1x1(nn.Module):
def __init__(self, num_in, num_out):
super().__init__()
self.block = nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)

def forward(self, x):
return self.block(x)


class Conv3x3(nn.Module):
def __init__(self, num_in, num_out):
super().__init__()
self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False)

def forward(self, x):
return self.block(x)
4 changes: 2 additions & 2 deletions robosat/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.autograd

from robosat.config import load_config
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation


def add_parser(subparser):
Expand All @@ -25,7 +25,7 @@ def main(args):
dataset = load_config(args.dataset)

num_classes = len(dataset["common"]["classes"])
net = UNet(num_classes)
net = FPNSegmentation(num_classes)

def map_location(storage, _):
return storage.cpu()
Expand Down
7 changes: 4 additions & 3 deletions robosat/tools/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from PIL import Image

from robosat.datasets import BufferedSlippyMapDirectory
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.config import load_config
from robosat.colors import continuous_palette_for_color
from robosat.transforms import ConvertImageMode, ImageToTensor
Expand Down Expand Up @@ -59,8 +59,9 @@ def map_location(storage, _):
# https://github.com/pytorch/pytorch/issues/7178
chkpt = torch.load(args.checkpoint, map_location=map_location)

net = UNet(num_classes).to(device)
net = nn.DataParallel(net)
net = FPNSegmentation(num_classes)
net = DataParallel(net)
net = net.to(device)

if cuda:
torch.backends.cudnn.benchmark = True
Expand Down
7 changes: 4 additions & 3 deletions robosat/tools/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from flask import Flask, send_file, render_template, abort

from robosat.tiles import fetch_image
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.config import load_config
from robosat.colors import make_palette
from robosat.transforms import ConvertImageMode, ImageToTensor
Expand Down Expand Up @@ -180,8 +180,9 @@ def map_location(storage, _):

num_classes = len(self.dataset["common"]["classes"])

net = UNet(num_classes).to(self.device)
net = nn.DataParallel(net)
net = FPNSegmentation(num_classes)
net = DataParallel(net)
net = net.to(device)

if self.cuda:
torch.backends.cudnn.benchmark = True
Expand Down
4 changes: 2 additions & 2 deletions robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from robosat.datasets import SlippyMapTilesConcatenation
from robosat.metrics import Metrics
from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d
from robosat.unet import UNet
from robosat.fpn import FPNSegmentation
from robosat.utils import plot
from robosat.config import load_config
from robosat.log import Log
Expand Down Expand Up @@ -68,7 +68,7 @@ def main(args):
os.makedirs(model["common"]["checkpoint"], exist_ok=True)

num_classes = len(dataset["common"]["classes"])
net = UNet(num_classes)
net = FPNSegmentation(num_classes)
net = DataParallel(net)
net = net.to(device)

Expand Down
1 change: 1 addition & 0 deletions robosat/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):
Args:
num_classes: number of classes to predict.
num_filters: the number of filters for the decoder block
pretrained: use ImageNet pre-trained backbone feature extractor
"""

Expand Down

0 comments on commit 717e1a5

Please sign in to comment.