From 33b7b87a02c2b409eced8e00e4ff01f5e559d360 Mon Sep 17 00:00:00 2001 From: Kobi Felton Date: Tue, 4 Jul 2023 23:42:19 +0100 Subject: [PATCH] Deprecate LogSpaceObjectives and add LogTransform (#263) --- summit/strategies/base.py | 114 ++++++++++++++++++++++++++++---------- 1 file changed, 86 insertions(+), 28 deletions(-) diff --git a/summit/strategies/base.py b/summit/strategies/base.py index 046a8322..1b7860aa 100644 --- a/summit/strategies/base.py +++ b/summit/strategies/base.py @@ -1,3 +1,4 @@ +import warnings from summit.domain import * from summit.utils.dataset import DataSet @@ -6,7 +7,7 @@ from abc import ABC, abstractmethod, abstractclassmethod -from typing import Type, Tuple +from typing import List, Type, Tuple import json __all__ = [ @@ -367,7 +368,7 @@ def un_transform(self, ds, **kwargs): return new_ds def to_dict(self, **kwargs): - """ Output a dictionary representation of the transform""" + """Output a dictionary representation of the transform""" return dict( transform_domain=self.transform_domain.to_dict(), name=self.__class__.__name__, @@ -514,7 +515,7 @@ def transform_inputs_outputs(self, ds, **kwargs): return inputs, outputs def to_dict(self): - """ Output a dictionary representation of the transform""" + """Output a dictionary representation of the transform""" transform_params = dict(expression=self.expression, maximize=self.maximize) d = super().to_dict(**transform_params) return d @@ -558,22 +559,23 @@ class LogSpaceObjectives(Transform): def __init__(self, domain: Domain): super().__init__(domain) - objectives = [ - (i, v) - for i, v in enumerate(self.transform_domain.variables) - if v.is_objective - ] - # Check that the domain has objectives - num_objectives = len(objectives) - if num_objectives == 0: - raise ValueError( - f"The domain must have objectives. Currently has {num_objectives} objectives." - ) + warnings.warn( + "This class will be deprecated in a future version of summit.", + DeprecationWarning, + stacklevel=2, + ) + + self.to_transform = [v.name for v in self.transform_domain.output_variables] - # Rename objectives in new domain - for i, v in objectives: + # Rename variables and set new bounds in transformeed domain + for name in self.to_transform: + v = self.transform_domain[name] + if v is None: + raise ValueError(f"Variable {name} not found in domain.") v.name = "log_" + v.name + v._lower_bound = np.log(v.bounds[0]) + v._upper_bound = np.log(v.bounds[1]) def transform_inputs_outputs(self, ds, **kwargs): """Transform of data into inputs and outptus for a strategy @@ -594,15 +596,70 @@ def transform_inputs_outputs(self, ds, **kwargs): inputs, outputs Datasets with the input and output datasets """ - inputs, outputs = super().transform_inputs_outputs(ds, **kwargs) - if (outputs.any() < 0).any(): - raise ValueError("Cannot complete log transform for values less than zero.") - outputs = outputs.apply(np.log) - columns = [v.name for v in self.transform_domain.variables if v.is_objective] - outputs = DataSet(outputs.data_to_numpy(), columns=columns) - return inputs, outputs + for name in self.to_transform: + ds.loc[:, ("log_" + name, "DATA")] = np.log(ds[name].astype(float).values) + return super().transform_inputs_outputs(ds, **kwargs) - def un_transform(self, ds, **kwargs): + def un_transform(self, ds: DataSet, **kwargs): + """Untransform objectives from log space + + Parameters + ---------- + ds: `DataSet` + Dataset with columns corresponding to the inputs and objectives of the domain. + copy: bool, optional + Copy the dataset internally. Defaults to True. + transform_descriptors: bool, optional + Transform the descriptors into continuous variables. Default True. + """ + for name in self.to_transform: + if not "log_" + name in ds.data_columns: + continue + ds.loc[:, (name, "DATA")] = np.exp(ds["log_" + name].astype(float).values) + ds = super().un_transform(ds, **kwargs) + return ds + + +class LogTransform(Transform): + def __init__(self, domain: Domain, to_transform: List[str]): + super().__init__(domain) + + self.to_transform = to_transform + + # Rename variables and set new bounds in transformeed domain + for name in self.to_transform: + v = self.transform_domain[name] + if v is None: + raise ValueError(f"Variable {name} not found in domain.") + v.name = "log_" + v.name + v._lower_bound = np.log(v.bounds[0]) + v._upper_bound = np.log(v.bounds[1]) + # v.bounds = (np.log(v.bounds[0]), np.log(v.bounds[1])) + + def transform_inputs_outputs(self, ds, **kwargs): + """Transform of data into inputs and outptus for a strategy + + This will do a log transform on the objectives (outputs). + + Parameters + ---------- + ds: `DataSet` + Dataset with columns corresponding to the inputs and objectives of the domain. + copy: bool, optional + Copy the dataset internally. Defaults to True. + transform_descriptors: bool, optional + Transform the descriptors into continuous variables. Default True. + + Returns + ------- + inputs, outputs + Datasets with the input and output datasets + """ + for name in self.to_transform: + ds.loc[:, ("log_" + name, "DATA")] = np.log(ds[name].astype(float).values) + return super().transform_inputs_outputs(ds, **kwargs) + + def un_transform(self, ds: DataSet, **kwargs): """Untransform objectives from log space Parameters @@ -614,10 +671,11 @@ def un_transform(self, ds, **kwargs): transform_descriptors: bool, optional Transform the descriptors into continuous variables. Default True. """ + for name in self.to_transform: + if not "log_" + name in ds.data_columns: + continue + ds.loc[:, (name, "DATA")] = np.exp(ds["log_" + name].astype(float).values) ds = super().un_transform(ds, **kwargs) - for v in self.domain.variables: - if v.is_objective and ds.get("log_" + v.name): - ds[v.name] = np.exp(ds["log_" + v.name]) return ds @@ -1195,7 +1253,7 @@ def _closest_point_index(design_point, candidate_matrix): def _design_distances(design_point, candidate_matrix): - """ Return the distances between a design_point and all candidates""" + """Return the distances between a design_point and all candidates""" diff = design_point - candidate_matrix squared = np.power(diff, 2) summed = np.sum(squared, axis=1)