-
Notifications
You must be signed in to change notification settings - Fork 4
/
unet3d.py
136 lines (111 loc) · 5.72 KB
/
unet3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Adapted from https://github.com/milesial/Pytorch-UNet/tree/master/unet"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, width_multiplier=1, trilinear=True, use_ds_conv=False):
"""A simple 3D Unet, adapted from a 2D Unet from https://github.com/milesial/Pytorch-UNet/tree/master/unet
Arguments:
n_channels = number of input channels; 3 for RGB, 1 for grayscale input
n_classes = number of output channels/classes
width_multiplier = how much 'wider' your UNet should be compared with a standard UNet
default is 1;, meaning 32 -> 64 -> 128 -> 256 -> 512 -> 256 -> 128 -> 64 -> 32
higher values increase the number of kernels pay layer, by that factor
trilinear = use trilinear interpolation to upsample; if false, 3D convtranspose layers will be used instead
use_ds_conv = if True, we use depthwise-separable convolutional layers. in my experience, this is of little help. This
appears to be because with 3D data, the vast vast majority of GPU RAM is the input data/labels, not the params, so little
VRAM is saved by using ds_conv, and yet performance suffers."""
super(UNet, self).__init__()
_channels = (32, 64, 128, 256, 512)
self.n_channels = n_channels
self.n_classes = n_classes
self.channels = [int(c*width_multiplier) for c in _channels]
self.trilinear = trilinear
self.convtype = DepthwiseSeparableConv3d if use_ds_conv else nn.Conv3d
self.inc = DoubleConv(n_channels, self.channels[0], conv_type=self.convtype)
self.down1 = Down(self.channels[0], self.channels[1], conv_type=self.convtype)
self.down2 = Down(self.channels[1], self.channels[2], conv_type=self.convtype)
self.down3 = Down(self.channels[2], self.channels[3], conv_type=self.convtype)
factor = 2 if trilinear else 1
self.down4 = Down(self.channels[3], self.channels[4] // factor, conv_type=self.convtype)
self.up1 = Up(self.channels[4], self.channels[3] // factor, trilinear)
self.up2 = Up(self.channels[3], self.channels[2] // factor, trilinear)
self.up3 = Up(self.channels[2], self.channels[1] // factor, trilinear)
self.up4 = Up(self.channels[1], self.channels[0], trilinear)
self.outc = OutConv(self.channels[0], n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, conv_type=nn.Conv3d, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
conv_type(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(mid_channels),
nn.ReLU(inplace=True),
conv_type(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels, conv_type=nn.Conv3d):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2),
DoubleConv(in_channels, out_channels, conv_type=conv_type)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, trilinear=True):
super().__init__()
# if trilinear, use the normal convolutions to reduce the number of channels
if trilinear:
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, mid_channels=in_channels // 2)
else:
self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class DepthwiseSeparableConv3d(nn.Module):
def __init__(self, nin, nout, kernel_size, padding, kernels_per_layer=1):
super(DepthwiseSeparableConv3d, self).__init__()
self.depthwise = nn.Conv3d(nin, nin * kernels_per_layer, kernel_size=kernel_size, padding=padding, groups=nin)
self.pointwise = nn.Conv3d(nin * kernels_per_layer, nout, kernel_size=1)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out