diff --git a/src/action_rules/action_rules.py b/src/action_rules/action_rules.py index f28da5c..d22c487 100644 --- a/src/action_rules/action_rules.py +++ b/src/action_rules/action_rules.py @@ -3,7 +3,7 @@ import itertools import warnings from collections import defaultdict -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union # noqa from .candidates.candidate_generator import CandidateGenerator from .output.output import Output @@ -347,6 +347,11 @@ def fit( Use GPU (cuDF) for data processing if available. use_sparse_matrix : bool, optional If True, rhe sparse matrix is used. Default is False. + + Raises + ------ + RuntimeError + If the model has already been fitted. """ if self.output is not None: raise RuntimeError("The model is already fit.") @@ -534,16 +539,81 @@ def get_split_tables( frames[item] = data[:, mask] return frames - def get_rules(self) -> Optional[Output]: + def get_rules(self) -> Output: """ Return the generated action rules if available. + Raises + ------ + RuntimeError + If the model has not been fitted. + Returns ------- Optional[Output] The generated action rules, or None if no rules have been generated. """ if self.output is None: - return None - else: - return self.output + raise RuntimeError("The model is not fit.") + return self.output + + def predict(self, frame_row: Union['cudf.Series', 'pandas.Series']) -> Union['cudf.DataFrame', 'pandas.DataFrame']: + """ + Predict recommended actions based on the provided row of data. + + This method applies the fitted action rules to the given row of data and generates + a DataFrame with recommended actions if any of the action rules are triggered. + + Parameters + ---------- + frame_row : Union['cudf.Series', 'pandas.Series'] + A row of data in the form of a cuDF or pandas Series. The Series should + contain the features required by the action rules. + + Returns + ------- + Union['cudf.DataFrame', 'pandas.DataFrame'] + A DataFrame with the recommended actions. The DataFrame includes the following columns: + - The original attributes with recommended changes. + - 'ActionRules_RuleIndex': Index of the action rule applied. + - 'ActionRules_UndesiredSupport': Support of the undesired part of the rule. + - 'ActionRules_DesiredSupport': Support of the desired part of the rule. + - 'ActionRules_UndesiredConfidence': Confidence of the undesired part of the rule. + - 'ActionRules_DesiredConfidence': Confidence of the desired part of the rule. + - 'ActionRules_Uplift': Uplift value of the rule. + + Raises + ------ + RuntimeError + If the model has not been fitted. + + Notes + ----- + The method compares the given row of data against the undesired itemsets of the action rules. + If a match is found, it applies the desired itemset changes and records the action rule's + metadata. The result is a DataFrame with one or more rows representing the recommended actions + for the given data. + """ + if self.output is None: + raise RuntimeError("The model is not fit.") + index_value_tuples = list(zip(frame_row.index, frame_row)) + values = [] + column_values = self.output.column_values + for index_value_tuple in index_value_tuples: + values.append(list(column_values.keys())[list(column_values.values()).index(index_value_tuple)]) + new_values = tuple(values) + predicted = [] + for i, action_rule in enumerate(self.output.action_rules): + if set(action_rule['undesired']['itemset']) <= set(new_values): + predicted_row = frame_row.copy() + for recommended in set(action_rule['desired']['itemset']) - set(new_values): + attribute, value = column_values[recommended] + predicted_row[attribute + ' (Recommended)'] = value + predicted_row['ActionRules_RuleIndex'] = i + predicted_row['ActionRules_UndesiredSupport'] = action_rule['undesired']['support'] + predicted_row['ActionRules_DesiredSupport'] = action_rule['desired']['support'] + predicted_row['ActionRules_UndesiredConfidence'] = action_rule['undesired']['confidence'] + predicted_row['ActionRules_DesiredConfidence'] = action_rule['desired']['confidence'] + predicted_row['ActionRules_Uplift'] = action_rule['uplift'] + predicted.append(predicted_row) + return self.pd.DataFrame(predicted) # type: ignore diff --git a/tests/test_action_rules.py b/tests/test_action_rules.py index 4029bd0..6d86c0a 100644 --- a/tests/test_action_rules.py +++ b/tests/test_action_rules.py @@ -309,7 +309,52 @@ def test_get_rules(action_rules): ------- Asserts that the generated rules are correctly returned. """ - assert action_rules.get_rules() is None + with pytest.raises(RuntimeError, match="The model is not fit."): + assert action_rules.get_rules() is None action_rules.output = MagicMock() assert action_rules.get_rules() is not None assert action_rules.get_rules() == action_rules.output + + +def test_predict(action_rules): + """ + Test the predict method of the ActionRules class. + + Parameters + ---------- + action_rules : ActionRules + The ActionRules instance to test. + + Asserts + ------- + Asserts that the prediction works correctly and returns the expected DataFrame. + """ + frame_row = pd.Series({'stable': 'a', 'flexible': 'z'}) + with pytest.raises(RuntimeError, match="The model is not fit."): + action_rules.predict(frame_row) + df = pd.DataFrame({'stable': ['a', 'b', 'a'], 'flexible': ['x', 'y', 'z'], 'target': ['yes', 'no', 'no']}) + action_rules.fit( + df, + stable_attributes=['stable'], + flexible_attributes=['flexible'], + target='target', + target_undesired_state='no', + target_desired_state='yes', + ) + result = action_rules.predict(frame_row) + assert not result.empty + assert 'flexible (Recommended)' in result.columns + assert 'ActionRules_RuleIndex' in result.columns + assert 'ActionRules_UndesiredSupport' in result.columns + assert 'ActionRules_DesiredSupport' in result.columns + assert 'ActionRules_UndesiredConfidence' in result.columns + assert 'ActionRules_DesiredConfidence' in result.columns + assert 'ActionRules_Uplift' in result.columns + + assert result.iloc[0]['flexible (Recommended)'] == 'x' + assert result.iloc[0]['ActionRules_RuleIndex'] == 0 + assert result.iloc[0]['ActionRules_UndesiredSupport'] == 1 + assert result.iloc[0]['ActionRules_DesiredSupport'] == 1 + assert result.iloc[0]['ActionRules_UndesiredConfidence'] == 1.0 + assert result.iloc[0]['ActionRules_DesiredConfidence'] == 1.0 + assert result.iloc[0]['ActionRules_Uplift'] == 1.0