forked from aetherks/TideExtract
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunet.py
60 lines (56 loc) · 2.31 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Adapted from: https://github.com/milesial/Pytorch-UNet
#Changes involve
# 1) adding an Nbase parameter which changes the UNet
# size without changing the topology and
# 2) Adding an inpuy Batch normalization layer
from unet_parts_t import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False, Nbase = 16, inpBNFlag = False, n_emb = 100):
super(UNet, self).__init__()
self.tempEmb = TemporalEncoding(n_emb = n_emb)
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inpBNFlag = inpBNFlag
if inpBNFlag:
self.inpBN = nn.BatchNorm2d(n_channels)
self.inc = DoubleConvTime(n_channels, Nbase, n_emb = n_emb)
#self.inc = (DoubleConv(n_channels, Nbase))
self.down1 = (Down(Nbase, Nbase*2))
self.down2 = (Down(Nbase*2, Nbase*4))
self.down3 = (Down(Nbase*4, Nbase*8))
factor = 2 if bilinear else 1
self.down4 = (Down(Nbase*8, Nbase*16 // factor))
self.up1 = (Up(Nbase*16, Nbase*8 // factor, bilinear))
self.up2 = (Up(Nbase*8, Nbase*4 // factor, bilinear))
self.up3 = (Up(Nbase*4, Nbase*2 // factor, bilinear))
self.up4 = (Up(Nbase*2, Nbase, bilinear))
self.outc = (OutConv(Nbase, n_classes))
def forward(self, x, t=None):
timeEmb = self.tempEmb(t)
if self.inpBNFlag:
x1 = self.inpBN(x)
else:
x1 = x
x1 = self.inc(x1, timeEmb)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)