Skip to content

Commit

Permalink
update code for AddResBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
NastyaMittseva committed Jan 10, 2022
1 parent 4a448a7 commit 378f13e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 39 deletions.
60 changes: 33 additions & 27 deletions network/AADLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class AADLayer(nn.Module):
def __init__(self, c_x, attr_c, c_id=256):
def __init__(self, c_x, attr_c, c_id):
super(AADLayer, self).__init__()
self.attr_c = attr_c
self.c_id = c_id
Expand Down Expand Up @@ -37,41 +37,47 @@ def forward(self, h_in, z_attr, z_id):
out = (torch.ones_like(M).to(M.device) - M) * A + M * I
return out



class AddBlocksSequential(nn.Sequential):
def forward(self, *inputs):
h, z_attr, z_id = inputs
for i, module in enumerate(self._modules.values()):
if i%3 == 0 and i > 0:
inputs = (inputs, z_attr, z_id)
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs


class AAD_ResBlk(nn.Module):
def __init__(self, cin, cout, c_attr, c_id=256):
def __init__(self, cin, cout, c_attr, c_id, num_blocks):
super(AAD_ResBlk, self).__init__()
self.cin = cin
self.cout = cout

self.AAD1 = AADLayer(cin, c_attr, c_id)
self.conv1 = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False)
self.relu1 = nn.ReLU(inplace=True)

self.AAD2 = AADLayer(cin, c_attr, c_id)
self.conv2 = nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False)
self.relu2 = nn.ReLU(inplace=True)


add_blocks = []
for i in range(num_blocks):
out = cin if i < (num_blocks-1) else cout
add_blocks.extend([AADLayer(cin, c_attr, c_id),
nn.ReLU(inplace=True),
nn.Conv2d(cin, out, kernel_size=3, stride=1, padding=1, bias=False)
])
self.add_blocks = AddBlocksSequential(*add_blocks)

if cin != cout:
self.AAD3 = AADLayer(cin, c_attr, c_id)
self.conv3 = nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False)
self.relu3 = nn.ReLU(inplace=True)
last_add_block = [AADLayer(cin, c_attr, c_id),
nn.ReLU(inplace=True),
nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False)]
self.last_add_block = AddBlocksSequential(*last_add_block)


def forward(self, h, z_attr, z_id):
x = self.AAD1(h, z_attr, z_id)
x = self.relu1(x)
x = self.conv1(x)

x = self.AAD2(x,z_attr, z_id)
x = self.relu2(x)
x = self.conv2(x)

x = self.add_blocks(h, z_attr, z_id)
if self.cin != self.cout:
h = self.AAD3(h, z_attr, z_id)
h = self.relu3(h)
h = self.conv3(h)
h = self.last_add_block(h, z_attr, z_id)
x = x + h

return x


22 changes: 11 additions & 11 deletions network/AEI_Net.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,17 @@ def forward(self, Xt):


class AADGenerator(nn.Module):
def __init__(self, c_id=256):
def __init__(self, c_id=256, num_blocks=2):
super(AADGenerator, self).__init__()
self.up1 = nn.ConvTranspose2d(c_id, 1024, kernel_size=2, stride=1, padding=0)
self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id)
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id)
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id)
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id)
self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id)
self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id)
self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id)
self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id)
self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id, num_blocks)
self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id, num_blocks)
self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id, num_blocks)
self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id, num_blocks)
self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id, num_blocks)
self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id, num_blocks)
self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id, num_blocks)

self.apply(weight_init)

Expand All @@ -112,10 +112,10 @@ def forward(self, z_attr, z_id):


class AEI_Net(nn.Module):
def __init__(self, c_id=256):
def __init__(self, c_id=256, num_blocks=2):
super(AEI_Net, self).__init__()
self.encoder = MLAttrEncoder()
self.generator = AADGenerator(c_id)
self.generator = AADGenerator(c_id, num_blocks)

def forward(self, Xt, z_id):
attr = self.encoder(Xt)
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def train(args, device):
max_epoch = args.max_epoch

# initializing main models
G = AEI_Net(c_id=512).to(device)
G = AEI_Net(c_id=512, num_blocks=args.num_blocks).to(device)
D = MultiscaleDiscriminator(input_nc=3, n_layers=5, norm_layer=torch.nn.InstanceNorm2d).to(device)
G.train()
D.train()
Expand Down Expand Up @@ -267,6 +267,8 @@ def main(args):
parser.add_argument('--weight_rec', default=10, type=float, help='Reconstruction Loss weight')
parser.add_argument('--weight_eyes', default=0., type=float, help='Eyes Loss weight')
# training params you may want to change

parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock')
parser.add_argument('--same_person', default=0.2, type=float, help='Probability of using same person identity during training')
parser.add_argument('--same_identity', default=True, type=bool, help='Using simswap approach, when source_id = target_id. Only possible with vgg=True')
parser.add_argument('--diff_eq_same', default=False, type=bool, help='Don\'t use info about where is defferent identities')
Expand Down

0 comments on commit 378f13e

Please sign in to comment.