From 4ac405abdacec827563798f218b536624ddf00d2 Mon Sep 17 00:00:00 2001 From: Gregory Johnson Date: Tue, 17 Oct 2017 14:22:50 -0700 Subject: [PATCH 1/3] sets variable to be on same gpu as img1 --- pytorch_ssim/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py index 89da734..f8ecd8b 100644 --- a/pytorch_ssim/__init__.py +++ b/pytorch_ssim/__init__.py @@ -51,6 +51,10 @@ def forward(self, img1, img2): window = self.window else: window = create_window(self.window_size, channel).type_as(img1) + + if window.is_cuda(): + window.set_device(img1.get_device()) + self.window = window self.channel = channel @@ -60,4 +64,8 @@ def forward(self, img1, img2): def ssim(img1, img2, window_size = 11, size_average = True): (_, channel, _, _) = img1.size() window = create_window(window_size, channel).type_as(img1) + + if window.is_cuda(): + window.set_device(img1.get_device()) + return _ssim(img1, img2, window, window_size, channel, size_average) From 881d210fba3fc9af9d67787db11ae29b8f118268 Mon Sep 17 00:00:00 2001 From: Gregory Johnson Date: Tue, 17 Oct 2017 14:50:05 -0700 Subject: [PATCH 2/3] fixes with respect ot not setting cuda device properly --- pytorch_ssim/__init__.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py index f8ecd8b..738e803 100644 --- a/pytorch_ssim/__init__.py +++ b/pytorch_ssim/__init__.py @@ -50,10 +50,11 @@ def forward(self, img1, img2): if channel == self.channel and self.window.data.type() == img1.data.type(): window = self.window else: - window = create_window(self.window_size, channel).type_as(img1) + window = create_window(self.window_size, channel) - if window.is_cuda(): - window.set_device(img1.get_device()) + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) self.window = window self.channel = channel @@ -63,9 +64,10 @@ def forward(self, img1, img2): def ssim(img1, img2, window_size = 11, size_average = True): (_, channel, _, _) = img1.size() - window = create_window(window_size, channel).type_as(img1) + window = create_window(window_size, channel) - if window.is_cuda(): - window.set_device(img1.get_device()) + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) return _ssim(img1, img2, window, window_size, channel, size_average) From 55b586831e24df0545553dc950fb8e24bb093105 Mon Sep 17 00:00:00 2001 From: Gregory Johnson Date: Wed, 18 Oct 2017 11:21:38 -0700 Subject: [PATCH 3/3] updates for 3D --- pytorch_ssim/__init__.py | 109 +++++++++++++++++++++++++++------------ 1 file changed, 76 insertions(+), 33 deletions(-) diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py index 738e803..907624a 100644 --- a/pytorch_ssim/__init__.py +++ b/pytorch_ssim/__init__.py @@ -8,23 +8,71 @@ def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() -def create_window(window_size, channel): - _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) - window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) - return window +def do_conv(conv_func, img, windows, channel): -def _ssim(img1, img2, window, window_size, channel, size_average = True): - mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) - mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + ndims = len(windows) + + for i in range(0, ndims): + + window = windows[i] + + padding_amt = int((np.max(window.size())-1) /2 ) + padding = [0]*(ndims) + padding[i] = padding_amt + padding = tuple(padding) + + img = conv_func(img, window, padding=padding, groups = channel) + + return img + +def create_windows(img, window_sizes, sigma): + + windows = list() + + ndims = len(img.size())-2 + + if type(window_sizes) is not list: + window_sizes = [window_sizes]*ndims + + for i in range(0, ndims): + g = gaussian(window_sizes[i], 1.5) + + for j in range(0, ndims+1): + g = g.unsqueeze(-1) + + g = g.transpose(0, i+2) + + g = Variable(g) + if img.is_cuda: + g = g.cuda(img.get_device()) + g = g.type_as(img) + + windows.append(g) + + return windows + +def _ssim(img1, img2, windows, channel, size_average = True): + + ndims = len(windows) + + if ndims == 1: + conv_func = F.conv1d + if ndims == 2: + conv_func = F.conv2d + if ndims == 3: + conv_func = F.conv3d + + mu1 = do_conv(conv_func, img1, windows, channel = channel) + mu2 = do_conv(conv_func, img2, windows, channel = channel) + mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1*mu2 - sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq - sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq - sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + sigma1_sq = do_conv(conv_func, img1*img1, windows, channel = channel) - mu1_sq + sigma2_sq = do_conv(conv_func, img2*img2, windows, channel = channel) - mu2_sq + sigma12 = do_conv(conv_func, img1*img2, windows, channel = channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 @@ -37,37 +85,32 @@ def _ssim(img1, img2, window, window_size, channel, size_average = True): return ssim_map.mean(1).mean(1).mean(1) class SSIM(torch.nn.Module): - def __init__(self, window_size = 11, size_average = True): + def __init__(self, window_size = 11, sigma = 0.15, size_average = True): super(SSIM, self).__init__() + self.window_size = window_size + self.sigma = sigma + self.size_average = size_average self.channel = 1 - self.window = create_window(window_size, self.channel) + + self.windows = None def forward(self, img1, img2): - (_, channel, _, _) = img1.size() + imsize = img1.size() + channel = imsize[1] - if channel == self.channel and self.window.data.type() == img1.data.type(): - window = self.window - else: - window = create_window(self.window_size, channel) - - if img1.is_cuda: - window = window.cuda(img1.get_device()) - window = window.type_as(img1) - - self.window = window + if self.windows is None: + self.windows = create_windows(img1, self.window_size, self.sigma) self.channel = channel - return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + return _ssim(img1, img2, self.windows, self.channel, self.size_average) -def ssim(img1, img2, window_size = 11, size_average = True): - (_, channel, _, _) = img1.size() - window = create_window(window_size, channel) - - if img1.is_cuda: - window = window.cuda(img1.get_device()) - window = window.type_as(img1) +def ssim(img1, img2, window_size = 11, sigma = 1.5, size_average = True): + imsize = img1.size() + channel = imsize[1] + + windows = create_windows(img1, window_size, sigma) - return _ssim(img1, img2, window, window_size, channel, size_average) + return _ssim(img1, img2, windows, channel, size_average)