Skip to content

Commit

Permalink
add backbones
Browse files Browse the repository at this point in the history
  • Loading branch information
NastyaMittseva committed Jan 10, 2022
1 parent 378f13e commit 54360f7
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 34 deletions.
90 changes: 57 additions & 33 deletions network/AEI_Net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
from .AADLayer import *
from network.resnet import MLAttrEncoderResnet


def weight_init(m):
Expand Down Expand Up @@ -30,31 +31,42 @@ def __init__(self, in_c, out_c, norm=nn.BatchNorm2d):
self.bn = norm(out_c)
self.lrelu = nn.LeakyReLU(0.1, inplace=True)

def forward(self, input, skip):
def forward(self, input, skip, backbone):
x = self.deconv(input)
x = self.bn(x)
x = self.lrelu(x)
return torch.cat((x, skip), dim=1)


if backbone == 'linknet':
return x+skip
else:
return torch.cat((x, skip), dim=1)


class MLAttrEncoder(nn.Module):
def __init__(self):
def __init__(self, backbone):
super(MLAttrEncoder, self).__init__()
self.backbone = backbone
self.conv1 = conv4x4(3, 32)
self.conv2 = conv4x4(32, 64)
self.conv3 = conv4x4(64, 128)
self.conv4 = conv4x4(128, 256)
self.conv5 = conv4x4(256, 512)
self.conv6 = conv4x4(512, 1024)
self.conv7 = conv4x4(1024, 1024)

self.deconv1 = deconv4x4(1024, 1024)
self.deconv2 = deconv4x4(2048, 512)
self.deconv3 = deconv4x4(1024, 256)
self.deconv4 = deconv4x4(512, 128)
self.deconv5 = deconv4x4(256, 64)
self.deconv6 = deconv4x4(128, 32)


if backbone == 'unet':
self.deconv1 = deconv4x4(1024, 1024)
self.deconv2 = deconv4x4(2048, 512)
self.deconv3 = deconv4x4(1024, 256)
self.deconv4 = deconv4x4(512, 128)
self.deconv5 = deconv4x4(256, 64)
self.deconv6 = deconv4x4(128, 32)
elif backbone == 'linknet':
self.deconv1 = deconv4x4(1024, 1024)
self.deconv2 = deconv4x4(1024, 512)
self.deconv3 = deconv4x4(512, 256)
self.deconv4 = deconv4x4(256, 128)
self.deconv5 = deconv4x4(128, 64)
self.deconv6 = deconv4x4(64, 32)
self.apply(weight_init)

def forward(self, Xt):
Expand All @@ -73,29 +85,38 @@ def forward(self, Xt):
z_attr1 = self.conv7(feat6)
# 1024x2x2

z_attr2 = self.deconv1(z_attr1, feat6)
z_attr3 = self.deconv2(z_attr2, feat5)
z_attr4 = self.deconv3(z_attr3, feat4)
z_attr5 = self.deconv4(z_attr4, feat3)
z_attr6 = self.deconv5(z_attr5, feat2)
z_attr7 = self.deconv6(z_attr6, feat1)
z_attr2 = self.deconv1(z_attr1, feat6, self.backbone)
z_attr3 = self.deconv2(z_attr2, feat5, self.backbone)
z_attr4 = self.deconv3(z_attr3, feat4, self.backbone)
z_attr5 = self.deconv4(z_attr4, feat3, self.backbone)
z_attr6 = self.deconv5(z_attr5, feat2, self.backbone)
z_attr7 = self.deconv6(z_attr6, feat1, self.backbone)
z_attr8 = F.interpolate(z_attr7, scale_factor=2, mode='bilinear', align_corners=True)
return z_attr1, z_attr2, z_attr3, z_attr4, z_attr5, z_attr6, z_attr7, z_attr8


class AADGenerator(nn.Module):
def __init__(self, c_id=256, num_blocks=2):
def __init__(self, backbone, c_id=256, num_blocks=2):
super(AADGenerator, self).__init__()
self.up1 = nn.ConvTranspose2d(c_id, 1024, kernel_size=2, stride=1, padding=0)
self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id, num_blocks)
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id, num_blocks)
self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id, num_blocks)
self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id, num_blocks)
self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id, num_blocks)
self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id, num_blocks)

if backbone == 'linknet':
self.AADBlk2 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
self.AADBlk3 = AAD_ResBlk(1024, 1024, 512, c_id, num_blocks)
self.AADBlk4 = AAD_ResBlk(1024, 512, 256, c_id, num_blocks)
self.AADBlk5 = AAD_ResBlk(512, 256, 128, c_id, num_blocks)
self.AADBlk6 = AAD_ResBlk(256, 128, 64, c_id, num_blocks)
self.AADBlk7 = AAD_ResBlk(128, 64, 32, c_id, num_blocks)
self.AADBlk8 = AAD_ResBlk(64, 3, 32, c_id, num_blocks)
else:
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id, num_blocks)
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id, num_blocks)
self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id, num_blocks)
self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id, num_blocks)
self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id, num_blocks)
self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id, num_blocks)

self.apply(weight_init)

def forward(self, z_attr, z_id):
Expand All @@ -111,19 +132,22 @@ def forward(self, z_attr, z_id):
return torch.tanh(y)



class AEI_Net(nn.Module):
def __init__(self, c_id=256, num_blocks=2):
def __init__(self, backbone, num_blocks=2, c_id=256):
super(AEI_Net, self).__init__()
self.encoder = MLAttrEncoder()
self.generator = AADGenerator(c_id, num_blocks)
if backbone in ['unet', 'linknet']:
self.encoder = MLAttrEncoder(backbone)
elif backbone == 'resnet':
self.encoder = MLAttrEncoderResnet()
self.generator = AADGenerator(backbone, c_id, num_blocks)

def forward(self, Xt, z_id):
attr = self.encoder(Xt)
Y = self.generator(attr, z_id)
return Y, attr

def get_attr(self, X):
# with torch.no_grad():
return self.encoder(X)


Expand Down
149 changes: 149 additions & 0 deletions network/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch.nn as nn
import math


def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000, include_top=True):
self.inplanes = 64
super(ResNet, self).__init__()
self.include_top = include_top

self.conv0 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False)
self.bn0 = nn.BatchNorm2d(64)
self.relu0 = nn.ReLU(inplace=True)

self.conv1 = nn.Conv2d(64, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)

self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
self.layer2 = self._make_layer(block, 64, layers[1], stride=2)
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
self.layer4 = self._make_layer(block, 256, layers[3], stride=2)
self.layer5 = self._make_layer(block, 512, layers[4], stride=2)
self.layer6 = self._make_layer(block, 256, layers[5], stride=2)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
x0 = self.conv0(x)
x0 = self.bn0(x0)
x0 = self.relu0(x0)

x1 = self.conv1(x0)
x1 = self.bn1(x1)
x1 = self.relu(x1)

x2 = self.layer1(x1)
x3 = self.layer2(x2)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
x6 = self.layer5(x5)
x7 = self.layer6(x6)

return x7, x6, x5, x4, x3, x2, x1, x0


def MLAttrEncoderResnet(**kwargs):
model = ResNet(Bottleneck, [2, 2, 2, 2, 2, 2], **kwargs)
return model
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def train(args, device):
max_epoch = args.max_epoch

# initializing main models
G = AEI_Net(c_id=512, num_blocks=args.num_blocks).to(device)
G = AEI_Net(args.backbone, num_blocks=args.num_blocks, c_id=512).to(device)
D = MultiscaleDiscriminator(input_nc=3, n_layers=5, norm_layer=torch.nn.InstanceNorm2d).to(device)
G.train()
D.train()
Expand Down Expand Up @@ -268,6 +268,7 @@ def main(args):
parser.add_argument('--weight_eyes', default=0., type=float, help='Eyes Loss weight')
# training params you may want to change

parser.add_argument('--backbone', default='unet', const='unet', nargs='?', choices=['unet', 'linknet', 'resnet'], help='Backbone for attribute encoder')
parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock')
parser.add_argument('--same_person', default=0.2, type=float, help='Probability of using same person identity during training')
parser.add_argument('--same_identity', default=True, type=bool, help='Using simswap approach, when source_id = target_id. Only possible with vgg=True')
Expand Down

0 comments on commit 54360f7

Please sign in to comment.