Skip to content

Commit

Permalink
Works around internal PyTorch bug related to shared weights
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed Aug 1, 2018
1 parent da30c57 commit ca3bc00
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions robosat/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,8 @@ def __init__(self, num_classes, num_filters=32, pretrained=True):

self.resnet = resnet50(pretrained=pretrained)

self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool)
self.enc1 = self.resnet.layer1 # 256
self.enc2 = self.resnet.layer2 # 512
self.enc3 = self.resnet.layer3 # 1024
self.enc4 = self.resnet.layer4 # 2048
# Access resnet directly in forward pass; do not store refs here due to
# https://github.com/pytorch/pytorch/issues/8392

self.center = DecoderBlock(2048, num_filters * 8)

Expand All @@ -120,11 +117,15 @@ def forward(self, x):
The networks output tensor.
"""

enc0 = self.enc0(x)
enc1 = self.enc1(enc0)
enc2 = self.enc2(enc1)
enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3)
enc0 = self.resnet.conv1(x)
enc0 = self.resnet.bn1(enc0)
enc0 = self.resnet.relu(enc0)
enc0 = self.resnet.maxpool(enc0)

enc1 = self.resnet.layer1(enc0)
enc2 = self.resnet.layer2(enc1)
enc3 = self.resnet.layer3(enc2)
enc4 = self.resnet.layer4(enc3)

center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))

Expand Down

0 comments on commit ca3bc00

Please sign in to comment.