diff --git a/docs/sources/user_guide/frequent_patterns/association_rules.ipynb b/docs/sources/user_guide/frequent_patterns/association_rules.ipynb
index 4a4156fda..99519183c 100644
--- a/docs/sources/user_guide/frequent_patterns/association_rules.ipynb
+++ b/docs/sources/user_guide/frequent_patterns/association_rules.ipynb
@@ -2418,13 +2418,16 @@
},
{
"cell_type": "code",
+
"execution_count": 20,
+
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
+
"/tmp/ipykernel_34953/2823279667.py:23: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'nan' has dtype incompatible with bool, please explicitly cast to a compatible dtype first.\n",
" df.iloc[idx[i], col[i]] = np.nan\n",
"/tmp/ipykernel_34953/2823279667.py:23: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'nan' has dtype incompatible with bool, please explicitly cast to a compatible dtype first.\n",
@@ -2438,6 +2441,7 @@
"/tmp/ipykernel_34953/2823279667.py:23: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'nan' has dtype incompatible with bool, please explicitly cast to a compatible dtype first.\n",
" df.iloc[idx[i], col[i]] = np.nan\n",
"/tmp/ipykernel_34953/2823279667.py:23: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'nan' has dtype incompatible with bool, please explicitly cast to a compatible dtype first.\n",
+
" df.iloc[idx[i], col[i]] = np.nan\n"
]
},
@@ -2489,6 +2493,7 @@
"
True | \n",
" False | \n",
" NaN | \n",
+
" \n",
" \n",
" 1 | \n",
@@ -2710,6 +2715,7 @@
]
},
"execution_count": 21,
+
"metadata": {},
"output_type": "execute_result"
}
@@ -2718,6 +2724,7 @@
"frequent_itemsets = fpgrowth(df, min_support=0.6, null_values = True, use_colnames=True)\n",
"# frequent_itemsets = fpmax(df, min_support=0.6, null_values = True, use_colnames=True)\n",
"rules = association_rules(frequent_itemsets, len(df), df, null_values = True, metric=\"confidence\", min_threshold=0.8)\n",
+
"rules"
]
},
diff --git a/mlxtend/frequent_patterns/association_rules.py b/mlxtend/frequent_patterns/association_rules.py
index c3ca9c249..ca4594318 100644
--- a/mlxtend/frequent_patterns/association_rules.py
+++ b/mlxtend/frequent_patterns/association_rules.py
@@ -34,7 +34,7 @@
def association_rules(
df: pd.DataFrame,
- num_itemsets: int,
+ num_itemsets: Optional[int] = 1,
df_orig: Optional[pd.DataFrame] = None,
null_values=False,
metric="confidence",
@@ -54,8 +54,8 @@ def association_rules(
df_orig : pandas DataFrame (default: None)
DataFrame with original input data. Only provided when null_values exist
- num_itemsets : int
- Number of transactions in original input data
+ num_itemsets : int (default: 1)
+ Number of transactions in original input data (df_orig)
null_values : bool (default: False)
In case there are null values as NaNs in the original input data
@@ -119,6 +119,10 @@ def association_rules(
if null_values and df_orig is None:
raise TypeError("If null values exist, df_orig must be provided.")
+ # if null values exist, num_itemsets must be provided
+ if null_values and num_itemsets == 1:
+ raise TypeError("If null values exist, num_itemsets must be provided.")
+
# check for valid input
fpc.valid_input_check(df_orig, null_values)
@@ -285,7 +289,6 @@ def certainty_metric_helper(sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_):
# if the input dataframe is complete
if not null_values:
disAC, disA, disC, dis_int, dis_int_ = 0, 0, 0, 0, 0
- num_itemsets = 1
else:
an = list(antecedent)
diff --git a/mlxtend/frequent_patterns/fpcommon.py b/mlxtend/frequent_patterns/fpcommon.py
index fb0ed2ac3..7d3047917 100644
--- a/mlxtend/frequent_patterns/fpcommon.py
+++ b/mlxtend/frequent_patterns/fpcommon.py
@@ -31,7 +31,7 @@ def setup_fptree(df, min_support):
)
item_support = np.array(
- np.sum(np.logical_or(df.values == 1, df.values is True), axis=0)
+ np.nansum(df.values, axis=0)
/ (float(num_itemsets) - np.nansum(disabled, axis=0))
)
item_support = item_support.reshape(-1)