diff --git a/robosat/fpn.py b/robosat/fpn.py index 8b3f0b49..5a323a7d 100644 --- a/robosat/fpn.py +++ b/robosat/fpn.py @@ -46,6 +46,9 @@ def __init__(self, num_filters=256, pretrained=True): def forward(self, x): # Bottom-up pathway, from ResNet + size = x.size() + assert size[-1] % 32 == 0 and size[-2] % 32 == 0, "image resolution has to be divisible by 32 for resnet" + enc0 = self.resnet.conv1(x) enc0 = self.resnet.bn1(enc0) enc0 = self.resnet.relu(enc0)