Skip to content

Commit

Permalink
Merge pull request #4 from marsggbo/master
Browse files Browse the repository at this point in the history
merge master of marggbo to master of pprp
  • Loading branch information
PJDong authored Oct 13, 2021
2 parents 627f7b5 + 2db0003 commit d0df3c8
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 37 deletions.
6 changes: 6 additions & 0 deletions hyperbox/networks/base_nas_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,13 @@ def sub_filter_start_end(kernel_size, sub_kernel_size):
name = ''.join(key.split('.candidates')[0])
module = self.get_module_by_name(name)
if isinstance(module, spaces.OperationSpace):
cand_indices = {}
for idx, cand in enumerate(module.candidates_original):
cand_indices[cand.__class__.__name__] = idx
index = module.index
if index is None:
cand_index = int(key.split('.candidates.')[1].split('.')[0])
index = cand_indices[module.candidates[cand_index].__class__.__name__]
prefix, suffix = key.split('.candidates.')
prefix += '.candidates'
suffix = '.'.join(suffix.split('.')[1:])
Expand Down
12 changes: 8 additions & 4 deletions hyperbox/networks/bnnas/ea_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = BNNet(num_classes=10, search_depth=False)
ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_bn_depth_adam0.001_sync_hete/2021-10-02_23-05-51/checkpoints/epoch=390_val/acc=27.8100.ckpt'
# ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_bn_depth_adam0.001_sync_hete/2021-10-02_23-05-51/checkpoints/epoch=390_val/acc=27.8100.ckpt'
# ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_all_depth_adam0.001_sync_hete/2021-10-02_23-05-43/checkpoints/epoch=339_val/acc=43.1300.ckpt'
# ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_bn_adam0.001_sync_hete/2021-10-06_06-29-41/checkpoints/epoch=392_val/acc=28.8200.ckpt'
ckpt = '/home/xihe/xinhe/hyperbox/logs/runs/bnnas_c10_all_adam0.001_sync_hete/2021-10-06_06-31-00/checkpoints/epoch=302_val/acc=44.0300.ckpt'
ckpt = torch.load(ckpt, map_location='cpu')
weights = {}
for key in ckpt['state_dict']:
Expand All @@ -24,10 +26,12 @@
net = net.to(device)

# method 1
ea = EAMutator(net, num_population=50, algorithm='top')
mode = 'all'
search_algorithm = 'cars'
ea = EAMutator(net, num_population=50, algorithm=search_algorithm)
# ea.load_ckpt('epoch2.pth')
eval_func = lambda arch, net: net.bn_metrics().item()
ea.search(20, eval_func, verbose=True, filling_history=True)
ea.search(21, eval_func, verbose=True, filling_history=True)
size = np.array([pool['size'] for pool in ea.history.values()])
metric = np.array([pool['metric'] for pool in ea.history.values()])
indices = np.argsort(size)
Expand All @@ -37,7 +41,7 @@
pareto_indices = pareto_lists[0] # e.g., [75, 87, 113, 201, 205]
plot_pareto_fronts(
size, metric, pareto_indices, 'model size (MB)', 'BN-based metric',
figname=f'bn_depth_pareto_epoch_single{epoch}.pdf'
figname=f'{mode}_pareto_searchepoch{epoch}_{search_algorithm}.pdf'
)

# method 2
Expand Down
4 changes: 3 additions & 1 deletion hyperbox/networks/darts/darts_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def forward(self, x):
# noinspection PyUnresolvedReferences
padding = torch.zeros(n, c, h, w, device=device, requires_grad=False)
return padding'''
return x * 0
if self.stride == 1:
return x.mul(0.)
return x[:,:,::self.stride,::self.stride].mul(0.)

@staticmethod
def is_zero_layer():
Expand Down
3 changes: 2 additions & 1 deletion hyperbox/networks/repnas/rep_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
kernel_size=1, stride=1, padding=0, groups=groups, bias=False))
self.dbb_1x1_kxk.add_module('bn1', BNAndPadLayer(
pad_pixels=padding, num_features=internal_channels_1x1_3x3, affine=True))
padding = self.padding - 1
self.dbb_1x1_kxk.add_module('conv2', nn.Conv2d(in_channels=internal_channels_1x1_3x3, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=False))
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))

def forward(self, input): # input: [5, 16, 32, 32]
Expand Down
58 changes: 39 additions & 19 deletions hyperbox/networks/repnas/repnas_spos.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@


class RepBlock(nn.Module):
def __init__(self, idx, i_block, inp, outp, stride, mask):
def __init__(self, idx, i_block, inp, outp, stride, mask, ks=3):
super(RepBlock, self).__init__()
self.stride = stride
self.use_skip_connect = stride==1 and inp==outp

self.block = OperationSpace(
[
DBBORIGIN(inp, outp, kernel_size=3, stride=stride),
DBBAVG(inp, outp, kernel_size=3, stride=stride),
DBBORIGIN(inp, outp, kernel_size=ks, stride=stride),
DBBAVG(inp, outp, kernel_size=ks, stride=stride),
DBB1x1(inp, outp, stride=stride),
DBB1x1kxk(inp, outp, kernel_size=3, stride=stride),
DBB1x1kxk(inp, outp, kernel_size=ks, stride=stride),
],
return_mask=False,
mask=mask,
Expand Down Expand Up @@ -92,14 +92,14 @@ def __init__(

self._initialize_weights()

def _make_blocks(self, idx, blocks, in_channels, channels, stride):
def _make_blocks(self, idx, blocks, in_channels, channels, stride, kernel_size=5):
result = []
for i_block in range(blocks):
stride = stride if i_block == 0 else 1
inp = in_channels if i_block == 0 else channels
outp = channels

result.append(RepBlock(idx, i_block, inp, outp, stride, self.mask))
result.append(RepBlock(idx, i_block, inp, outp, stride, self.mask, kernel_size))

return result

Expand Down Expand Up @@ -148,19 +148,39 @@ def _initialize_weights(self):
DBB1x1
DBB1x1kxk
"""
from copy import deepcopy
import numpy as np
from hyperbox.mutator import RandomMutator, DartsMutator

net = RepNAS()
net.eval()
rm = DartsMutator(net)
rm.reset()
net = RepNAS(mask=rm._cache)

x = torch.zeros(2, 3, 32, 32)
y1 = net(x)
replace(net)
net.eval()
y2 = net(x)
print(y1,y2)
print(np.allclose(y1.detach().numpy(), y2.detach().numpy(), atol=1e-5))
for i in range(10):
supernet = RepNAS()
rm = DartsMutator(supernet)
rm.reset()
if i < 5:
# Bool mask
mask_type = 'bool'
mask = {}
threshold = 0.25
for key, value in rm._cache.items():
mask[key] = value.detach()>threshold
elif i < 10:
# float mask
mask_type = 'float'
mask = rm._cache
if i % 2 == 0:
net_type = 'RepNAS(mask=mask)'
net = RepNAS(mask=mask)
else:
net_type = 'supernet.build_subnet(mask=mask)'
net = supernet.build_subnet(mask=mask)
net.eval()
print(f"{i} {mask_type} {net_type}")

x = torch.zeros(8, 3, 32, 32)
y1 = net(x).abs().sum()
replace(net)
net.eval()
y2 = net(x).abs().sum()
print(f"{y1.abs().sum():.8f} \n{y2.abs().sum():.8f}")
# print(y1.softmax(-1),'\n',y2.softmax(-1))
print(np.allclose(y1.detach().numpy(), y2.detach().numpy(), atol=1e-5))
27 changes: 15 additions & 12 deletions hyperbox/networks/repnas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from hyperbox.networks.repnas.rep_ops import *


def fuse(candidates, kernel_size=3):
def fuse(candidates, weights, kernel_size=3):
k_list = []
b_list = []

for i in range(len(candidates)):
op = candidates[i]
weight = weights[i].float()
if op.__class__.__name__ == "DBB1x1kxk":
if hasattr(op.dbb_1x1_kxk, 'idconv1'):
k1 = op.dbb_1x1_kxk.idconv1.get_actual_kernel()
Expand All @@ -38,32 +39,34 @@ def fuse(candidates, kernel_size=3):
k, b = k2, b2
else:
raise "TypeError: Not In DBBAVG DBB1x1kxk DBB1x1 DBBORIGIN."
k_list.append(k)
b_list.append(b)
k_list.append(k.detach() * weight)
b_list.append(b.detach() * weight)

return transII_addbranch(k_list, b_list)


def replace(net):
for name, module in net.named_modules():
if isinstance(module, OperationSpace):
k, b = fuse(module.candidates)
first = module.candidates[0]
candidates = []
weights = []
for idx, weight in enumerate(module.mask):
if weight:
candidates.append(module.candidates_original[idx])
weights.append(weight)
ks = max([c_.kernel_size for c_ in candidates])
k, b = fuse(candidates, weights, ks)
first = module.candidates_original[0]
inc = first.in_channels
ouc = first.out_channels
ks = first.kernel_size
s = first.stride
p = first.padding
p = ks//2
g = first.groups
reparam = nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=ks,
stride=s, padding=p, dilation=1, groups=g)
reparam.weight.data = k
reparam.bias.data = b

for i in range(len(module.candidates)):
op = module.candidates[i]
for para in op.parameters():
para.detach_()

module.candidates_original = [reparam]
module.candidates = torch.nn.ModuleList([reparam])
module.mask = torch.tensor([True])

0 comments on commit d0df3c8

Please sign in to comment.