Skip to content

Commit

Permalink
Merge pull request #16 from lukassykora/feature/onehot
Browse files Browse the repository at this point in the history
feat: onehot dataset as input
  • Loading branch information
lukassykora authored Aug 6, 2024
2 parents 8d27547 + 0c10bfe commit 625bfa1
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 1 deletion.
96 changes: 95 additions & 1 deletion src/action_rules/action_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
self.pd = None # type: Optional[ModuleType]
self.is_gpu_np = False
self.is_gpu_pd = False
self.is_onehot = False

def count_max_nodes(self, stable_items_binding: dict, flexible_items_binding: dict) -> int:
"""
Expand Down Expand Up @@ -331,6 +332,98 @@ def one_hot_encode(
data = self.pd.concat(to_concat, axis=1) # type: ignore
return data

def fit_onehot(
self,
data: Union['cudf.DataFrame', 'pandas.DataFrame'],
stable_attributes: dict,
flexible_attributes: dict,
target: dict,
target_undesired_state: str,
target_desired_state: str,
use_sparse_matrix: bool = False,
use_gpu: bool = False,
):
"""
Preprocess and fit the model using one-hot encoded attributes.
This method prepares the dataset for generating action rules by
performing one-hot encoding on the specified stable, flexible,
and target attributes. The resulting dataset is then used to fit
the model using the `fit` method.
Parameters
----------
data : Union[cudf.DataFrame, pandas.DataFrame]
The dataset to be processed and used for fitting the model.
stable_attributes : dict
A dictionary mapping stable attribute names to lists of column
names corresponding to those attributes.
flexible_attributes : dict
A dictionary mapping flexible attribute names to lists of column
names corresponding to those attributes.
target : dict
A dictionary mapping the target attribute name to a list of
column names corresponding to that attribute.
target_undesired_state : str
The undesired state of the target attribute, used in action rule generation.
target_desired_state : str
The desired state of the target attribute, used in action rule generation.
use_sparse_matrix : bool, optional
If True, a sparse matrix is used in the fitting process. Default is False.
use_gpu : bool, optional
If True, the GPU (cuDF) is used for data processing if available.
Default is False.
Notes
-----
The method modifies the dataset by:
1. Renaming columns according to the stable, flexible, and target attributes.
2. Removing columns that are not associated with any of these attributes.
3. Passing the processed dataset and relevant attribute lists to the `fit` method
to generate action rules.
This method ensures that the dataset is correctly preprocessed for rule
generation, focusing on the specified attributes and their one-hot encoded forms.
"""
self.is_onehot = True
new_labels = []
attributes_stable = set([])
attribtes_flexible = set([])
attribute_target = ''
remove_cols = []
for label in data.columns:
to_remove = True
for attribute, columns in stable_attributes.items():
if label in columns:
new_labels.append(attribute + '_<item_stable>_' + label)
attributes_stable.add(attribute)
to_remove = False
for attribute, columns in flexible_attributes.items():
if label in columns:
new_labels.append(attribute + '_<item_flexible>_' + label)
attribtes_flexible.add(attribute)
to_remove = False
for attribute, columns in target.items():
if label in columns:
new_labels.append(attribute + '_<item_target>_' + label)
attribute_target = attribute
to_remove = False
if to_remove:
new_labels.append(label)
remove_cols.append(label)
data.columns = new_labels
data = data.drop(columns=remove_cols)
self.fit(
data,
list(attributes_stable),
list(attribtes_flexible),
attribute_target,
target_undesired_state,
target_desired_state,
use_sparse_matrix,
use_gpu,
)

def fit(
self,
data: Union['cudf.DataFrame', 'pandas.DataFrame'],
Expand Down Expand Up @@ -377,7 +470,8 @@ def fit(
if self.output is not None:
raise RuntimeError("The model is already fit.")
self.set_array_library(use_gpu, data)
data = self.one_hot_encode(data, stable_attributes, flexible_attributes, target)
if not self.is_onehot:
data = self.one_hot_encode(data, stable_attributes, flexible_attributes, target)
data, columns = self.df_to_array(data, use_sparse_matrix)

stable_items_binding, flexible_items_binding, target_items_binding, column_values = self.get_bindings(
Expand Down
63 changes: 63 additions & 0 deletions tests/test_action_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,69 @@ def test_fit_raises_error_when_already_fit(action_rules):
)


def test_fit_onehot(action_rules):
"""
Test the fit_onehot method.
Parameters
----------
action_rules : ActionRules
The ActionRules instance to test.
Asserts
-------
Asserts that the fit_onehot method processes the data correctly and fits the model.
"""
df = pd.DataFrame(
{
'young': [0, 1, 0, 0],
'old': [1, 0, 1, 1],
'high': [1, 1, 0, 0],
'low': [0, 0, 1, 1],
'animals': [1, 1, 1, 0],
'toys': [0, 0, 1, 1],
'no': [0, 0, 1, 1],
'yes': [1, 1, 0, 0],
}
)

stable_attributes = {'age': ['young', 'old']}
flexible_attributes = {'income': ['high', 'low'], 'hobby': ['animals', 'toys']}
target = {'target': ['yes', 'no']}

action_rules.fit_onehot(
data=df,
stable_attributes=stable_attributes,
flexible_attributes=flexible_attributes,
target=target,
target_undesired_state='no',
target_desired_state='yes',
use_sparse_matrix=False,
use_gpu=False,
)

# Check that the model has been fitted
assert action_rules.output is not None
assert isinstance(action_rules.output, Output)

# Check if the columns were renamed correctly and irrelevant columns removed
expected_columns = [
'age_<item_stable>_young',
'age_<item_stable>_old',
'income_<item_flexible>_high',
'income_<item_flexible>_low',
'hobby_<item_flexible>_animals',
'hobby_<item_flexible>_toys',
'target_<item_target>_yes',
'target_<item_target>_no',
]
assert set(df.columns) == set(expected_columns)

# Check if the correct attributes were passed to the fit method
assert action_rules.rules is not None
assert len(action_rules.rules.action_rules) > 0 # Rules should have been generated


def test_get_rules(action_rules):
"""
Test the get_rules method.
Expand Down

0 comments on commit 625bfa1

Please sign in to comment.