@@ -190,6 +190,8 @@ def _init_search(self):
190
190
self ._K = 0
191
191
self ._iter_best_config = self .trial_count = 1
192
192
self ._reset_times = 0
193
+ # record intermediate trial cost
194
+ self ._trial_cost = {}
193
195
194
196
@property
195
197
def step_lower_bound (self ) -> float :
@@ -237,7 +239,8 @@ def complete_config(self, partial_config: Dict,
237
239
''' generate a complete config from the partial config input
238
240
add minimal resource to config if available
239
241
'''
240
- if self ._reset_times : # not the first time, use random gaussian
242
+ if self ._reset_times and partial_config == self .init_config :
243
+ # not the first time to complete init_config, use random gaussian
241
244
normalized = self .normalize (partial_config )
242
245
for key in normalized :
243
246
# don't change unordered cat choice
@@ -258,21 +261,22 @@ def complete_config(self, partial_config: Dict,
258
261
normalized [key ] = max (l , min (u , normalized [key ] + delta ))
259
262
# use best config for unordered cat choice
260
263
config = self .denormalize (normalized )
264
+ self ._reset_times += 1
261
265
else :
266
+ # first time init_config, or other configs, take as is
262
267
config = partial_config .copy ()
263
268
264
269
for key , value in self .space .items ():
265
270
if key not in config :
266
271
config [key ] = value
267
- logger .debug (f'before random { config } ' )
272
+ # logger.debug(f'before random {config}')
268
273
for _ , generated in generate_variants ({'config' : config }):
269
274
config = generated ['config' ]
270
275
break
271
- logger .debug (f'after random { config } ' )
276
+ # logger.debug(f'after random {config}')
272
277
273
278
if self ._resource :
274
279
config [self .prune_attr ] = self .min_resource
275
- self ._reset_times += 1
276
280
return config
277
281
278
282
def create (self , init_config : Dict , obj : float , cost : float ) -> Searcher :
@@ -442,7 +446,8 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
442
446
if proposed_by == self .incumbent :
443
447
# proposed by current incumbent and no better
444
448
self ._num_complete4incumbent += 1
445
- cost = result .get (self .cost_attr )
449
+ cost = result .get (
450
+ self .cost_attr ) if result else self ._trial_cost .get (trial_id )
446
451
if cost : self ._cost_complete4incumbent += cost
447
452
if self ._num_complete4incumbent >= 2 * self .dim and \
448
453
self ._num_allowed4incumbent == 0 :
@@ -483,6 +488,9 @@ def on_trial_result(self, trial_id: str, result: Dict):
483
488
self ._num_allowed4incumbent = 2 * self .dim
484
489
self ._proposed_by .clear ()
485
490
self ._iter_best_config = self .trial_count
491
+ cost = result .get (self .cost_attr )
492
+ # record the cost in case it is pruned and cost info is lost
493
+ self ._trial_cost [trial_id ] = cost
486
494
487
495
def rand_vector_unit_sphere (self , dim ) -> np .ndarray :
488
496
vec = self ._random .normal (0 , 1 , dim )
0 commit comments