diff --git a/summit/strategies/base.py b/summit/strategies/base.py index 1b7860aa..b17e9f55 100644 --- a/summit/strategies/base.py +++ b/summit/strategies/base.py @@ -104,35 +104,12 @@ def transform_inputs_outputs(self, ds: DataSet, **kwargs): ): # Add descriptors to the dataset var_descriptor_names = variable.ds.data_columns - if all( - np.isin(var_descriptor_names, new_ds.columns.levels[0].to_list()) - ): - # Make the descriptors columns a metadata column - column_list_1 = new_ds.columns.levels[0].to_list() - ix = [ - column_list_1.index(d_name) for d_name in var_descriptor_names - ] - column_codes_2 = list(new_ds.columns.codes[1]) - ix_code = [ - np.where(new_ds.columns.codes[0] == tmp_ix)[0][0] - for tmp_ix in ix - ] - for ixc in ix_code: - column_codes_2[ixc] = 0 - new_ds.columns = new_ds.columns.set_codes(column_codes_2, level=1) - else: - indices = new_ds[variable.name].values - descriptors = variable.ds.loc[indices] - descriptors.index = new_ds.index - new_ds = new_ds.join(descriptors, how="inner") - - # Make the original descriptors column a metadata column - column_list_1 = new_ds.columns.levels[0].to_list() - ix = column_list_1.index(variable.name) - column_codes_2 = list(new_ds.columns.codes[1]) - ix_code = np.where(new_ds.columns.codes[0] == ix)[0][0] - column_codes_2[ix_code] = 1 - new_ds.columns = new_ds.columns.set_codes(column_codes_2, level=1) + new_ds = new_ds.merge(variable.ds, left_on=variable.name, right_index=True, how="left") + + # Make the original categorical column a metadata column + original_categorical = new_ds[variable.name].copy() + new_ds = new_ds.drop((variable.name, "DATA"), axis=1) + new_ds[variable.name, "METADATA"] = original_categorical # Normalize descriptors between 0 and 1 if min_max_scale_inputs: