diff --git a/CHANGELOG.md b/CHANGELOG.md index 2929be4..28936ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -135,3 +135,7 @@ ## [1.0.3] - 2024-07-31 * Comparison with another package. + +## [1.0.4] - 2024-07-04 + +* Fix predict method for cupy diff --git a/src/action_rules/action_rules.py b/src/action_rules/action_rules.py index f260d26..6c94f81 100644 --- a/src/action_rules/action_rules.py +++ b/src/action_rules/action_rules.py @@ -634,7 +634,11 @@ def predict(self, frame_row: Union['cudf.Series', 'pandas.Series']) -> Union['cu values.append(list(column_values.keys())[list(column_values.values()).index(index_value_tuple)]) new_values = tuple(values) predicted = [] - for i, action_rule in enumerate(self.output.action_rules): + if self.is_gpu_np: + action_rules = self.output.action_rules.get() + else: + action_rules = self.output.action_rules + for i, action_rule in action_rules: if set(action_rule['undesired']['itemset']) <= set(new_values): predicted_row = frame_row.copy() for recommended in set(action_rule['desired']['itemset']) - set(new_values):