Skip to content

Commit

Permalink
Merge pull request #8 from lukassykora/predict
Browse files Browse the repository at this point in the history
Predict
  • Loading branch information
lukassykora authored Jul 22, 2024
2 parents e8ccdaf + f96bbc5 commit 2c75e42
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 6 deletions.
80 changes: 75 additions & 5 deletions src/action_rules/action_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional, Union # noqa

from .candidates.candidate_generator import CandidateGenerator
from .output.output import Output
Expand Down Expand Up @@ -347,6 +347,11 @@ def fit(
Use GPU (cuDF) for data processing if available.
use_sparse_matrix : bool, optional
If True, rhe sparse matrix is used. Default is False.
Raises
------
RuntimeError
If the model has already been fitted.
"""
if self.output is not None:
raise RuntimeError("The model is already fit.")
Expand Down Expand Up @@ -534,16 +539,81 @@ def get_split_tables(
frames[item] = data[:, mask]
return frames

def get_rules(self) -> Optional[Output]:
def get_rules(self) -> Output:
"""
Return the generated action rules if available.
Raises
------
RuntimeError
If the model has not been fitted.
Returns
-------
Optional[Output]
The generated action rules, or None if no rules have been generated.
"""
if self.output is None:
return None
else:
return self.output
raise RuntimeError("The model is not fit.")
return self.output

def predict(self, frame_row: Union['cudf.Series', 'pandas.Series']) -> Union['cudf.DataFrame', 'pandas.DataFrame']:
"""
Predict recommended actions based on the provided row of data.
This method applies the fitted action rules to the given row of data and generates
a DataFrame with recommended actions if any of the action rules are triggered.
Parameters
----------
frame_row : Union['cudf.Series', 'pandas.Series']
A row of data in the form of a cuDF or pandas Series. The Series should
contain the features required by the action rules.
Returns
-------
Union['cudf.DataFrame', 'pandas.DataFrame']
A DataFrame with the recommended actions. The DataFrame includes the following columns:
- The original attributes with recommended changes.
- 'ActionRules_RuleIndex': Index of the action rule applied.
- 'ActionRules_UndesiredSupport': Support of the undesired part of the rule.
- 'ActionRules_DesiredSupport': Support of the desired part of the rule.
- 'ActionRules_UndesiredConfidence': Confidence of the undesired part of the rule.
- 'ActionRules_DesiredConfidence': Confidence of the desired part of the rule.
- 'ActionRules_Uplift': Uplift value of the rule.
Raises
------
RuntimeError
If the model has not been fitted.
Notes
-----
The method compares the given row of data against the undesired itemsets of the action rules.
If a match is found, it applies the desired itemset changes and records the action rule's
metadata. The result is a DataFrame with one or more rows representing the recommended actions
for the given data.
"""
if self.output is None:
raise RuntimeError("The model is not fit.")
index_value_tuples = list(zip(frame_row.index, frame_row))
values = []
column_values = self.output.column_values
for index_value_tuple in index_value_tuples:
values.append(list(column_values.keys())[list(column_values.values()).index(index_value_tuple)])
new_values = tuple(values)
predicted = []
for i, action_rule in enumerate(self.output.action_rules):
if set(action_rule['undesired']['itemset']) <= set(new_values):
predicted_row = frame_row.copy()
for recommended in set(action_rule['desired']['itemset']) - set(new_values):
attribute, value = column_values[recommended]
predicted_row[attribute + ' (Recommended)'] = value
predicted_row['ActionRules_RuleIndex'] = i
predicted_row['ActionRules_UndesiredSupport'] = action_rule['undesired']['support']
predicted_row['ActionRules_DesiredSupport'] = action_rule['desired']['support']
predicted_row['ActionRules_UndesiredConfidence'] = action_rule['undesired']['confidence']
predicted_row['ActionRules_DesiredConfidence'] = action_rule['desired']['confidence']
predicted_row['ActionRules_Uplift'] = action_rule['uplift']
predicted.append(predicted_row)
return self.pd.DataFrame(predicted) # type: ignore
47 changes: 46 additions & 1 deletion tests/test_action_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,52 @@ def test_get_rules(action_rules):
-------
Asserts that the generated rules are correctly returned.
"""
assert action_rules.get_rules() is None
with pytest.raises(RuntimeError, match="The model is not fit."):
assert action_rules.get_rules() is None
action_rules.output = MagicMock()
assert action_rules.get_rules() is not None
assert action_rules.get_rules() == action_rules.output


def test_predict(action_rules):
"""
Test the predict method of the ActionRules class.
Parameters
----------
action_rules : ActionRules
The ActionRules instance to test.
Asserts
-------
Asserts that the prediction works correctly and returns the expected DataFrame.
"""
frame_row = pd.Series({'stable': 'a', 'flexible': 'z'})
with pytest.raises(RuntimeError, match="The model is not fit."):
action_rules.predict(frame_row)
df = pd.DataFrame({'stable': ['a', 'b', 'a'], 'flexible': ['x', 'y', 'z'], 'target': ['yes', 'no', 'no']})
action_rules.fit(
df,
stable_attributes=['stable'],
flexible_attributes=['flexible'],
target='target',
target_undesired_state='no',
target_desired_state='yes',
)
result = action_rules.predict(frame_row)
assert not result.empty
assert 'flexible (Recommended)' in result.columns
assert 'ActionRules_RuleIndex' in result.columns
assert 'ActionRules_UndesiredSupport' in result.columns
assert 'ActionRules_DesiredSupport' in result.columns
assert 'ActionRules_UndesiredConfidence' in result.columns
assert 'ActionRules_DesiredConfidence' in result.columns
assert 'ActionRules_Uplift' in result.columns

assert result.iloc[0]['flexible (Recommended)'] == 'x'
assert result.iloc[0]['ActionRules_RuleIndex'] == 0
assert result.iloc[0]['ActionRules_UndesiredSupport'] == 1
assert result.iloc[0]['ActionRules_DesiredSupport'] == 1
assert result.iloc[0]['ActionRules_UndesiredConfidence'] == 1.0
assert result.iloc[0]['ActionRules_DesiredConfidence'] == 1.0
assert result.iloc[0]['ActionRules_Uplift'] == 1.0

0 comments on commit 2c75e42

Please sign in to comment.