-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfp_growth.py
123 lines (93 loc) · 2.29 KB
/
fp_growth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# %%
# %env JAVA_HOME="C:\Progra~1\Eclips~1\jdk-17.0.11.9-hotspot"
# %env PYSPARK_PYTHON="python"
# %%
import pandas as pd
import pyspark.sql.functions as F
from itables import show
from pyspark.ml.fpm import FPGrowth
from pyspark.sql import SparkSession
# %%
spark = (
SparkSession.builder.appName("FP-Growth")
.config("spark.driver.extraJavaOptions", "-Xss10m")
.getOrCreate()
)
# %% [markdown]
# ## Sample
# %%
df = pd.read_excel("data/Online Retail.xlsx")
# %%
df_spark = spark.createDataFrame(df).cache()
# %%
df_spark.printSchema()
# %%
df_spark.show()
# %% [markdown]
# ## Modify
# %% [markdown]
# ### Rename
# %%
df_renamed = (
df_spark.withColumnRenamed("InvoiceNo", "invoice_no")
.withColumnRenamed("StockCode", "stock_code")
.withColumnRenamed("Description", "description")
.withColumnRenamed("Quantity", "quantity")
.withColumnRenamed("InvoiceDate", "invoice_date")
.withColumnRenamed("UnitPrice", "unit_price")
.withColumnRenamed("CustomerID", "customer_id")
.withColumnRenamed("Country", "country")
)
# %% [markdown]
# ### Remove Duplicates
# %%
df_deduped = df_renamed.dropDuplicates()
# %%
df_renamed.count() - df_deduped.count()
# %%
df_deduped_t = df_deduped.dropDuplicates(["invoice_no", "description"])
# %%
df_deduped.count() - df_deduped_t.count()
# %% [markdown]
# ### Filter Examples
# %%
df_filtered = (
df_deduped_t.filter(F.col("description").isNotNull())
.filter(F.col("description") != "")
.filter(F.col("invoice_no").isNotNull())
.filter(F.col("invoice_no") != "")
.filter(F.col("quantity") > 0)
.filter(F.col("unit_price") > 0)
)
# %%
df_deduped_t.count() - df_filtered.count()
# %% [markdown]
# ### Aggregate
# %%
df_agg = df_filtered.groupBy("invoice_no").agg(
F.collect_list("description").alias("descriptions")
)
# %%
n_transactions = df_agg.count()
n_transactions
# %% [markdown]
# ## Model
# %% [markdown]
# ### FP-Growth
# %%
fp = FPGrowth(itemsCol="descriptions", minSupport=0.01, minConfidence=0.8)
# %%
fp_model = fp.fit(df_agg)
# %%
itemsets = (
fp_model.freqItemsets.withColumn("support", F.col("freq") / n_transactions)
.sort(F.col("support").desc())
.toPandas()
)
# %%
show(itemsets, scrollX=True)
# %%
rules = fp_model.associationRules.toPandas()
# %%
show(rules, scrollX=True)
# %%