Skip to content

Commit

Permalink
Merge pull request mkocabas#2 from leVirve-arxiv/master
Browse files Browse the repository at this point in the history
Fix the bug for non-square inputs
  • Loading branch information
Muhammed Kocabas authored Jul 14, 2018
2 parents 83bda07 + 45df391 commit 7504c52
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions CoordConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def forward(self, input_tensor):
"""
batch_size_tensor = input_tensor.shape[0]

xx_ones = torch.ones([1, self.x_dim], dtype=torch.int32)
xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32)
xx_ones = xx_ones.unsqueeze(-1)

xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0)
Expand All @@ -24,7 +24,7 @@ def forward(self, input_tensor):
xx_channel = torch.matmul(xx_ones, xx_range)
xx_channel = xx_channel.unsqueeze(-1)

yy_ones = torch.ones([1, self.y_dim], dtype=torch.int32)
yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32)
yy_ones = yy_ones.unsqueeze(1)

yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0)
Expand All @@ -33,8 +33,8 @@ def forward(self, input_tensor):
yy_channel = torch.matmul(yy_range, yy_ones)
yy_channel = yy_channel.unsqueeze(-1)

xx_channel = xx_channel.permute(0, 3, 1, 2)
yy_channel = yy_channel.permute(0, 3, 1, 2)
xx_channel = xx_channel.permute(0, 3, 2, 1)
yy_channel = yy_channel.permute(0, 3, 2, 1)

xx_channel = xx_channel.float() / (self.x_dim - 1)
yy_channel = yy_channel.float() / (self.y_dim - 1)
Expand Down

0 comments on commit 7504c52

Please sign in to comment.