Skip to content

Commit

Permalink
Update transforms (#262)
Browse files Browse the repository at this point in the history
* Fix drop level

* Add mixed categorical method
  • Loading branch information
marcosfelt authored Jul 3, 2023
1 parent 361186e commit 3df83ef
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions summit/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ def transform_inputs_outputs(self, ds: DataSet, **kwargs):
self.input_means, self.input_stds = {}, {}
self.output_means, self.output_stds = {}, {}
self.encoders = {}
for variable in self.transform_domain.input_variables:
for variable in self.domain.input_variables:
if (
isinstance(variable, CategoricalVariable)
and categorical_method == "descriptors"
and (categorical_method == "descriptors" or (categorical_method == "mixed" and variable.ds is not None))
):
# Add descriptors to the dataset
var_descriptor_names = variable.ds.data_columns
Expand All @@ -118,7 +118,7 @@ def transform_inputs_outputs(self, ds: DataSet, **kwargs):
]
for ixc in ix_code:
column_codes_2[ixc] = 0
new_ds.columns.set_codes(column_codes_2, level=1, inplace=True)
new_ds.columns = new_ds.columns.set_codes(column_codes_2, level=1)
else:
indices = new_ds[variable.name].values
descriptors = variable.ds.loc[indices]
Expand All @@ -131,7 +131,7 @@ def transform_inputs_outputs(self, ds: DataSet, **kwargs):
column_codes_2 = list(new_ds.columns.codes[1])
ix_code = np.where(new_ds.columns.codes[0] == ix)[0][0]
column_codes_2[ix_code] = 1
new_ds.columns.set_codes(column_codes_2, level=1, inplace=True)
new_ds.columns = new_ds.columns.set_codes(column_codes_2, level=1)

# Normalize descriptors between 0 and 1
if min_max_scale_inputs:
Expand All @@ -146,7 +146,7 @@ def transform_inputs_outputs(self, ds: DataSet, **kwargs):
input_columns.extend(var_descriptor_names)
elif (
isinstance(variable, CategoricalVariable)
and categorical_method == "one-hot"
and (categorical_method == "one-hot" or categorical_method == "mixed" and variable.ds is None)
):
# Create one-hot encoding columns & insert to DataSet
enc = OneHotEncoder(categories=[variable.levels])
Expand All @@ -159,7 +159,7 @@ def transform_inputs_outputs(self, ds: DataSet, **kwargs):
self.encoders[variable.name] = enc

# Drop old categorical column, then write as metadata
new_ds = new_ds.drop(variable.name, axis=1)
new_ds = new_ds.drop(variable.name, axis=1, level=0)
new_ds[variable.name, "METADATA"] = values
elif (
isinstance(variable, CategoricalVariable) and categorical_method == None
Expand All @@ -184,7 +184,7 @@ def transform_inputs_outputs(self, ds: DataSet, **kwargs):
f"Variable {variable.name} is not a continuous or categorical variable."
)

for variable in self.transform_domain.output_variables:
for variable in self.domain.output_variables:
if variable.name in data_columns and variable.is_objective:
if isinstance(variable, CategoricalVariable):
raise DomainError(
Expand Down Expand Up @@ -262,11 +262,11 @@ def un_transform(self, ds, **kwargs):

# Determine input and output columns in dataset
new_ds = ds.copy()
for i, variable in enumerate(self.transform_domain.input_variables):
for i, variable in enumerate(self.domain.input_variables):
# Categorical variables with descriptors
if (
isinstance(variable, CategoricalVariable)
and categorical_method == "descriptors"
and ((categorical_method == "descriptors") or (categorical_method == "mixed" and variable.ds is not None))
):
var_descriptor_names = variable.ds.data_columns
# Unnormalize descriptors between 0 and 1
Expand Down Expand Up @@ -318,7 +318,7 @@ def un_transform(self, ds, **kwargs):
# Categorical variables using one-hot encoding
elif (
isinstance(variable, CategoricalVariable)
and categorical_method == "one-hot"
and ((categorical_method == "one-hot") or (categorical_method == "mixed" and variable.ds is None))
):
# Get one-hot encoder
enc = self.encoders[variable.name]
Expand All @@ -331,7 +331,7 @@ def un_transform(self, ds, **kwargs):
values = enc.inverse_transform(one_hot)

# Add to dataset and drop one-hot encoding
new_ds = new_ds.drop(one_hot_names, axis=1)
new_ds = new_ds.drop(one_hot_names, axis=1, level=0)
new_ds[variable.name, "DATA"] = values
# Plain categorical variables
elif isinstance(variable, CategoricalVariable):
Expand All @@ -352,7 +352,7 @@ def un_transform(self, ds, **kwargs):
else:
raise DomainError(f"Variable {variable.name} is not in the dataset.")

for variable in self.transform_domain.output_variables:
for variable in self.domain.output_variables:
if variable.name in data_columns and variable.is_objective:
if standardize_outputs:
mean = self.output_means[variable.name]
Expand Down

0 comments on commit 3df83ef

Please sign in to comment.