From b731fd2e041fe5976dcd8a4c1b104bcfcad3956c Mon Sep 17 00:00:00 2001 From: Denis Barbier Date: Fri, 3 Jan 2020 12:09:59 +0100 Subject: [PATCH] Replace _support function For unknbown reasons, np.sum is slow on a very large boolean array. --- mlxtend/frequent_patterns/apriori.py | 41 ++++++++-------------------- 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/mlxtend/frequent_patterns/apriori.py b/mlxtend/frequent_patterns/apriori.py index fb5c78ff4..1c7cf086e 100644 --- a/mlxtend/frequent_patterns/apriori.py +++ b/mlxtend/frequent_patterns/apriori.py @@ -121,32 +121,6 @@ def apriori(df, min_support=0.5, use_colnames=False, max_len=None, verbose=0, """ - def _support(_x, _n_rows, _is_sparse): - """DRY private method to calculate support as the - row-wise sum of values / number of rows - - Parameters - ----------- - - _x : matrix of bools or binary - - _n_rows : numeric, number of rows in _x - - _is_sparse : bool True if _x is sparse - - Returns - ----------- - np.array, shape = (n_rows, ) - - Examples - ----------- - For usage examples, please see - http://rasbt.github.io/mlxtend/user_guide/frequent_patterns/apriori/ - - """ - out = (np.sum(_x, axis=0) / _n_rows) - return np.array(out).reshape(-1) - if min_support <= 0.: raise ValueError('`min_support` must be a positive ' 'number within the interval `(0, 1]`. ' @@ -180,7 +154,17 @@ def _support(_x, _n_rows, _is_sparse): # dense DataFrame X = df.values is_sparse = False - support = _support(X, X.shape[0], is_sparse) + if is_sparse: + # Count nonnull entries via direct access to X indices; + # this requires X to be stored in CSC format, and to call + # X.eliminate_zeros() to remove null entries from X. + support = np.array([X.indptr[idx+1] - X.indptr[idx] + for idx in range(X.shape[1])], dtype=int) + else: + # Faster than np.count_nonzero(X, axis=0) or np.sum(X, axis=0), why? + support = np.array([np.count_nonzero(X[:, idx]) + for idx in range(X.shape[1])], dtype=int) + support = support / X.shape[0] support_dict = {1: support[support >= min_support]} itemset_dict = {1: [(idx,) for idx in np.where(support >= min_support)[0]]} max_itemset = 1 @@ -199,9 +183,6 @@ def _support(_x, _n_rows, _is_sparse): processed += 1 count[:] = 0 for item in itemset: - # Count nonnull entries via direct access to X indices; - # this requires X to be stored in CSC format, and to call - # X.eliminate_zeros() to remove null entries from X. count[X.indices[X.indptr[item]:X.indptr[item+1]]] += 1 support = np.count_nonzero(count == len(itemset)) / X.shape[0] if support >= min_support: