From d8c37983c837c24d86f053b42d20c55b25e9f92e Mon Sep 17 00:00:00 2001 From: gierle Date: Thu, 20 Apr 2023 14:08:04 +0200 Subject: [PATCH 1/2] fix: update pyro_ppl and adjust auxiliary shape Tests failed, ... - since Pyro 1.6.0 did not support PyTorch 2.0 - because of [the reduction of channel numbers in this commit](https://github.com/automl/NASLib/commit/a4412f18bf06f1c1d52f16ef5aa3c12d011a1573) --- requirements.txt | 2 +- tests/test_nb301_search_space.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 74b689a8a3..d0c7aff666 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ xgboost==1.4.2 emcee==3.1.0 pybnn==0.0.5 grakel==0.1.8 -pyro-ppl==1.6.0 +pyro-ppl==1.8.4 # additional from setup.py prev tqdm==4.61.1 diff --git a/tests/test_nb301_search_space.py b/tests/test_nb301_search_space.py index f1c8bdce46..31470cdee8 100644 --- a/tests/test_nb301_search_space.py +++ b/tests/test_nb301_search_space.py @@ -117,7 +117,7 @@ def test_forward_pass_aux_head(self): graph(torch.randn(3, 3, 32, 32)) aux_out = graph.auxiliary_logits() - self.assertEqual(aux_out.shape, (3, 512, 8, 8)) + self.assertEqual(aux_out.shape, (3, 256, 8, 8)) def test_forward_pass_aux_head_eval(self): graph = create_model() From e20f95f64106f9924384fb74b07b55d1a6efa190 Mon Sep 17 00:00:00 2001 From: gierle Date: Thu, 20 Apr 2023 15:45:22 +0200 Subject: [PATCH 2/2] fix: update deprecated function call at Stem op --- naslib/search_spaces/hierarchical/graph.py | 8 ++++---- naslib/search_spaces/simple_cell/graph.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/naslib/search_spaces/hierarchical/graph.py b/naslib/search_spaces/hierarchical/graph.py index 4c61ff2189..824a7e6f66 100755 --- a/naslib/search_spaces/hierarchical/graph.py +++ b/naslib/search_spaces/hierarchical/graph.py @@ -74,7 +74,7 @@ def __init__(self): self.add_nodes_from([i for i in range(1, 9)]) self.add_edges_from([(i, i + 1) for i in range(1, 8)]) - self.edges[1, 2].set("op", ops.Stem(16)) + self.edges[1, 2].set("op", ops.Stem(C_out=16)) self.edges[2, 3].set("op", cells[0]) self.edges[3, 4].set( "op", ops.SepConv(16, 32, kernel_size=3, stride=2, padding=1) @@ -117,7 +117,7 @@ def prepare_evaluation(self): single_instances=False, ) - self.edges[1, 2].set("op", ops.Stem(channels[0])) + self.edges[1, 2].set("op", ops.Stem(C_out=channels[0])) self.edges[2, 3].set("op", cells[0].copy()) self.edges[3, 4].set( "op", @@ -191,7 +191,7 @@ def _expand(self): # single_instances=False # ) - # self.edges[1, 2].set('op', ops.Stem(channels[0])) + # self.edges[1, 2].set('op', ops.Stem(C_out=channels[0])) # self.edges[2, 3].set('op', cells[0].copy()) # self.edges[3, 4].set('op', ops.SepConv(channels[0], channels[1], kernel_size=3, stride=2, padding=1)) # self.edges[4, 5].set('op', cells[1].copy()) @@ -400,7 +400,7 @@ def __init__(self): self.add_nodes_from([i for i in range(1, 15)]) self.add_edges_from([(i, i + 1) for i in range(1, 14)]) - self.edges[1, 2].set("op", ops.Stem(channels[0])) + self.edges[1, 2].set("op", ops.Stem(C_out=channels[0])) self.edges[2, 3].set("op", cells[0].copy()) self.edges[3, 4].set( "op", diff --git a/naslib/search_spaces/simple_cell/graph.py b/naslib/search_spaces/simple_cell/graph.py index de51ba346f..580bd9ae4f 100644 --- a/naslib/search_spaces/simple_cell/graph.py +++ b/naslib/search_spaces/simple_cell/graph.py @@ -139,7 +139,7 @@ def __init__( # Compile the ops self.edges[1, 2].set( - "op", ops.Stem(channels[0]) + "op", ops.Stem(C_out=channels[0]) ) # we can also set a compiled op. Will be ignored by compile() def set_channels(edge, C):