From ca3bc000d248ab564f1018faca60a7c1a3c44be1 Mon Sep 17 00:00:00 2001 From: "Daniel J. Hofmann" Date: Wed, 1 Aug 2018 15:12:22 +0200 Subject: [PATCH] Works around internal PyTorch bug related to shared weights --- robosat/unet.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/robosat/unet.py b/robosat/unet.py index 2e455e1a..e9e4c5b3 100644 --- a/robosat/unet.py +++ b/robosat/unet.py @@ -93,11 +93,8 @@ def __init__(self, num_classes, num_filters=32, pretrained=True): self.resnet = resnet50(pretrained=pretrained) - self.enc0 = nn.Sequential(self.resnet.conv1, self.resnet.bn1, self.resnet.relu, self.resnet.maxpool) - self.enc1 = self.resnet.layer1 # 256 - self.enc2 = self.resnet.layer2 # 512 - self.enc3 = self.resnet.layer3 # 1024 - self.enc4 = self.resnet.layer4 # 2048 + # Access resnet directly in forward pass; do not store refs here due to + # https://github.com/pytorch/pytorch/issues/8392 self.center = DecoderBlock(2048, num_filters * 8) @@ -120,11 +117,15 @@ def forward(self, x): The networks output tensor. """ - enc0 = self.enc0(x) - enc1 = self.enc1(enc0) - enc2 = self.enc2(enc1) - enc3 = self.enc3(enc2) - enc4 = self.enc4(enc3) + enc0 = self.resnet.conv1(x) + enc0 = self.resnet.bn1(enc0) + enc0 = self.resnet.relu(enc0) + enc0 = self.resnet.maxpool(enc0) + + enc1 = self.resnet.layer1(enc0) + enc2 = self.resnet.layer2(enc1) + enc3 = self.resnet.layer3(enc2) + enc4 = self.resnet.layer4(enc3) center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2))