diff --git a/hyperbox/networks/repnas/repnas_spos.py b/hyperbox/networks/repnas/repnas_spos.py index 99ee2b4..58e62b8 100644 --- a/hyperbox/networks/repnas/repnas_spos.py +++ b/hyperbox/networks/repnas/repnas_spos.py @@ -156,7 +156,7 @@ def _initialize_weights(self): supernet = RepNAS() rm = DartsMutator(supernet) rm.reset() - if i < 0: + if i < 5: # Bool mask mask_type = 'bool' mask = {} diff --git a/hyperbox/networks/repnas/utils.py b/hyperbox/networks/repnas/utils.py index fbe1e65..a4cf2f2 100644 --- a/hyperbox/networks/repnas/utils.py +++ b/hyperbox/networks/repnas/utils.py @@ -60,7 +60,7 @@ def replace(net): inc = first.in_channels ouc = first.out_channels s = first.stride - p = first.padding + p = ks//2 g = first.groups reparam = nn.Conv2d(in_channels=inc, out_channels=ouc, kernel_size=ks, stride=s, padding=p, dilation=1, groups=g)