Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP API for Yolact, python package #323

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
5 changes: 4 additions & 1 deletion data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def print(self):


# ----------------------- MASK BRANCH TYPES ----------------------- #

#FIXME cannot import mask_type ?
mask_type_DIRECT = 0
mask_type_LINCOMB = 1
mask_type = Config({
# Direct produces masks directly as the output of each pred module.
# This is denoted as fc-mask in the paper.
Expand Down Expand Up @@ -819,6 +821,7 @@ def set_cfg(config_name:str):

if cfg.name is None:
cfg.name = config_name.split('_config')[0]
return cfg

def set_dataset(dataset_name:str):
""" Sets the dataset of the current config. """
Expand Down
13 changes: 13 additions & 0 deletions layers/modules/concat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch, torchvision
import torch.nn as nn

class Concat(nn.Module):
def __init__(self, nets, extra_params):
super().__init__()

self.nets = nn.ModuleList(nets)
self.extra_params = extra_params

def forward(self, x):
# Concat each along the channel dimension
return torch.cat([net(x) for net in self.nets], dim=1, **self.extra_params)
33 changes: 33 additions & 0 deletions layers/modules/fast_mask_iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch import nn
import torch.nn.functional as F

#locals
from data.config import Config
from utils.functions import make_net


# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules
use_jit = torch.cuda.device_count() <= 1
if not use_jit:
print('Multiple GPUs detected! Turning off JIT.')

ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module
script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part should probably go in its own file, since it should only be loaded once and might need to be used in multiple places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why that's important, but like this fccd89b ?



class FastMaskIoUNet(ScriptModuleWrapper):

def __init__(self, config:Config):
super().__init__()

cfg = config
input_channels = 1
last_layer = [(cfg.num_classes-1, 1, {})]
self.maskiou_net, _ = make_net(input_channels, cfg.maskiou_net + last_layer, include_last_relu=True)

def forward(self, x):
x = self.maskiou_net(x)
maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1)
return maskiou_p

119 changes: 119 additions & 0 deletions layers/modules/fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
from torch import nn
import torch.nn.functional as F


from typing import List

#local imports
from data.config import Config

# As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules
use_jit = torch.cuda.device_count() <= 1
if not use_jit:
print('Multiple GPUs detected! Turning off JIT.')

ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module
script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn



class FPN(ScriptModuleWrapper):
"""
Implements a general version of the FPN introduced in
https://arxiv.org/pdf/1612.03144.pdf

Parameters (in cfg.fpn):
- num_features (int): The number of output features in the fpn layers.
- interpolation_mode (str): The mode to pass to F.interpolate.
- num_downsample (int): The number of downsampled layers to add onto the selected layers.
These extra layers are downsampled from the last selected layer.

Args:
- in_channels (list): For each conv layer you supply in the forward pass,
how many features will it have?
"""
__constants__ = ['interpolation_mode', 'num_downsample', 'use_conv_downsample', 'relu_pred_layers',
'lat_layers', 'pred_layers', 'downsample_layers', 'relu_downsample_layers']

def __init__(self, in_channels, config:Config):
super().__init__()

cfg = config

self.lat_layers = nn.ModuleList([
nn.Conv2d(x, cfg.fpn.num_features, kernel_size=1)
for x in reversed(in_channels)
])

# This is here for backwards compatability
padding = 1 if cfg.fpn.pad else 0
self.pred_layers = nn.ModuleList([
nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=padding)
for _ in in_channels
])

if cfg.fpn.use_conv_downsample:
self.downsample_layers = nn.ModuleList([
nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=1, stride=2)
for _ in range(cfg.fpn.num_downsample)
])

self.interpolation_mode = cfg.fpn.interpolation_mode
self.num_downsample = cfg.fpn.num_downsample
self.use_conv_downsample = cfg.fpn.use_conv_downsample
self.relu_downsample_layers = cfg.fpn.relu_downsample_layers
self.relu_pred_layers = cfg.fpn.relu_pred_layers

@script_method_wrapper
def forward(self, convouts:List[torch.Tensor]):
"""
Args:
- convouts (list): A list of convouts for the corresponding layers in in_channels.
Returns:
- A list of FPN convouts in the same order as x with extra downsample layers if requested.
"""

out = []
x = torch.zeros(1, device=convouts[0].device)
for i in range(len(convouts)):
out.append(x)

# For backward compatability, the conv layers are stored in reverse but the input and output is
# given in the correct order. Thus, use j=-i-1 for the input and output and i for the conv layers.
j = len(convouts)
for lat_layer in self.lat_layers:
j -= 1

if j < len(convouts) - 1:
_, _, h, w = convouts[j].size()
x = F.interpolate(x, size=(h, w), mode=self.interpolation_mode, align_corners=False)

x = x + lat_layer(convouts[j])
out[j] = x

# This janky second loop is here because TorchScript.
j = len(convouts)
for pred_layer in self.pred_layers:
j -= 1
out[j] = pred_layer(out[j])

if self.relu_pred_layers:
F.relu(out[j], inplace=True)

cur_idx = len(out)

# In the original paper, this takes care of P6
if self.use_conv_downsample:
for downsample_layer in self.downsample_layers:
out.append(downsample_layer(out[-1]))
else:
for idx in range(self.num_downsample):
# Note: this is an untested alternative to out.append(out[-1][:, :, ::2, ::2]). Thanks TorchScript.
out.append(nn.functional.max_pool2d(out[-1], 1, stride=2))

if self.relu_downsample_layers:
for idx in range(len(out) - cur_idx):
out[idx] = F.relu(out[idx + cur_idx], inplace=False)

return out
Loading