-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkNN.py
30 lines (22 loc) · 774 Bytes
/
kNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from helper_functions import *
from scipy.stats import mode
class KNNClassifier:
def __init__(self, k):
self.k = k
def predict(self, metric, trainset, testset):
'''
:param metric: distance metric, compatible with cdist fucntion from scipy
:return: error rate
'''
train_x = trainset[:, 1:]
train_y = trainset[:, 0]
test_x = testset[:, 1:]
test_y = testset[:, 0]
incorrect = 0
for index, item in enumerate(test_x):
neighbours = get_knn(train_x, item.reshape(1, -1), metric, self.k)
pred = mode(train_y[neighbours])
if pred[0] != test_y[index]:
incorrect += 1
error = incorrect / len(test_y)
return error