-
Notifications
You must be signed in to change notification settings - Fork 0
/
k_means.py
64 lines (52 loc) · 2.11 KB
/
k_means.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from IPython.display import clear_output
file = "all_seasons.csv"
features = ["player_height", "player_weight", "draft_year", "draft_round", "draft_number"]
def initialise_data(file, features):
players = pd.read_csv(file)
players = players.dropna(subset=features) # Removes NULL/missing values
data = players[features].copy()
data = ((data - data.min()) / (data.max() - data.min())) * 99 + 1 # Data scaling
return data
def initialise_centroids(data, k):
centroids = []
for i in range(k):
centroid = data.apply(lambda x: float(x.sample()))
centroids.append(centroid)
return pd.concat(centroids, axis=1)
def get_labels(data, centroids):
distances = centroids.apply(lambda x: np.sqrt(((data - x) ** 2).sum(axis=1)))
return distances.idxmin(axis=1)
def update_centroids(data, labels):
return data.groupby(labels).apply(lambda x: np.exp(np.log(x).mean())).T # Splits dataframe by cluster, finds geometric mean of each feature
def plot_clusters(data, labels, centroids, iteration):
pca = PCA(n_components=2)
data_2d = pca.fit_transform(data)
centroids_2d = pca.transform(centroids.T)
clear_output(wait=True)
plt.title(f'Iteration {iteration}')
plt.scatter(x=data_2d[:,0], y=data_2d[:,1], c=labels)
plt.scatter(x=centroids_2d[:,0], y=centroids_2d[:,1])
plt.show()
def main():
file = "all_seasons.csv"
features = ["player_height", "player_weight", "draft_year", "draft_round", "draft_number"]
data = initialise_data(file, features)
max_iters = 100
k = 3
centroids = initialise_centroids(data, k)
old_centroids = pd.DataFrame()
iteration = 1
while iteration < max_iters and not centroids.equals(old_centroids):
old_centroids = centroids
labels = get_labels(data, centroids)
centroids = update_centroids(data, labels)
plot_clusters(data, labels, centroids, iteration)
iteration += 1
return centroids
if __name__ == "__main__":
centroids = main()
print(centroids)