From 45df391b59b649da2e16397536409f87f5396852 Mon Sep 17 00:00:00 2001 From: leVirve Date: Sat, 14 Jul 2018 17:51:58 +0800 Subject: [PATCH] fix the bug for non-square inputs --- CoordConv.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CoordConv.py b/CoordConv.py index 3cb6529..693f4b6 100644 --- a/CoordConv.py +++ b/CoordConv.py @@ -15,7 +15,7 @@ def forward(self, input_tensor): """ batch_size_tensor = input_tensor.shape[0] - xx_ones = torch.ones([1, self.x_dim], dtype=torch.int32) + xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32) xx_ones = xx_ones.unsqueeze(-1) xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0) @@ -24,7 +24,7 @@ def forward(self, input_tensor): xx_channel = torch.matmul(xx_ones, xx_range) xx_channel = xx_channel.unsqueeze(-1) - yy_ones = torch.ones([1, self.y_dim], dtype=torch.int32) + yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32) yy_ones = yy_ones.unsqueeze(1) yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0) @@ -33,8 +33,8 @@ def forward(self, input_tensor): yy_channel = torch.matmul(yy_range, yy_ones) yy_channel = yy_channel.unsqueeze(-1) - xx_channel = xx_channel.permute(0, 3, 1, 2) - yy_channel = yy_channel.permute(0, 3, 1, 2) + xx_channel = xx_channel.permute(0, 3, 2, 1) + yy_channel = yy_channel.permute(0, 3, 2, 1) xx_channel = xx_channel.float() / (self.x_dim - 1) yy_channel = yy_channel.float() / (self.y_dim - 1)