diff --git a/CoordConv.py b/CoordConv.py index 33da5df..d2e7cdd 100644 --- a/CoordConv.py +++ b/CoordConv.py @@ -112,7 +112,10 @@ class CoordConv(nn.Module): def __init__(self, in_channels, out_channels, with_r=False, **kwargs): super().__init__() self.addcoords = AddCoords(with_r=with_r) - self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs) + in_size = in_channels+2 + if with_r: + in_size += 1 + self.conv = nn.Conv2d(in_size, out_channels, **kwargs) def forward(self, x): ret = self.addcoords(x)