From 0c10bfe6f9fa520a856a8c2654821822efe26c20 Mon Sep 17 00:00:00 2001 From: lukas Date: Tue, 6 Aug 2024 11:02:09 +0200 Subject: [PATCH] feat: onehot dataset as input --- src/action_rules/action_rules.py | 96 +++++++++++++++++++++++++++++++- tests/test_action_rules.py | 63 +++++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/src/action_rules/action_rules.py b/src/action_rules/action_rules.py index 2afe625..8610e87 100644 --- a/src/action_rules/action_rules.py +++ b/src/action_rules/action_rules.py @@ -117,6 +117,7 @@ def __init__( self.pd = None # type: Optional[ModuleType] self.is_gpu_np = False self.is_gpu_pd = False + self.is_onehot = False def count_max_nodes(self, stable_items_binding: dict, flexible_items_binding: dict) -> int: """ @@ -331,6 +332,98 @@ def one_hot_encode( data = self.pd.concat(to_concat, axis=1) # type: ignore return data + def fit_onehot( + self, + data: Union['cudf.DataFrame', 'pandas.DataFrame'], + stable_attributes: dict, + flexible_attributes: dict, + target: dict, + target_undesired_state: str, + target_desired_state: str, + use_sparse_matrix: bool = False, + use_gpu: bool = False, + ): + """ + Preprocess and fit the model using one-hot encoded attributes. + + This method prepares the dataset for generating action rules by + performing one-hot encoding on the specified stable, flexible, + and target attributes. The resulting dataset is then used to fit + the model using the `fit` method. + + Parameters + ---------- + data : Union[cudf.DataFrame, pandas.DataFrame] + The dataset to be processed and used for fitting the model. + stable_attributes : dict + A dictionary mapping stable attribute names to lists of column + names corresponding to those attributes. + flexible_attributes : dict + A dictionary mapping flexible attribute names to lists of column + names corresponding to those attributes. + target : dict + A dictionary mapping the target attribute name to a list of + column names corresponding to that attribute. + target_undesired_state : str + The undesired state of the target attribute, used in action rule generation. + target_desired_state : str + The desired state of the target attribute, used in action rule generation. + use_sparse_matrix : bool, optional + If True, a sparse matrix is used in the fitting process. Default is False. + use_gpu : bool, optional + If True, the GPU (cuDF) is used for data processing if available. + Default is False. + + Notes + ----- + The method modifies the dataset by: + 1. Renaming columns according to the stable, flexible, and target attributes. + 2. Removing columns that are not associated with any of these attributes. + 3. Passing the processed dataset and relevant attribute lists to the `fit` method + to generate action rules. + + This method ensures that the dataset is correctly preprocessed for rule + generation, focusing on the specified attributes and their one-hot encoded forms. + """ + self.is_onehot = True + new_labels = [] + attributes_stable = set([]) + attribtes_flexible = set([]) + attribute_target = '' + remove_cols = [] + for label in data.columns: + to_remove = True + for attribute, columns in stable_attributes.items(): + if label in columns: + new_labels.append(attribute + '__' + label) + attributes_stable.add(attribute) + to_remove = False + for attribute, columns in flexible_attributes.items(): + if label in columns: + new_labels.append(attribute + '__' + label) + attribtes_flexible.add(attribute) + to_remove = False + for attribute, columns in target.items(): + if label in columns: + new_labels.append(attribute + '__' + label) + attribute_target = attribute + to_remove = False + if to_remove: + new_labels.append(label) + remove_cols.append(label) + data.columns = new_labels + data = data.drop(columns=remove_cols) + self.fit( + data, + list(attributes_stable), + list(attribtes_flexible), + attribute_target, + target_undesired_state, + target_desired_state, + use_sparse_matrix, + use_gpu, + ) + def fit( self, data: Union['cudf.DataFrame', 'pandas.DataFrame'], @@ -377,7 +470,8 @@ def fit( if self.output is not None: raise RuntimeError("The model is already fit.") self.set_array_library(use_gpu, data) - data = self.one_hot_encode(data, stable_attributes, flexible_attributes, target) + if not self.is_onehot: + data = self.one_hot_encode(data, stable_attributes, flexible_attributes, target) data, columns = self.df_to_array(data, use_sparse_matrix) stable_items_binding, flexible_items_binding, target_items_binding, column_values = self.get_bindings( diff --git a/tests/test_action_rules.py b/tests/test_action_rules.py index 35f04f9..7041262 100644 --- a/tests/test_action_rules.py +++ b/tests/test_action_rules.py @@ -319,6 +319,69 @@ def test_fit_raises_error_when_already_fit(action_rules): ) +def test_fit_onehot(action_rules): + """ + Test the fit_onehot method. + + Parameters + ---------- + action_rules : ActionRules + The ActionRules instance to test. + + Asserts + ------- + Asserts that the fit_onehot method processes the data correctly and fits the model. + """ + df = pd.DataFrame( + { + 'young': [0, 1, 0, 0], + 'old': [1, 0, 1, 1], + 'high': [1, 1, 0, 0], + 'low': [0, 0, 1, 1], + 'animals': [1, 1, 1, 0], + 'toys': [0, 0, 1, 1], + 'no': [0, 0, 1, 1], + 'yes': [1, 1, 0, 0], + } + ) + + stable_attributes = {'age': ['young', 'old']} + flexible_attributes = {'income': ['high', 'low'], 'hobby': ['animals', 'toys']} + target = {'target': ['yes', 'no']} + + action_rules.fit_onehot( + data=df, + stable_attributes=stable_attributes, + flexible_attributes=flexible_attributes, + target=target, + target_undesired_state='no', + target_desired_state='yes', + use_sparse_matrix=False, + use_gpu=False, + ) + + # Check that the model has been fitted + assert action_rules.output is not None + assert isinstance(action_rules.output, Output) + + # Check if the columns were renamed correctly and irrelevant columns removed + expected_columns = [ + 'age__young', + 'age__old', + 'income__high', + 'income__low', + 'hobby__animals', + 'hobby__toys', + 'target__yes', + 'target__no', + ] + assert set(df.columns) == set(expected_columns) + + # Check if the correct attributes were passed to the fit method + assert action_rules.rules is not None + assert len(action_rules.rules.action_rules) > 0 # Rules should have been generated + + def test_get_rules(action_rules): """ Test the get_rules method.