From 9225a27caf96fbc222d11ce4d20318148c30d39d Mon Sep 17 00:00:00 2001 From: lukas Date: Sun, 4 Aug 2024 20:15:37 +0200 Subject: [PATCH 1/3] fix: predict method for cupy --- CHANGELOG.md | 4 ++++ src/action_rules/action_rules.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2929be4..28936ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -135,3 +135,7 @@ ## [1.0.3] - 2024-07-31 * Comparison with another package. + +## [1.0.4] - 2024-07-04 + +* Fix predict method for cupy diff --git a/src/action_rules/action_rules.py b/src/action_rules/action_rules.py index f260d26..6c94f81 100644 --- a/src/action_rules/action_rules.py +++ b/src/action_rules/action_rules.py @@ -634,7 +634,11 @@ def predict(self, frame_row: Union['cudf.Series', 'pandas.Series']) -> Union['cu values.append(list(column_values.keys())[list(column_values.values()).index(index_value_tuple)]) new_values = tuple(values) predicted = [] - for i, action_rule in enumerate(self.output.action_rules): + if self.is_gpu_np: + action_rules = self.output.action_rules.get() + else: + action_rules = self.output.action_rules + for i, action_rule in action_rules: if set(action_rule['undesired']['itemset']) <= set(new_values): predicted_row = frame_row.copy() for recommended in set(action_rule['desired']['itemset']) - set(new_values): From 1758a7728986adb85789ace4d735ff36262f9b19 Mon Sep 17 00:00:00 2001 From: lukas Date: Sun, 4 Aug 2024 20:18:56 +0200 Subject: [PATCH 2/3] fix: missing enumerate --- src/action_rules/action_rules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/action_rules/action_rules.py b/src/action_rules/action_rules.py index 6c94f81..881f536 100644 --- a/src/action_rules/action_rules.py +++ b/src/action_rules/action_rules.py @@ -638,7 +638,7 @@ def predict(self, frame_row: Union['cudf.Series', 'pandas.Series']) -> Union['cu action_rules = self.output.action_rules.get() else: action_rules = self.output.action_rules - for i, action_rule in action_rules: + for i, action_rule in enumerate(action_rules): if set(action_rule['undesired']['itemset']) <= set(new_values): predicted_row = frame_row.copy() for recommended in set(action_rule['desired']['itemset']) - set(new_values): From 1c39538716b63d06417594a283961e1471391f8a Mon Sep 17 00:00:00 2001 From: lukas Date: Sun, 4 Aug 2024 20:41:09 +0200 Subject: [PATCH 3/3] fix: support return value --- src/action_rules/action_rules.py | 6 +----- src/action_rules/candidates/candidate_generator.py | 6 +++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/action_rules/action_rules.py b/src/action_rules/action_rules.py index 881f536..f260d26 100644 --- a/src/action_rules/action_rules.py +++ b/src/action_rules/action_rules.py @@ -634,11 +634,7 @@ def predict(self, frame_row: Union['cudf.Series', 'pandas.Series']) -> Union['cu values.append(list(column_values.keys())[list(column_values.values()).index(index_value_tuple)]) new_values = tuple(values) predicted = [] - if self.is_gpu_np: - action_rules = self.output.action_rules.get() - else: - action_rules = self.output.action_rules - for i, action_rule in enumerate(action_rules): + for i, action_rule in enumerate(self.output.action_rules): if set(action_rule['undesired']['itemset']) <= set(new_values): predicted_row = frame_row.copy() for recommended in set(action_rule['desired']['itemset']) - set(new_values): diff --git a/src/action_rules/candidates/candidate_generator.py b/src/action_rules/candidates/candidate_generator.py index 21c69a9..a5d86ab 100644 --- a/src/action_rules/candidates/candidate_generator.py +++ b/src/action_rules/candidates/candidate_generator.py @@ -432,7 +432,7 @@ def get_support( - Ensure that the `item` index is within the bounds of the frame's rows. - For sparse matrices, the sum is computed efficiently by leveraging sparse matrix operations. """ - return frame[item].sum() + return int(frame[item].sum()) def process_flexible_candidates( self, @@ -593,8 +593,8 @@ def process_items( if self.in_stop_list(itemset_prefix + (item,), stop_list_itemset): continue - undesired_support = undesired_frame[item].sum() - desired_support = desired_frame[item].sum() + undesired_support = self.get_support(undesired_frame, item) + desired_support = self.get_support(desired_frame, item) if verbose: print('SUPPORT for: ' + str(itemset_prefix + (item,)))