-
Notifications
You must be signed in to change notification settings - Fork 0
/
CrossValSev.py
53 lines (41 loc) · 1.87 KB
/
CrossValSev.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import BaggingClassifier, RandomForestClassifier
import sklearn.discriminant_analysis as skl_da
import sklearn.linear_model as skl_lm
import sklearn.neighbors as skl_nb
cross = 10
plt.style.use('seaborn-white')
songs = pd.read_csv('training_data.csv', na_values='?', dtype={'ID': str}).dropna().reset_index()
X = songs.drop(columns=['label'])
Y = songs['label']
randomize_indices = np.random.choice(X.shape[0], X.shape[0], replace=False)
models = []
models.append(RandomForestClassifier())
models.append(skl_da.LinearDiscriminantAnalysis())
models.append(skl_da.QuadraticDiscriminantAnalysis())
models.append(BaggingClassifier())
models.append(skl_lm.LogisticRegression())
models.append(skl_nb.KNeighborsClassifier(n_neighbors=14))
size = len(models)
print(size)
misclassification = np.zeros((cross, np.shape(models)[0]))
for i in range(cross):
n = np.ceil(X.shape[0]/cross) # number of samples in each fold
validationIndex = np.arange(i*n, min(i*n+n, X.shape[0]), 1).astype('int')
randomize_validationIndex = randomize_indices[validationIndex]
X_train = X.iloc[~X.index.isin(randomize_validationIndex)]
Y_train = Y.iloc[~Y.index.isin(randomize_validationIndex)]
X_validation = X.iloc[randomize_validationIndex]
Y_validation = Y.iloc[randomize_validationIndex]
for m in range(np.shape(models)[0]): # try different models
model = models[m]
model.fit(X_train, Y_train)
prediction = model.predict(X_validation)
misclassification[i, m] = (np.mean(prediction != Y_validation))
plt.boxplot(misclassification)
plt.title('cross validation error for different methods')
plt.xticks(np.arange(size)+1, ('RandomForest', 'LDA', 'QDA', 'Bagging', 'Log. Reg.', 'K-NN'))
plt.ylabel('validation error')
plt.show()