-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
58 lines (39 loc) · 1.27 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch, os, sys, cv2, json, argparse, random
import torch.nn as nn
from torch.nn import init
import functools
import torch.optim as optim
import torchvision.models as tvmodels
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as func
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
class VGG19_5_4(nn.Module):
def __init__(self):
super(VGG19_5_4, self).__init__()
features = list(tvmodels.vgg16(pretrained=True).cuda().features)[:36]
self.features = nn.ModuleList(features).eval()
def forward(self, x):
for ii, model in enumerate(self.features):
x = model(x)
return x
class VGG16_3_3(nn.Module):
def __init__(self):
super(VGG16_3_3, self).__init__()
features = list(tvmodels.vgg16(pretrained=True).cuda().features)[:16]
self.features = nn.ModuleList(features).eval()
def forward(self, x):
for ii, model in enumerate(self.features):
x = model(x)
return x
class PerceptualLoss(nn.Module):
def __init__(self):
super().__init__()
self.vgg = VGG16_3_3()
self.l2 = nn.MSELoss()
self.l1 = nn.L1Loss()
def forward(self, output, target):
output_vgg = self.vgg(output)
target_vgg = self.vgg(target)
return 0.8*self.l1(output, target) + 0.2*self.l2(output_vgg, target_vgg)