-
Notifications
You must be signed in to change notification settings - Fork 13
/
nb.py
109 lines (65 loc) · 2.26 KB
/
nb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from data_reader import load_data
import numpy as np
from uncompress import *
import os
from sklearn.naive_bayes import GaussianNB
import cPickle
if __name__=='__main__':
batch_size = 1
data_X,data_Y = load_data()
indices = np.random.permutation(np.arange(data_X.shape[0]))
data_X = data_X[indices,:,:]
data_Y = data_Y[indices]
X = []
Y = []
gnb = GaussianNB()
fp = open(os.path.join('nb_logs','nb_object' + '.save'), 'wb')
cPickle.dump(gnb, fp, protocol = cPickle.HIGHEST_PROTOCOL)
fp.close()
s = 0.0
i = 0
for step in range(data_X.shape[0]/batch_size):
print "Step:",step
batch_x, batch_y = data_X[step*batch_size:(step+1)*batch_size],data_Y[step*batch_size:(step+1)*batch_size]
batch_x = uncompress(batch_x,86796)
# print batch_x.shape
batch_x = np.sum(batch_x,axis=1)
# print batch_x.shape
batch_x = np.squeeze(batch_x)
# print batch_x.shape
# print 'y'
# print batch_y.shape
batch_y = np.repeat(batch_y,50,axis=0)
# print batch_y.shape
gnb.partial_fit(batch_x,batch_y,classes=[0,1])
# X.append(batch_x)
# Y.append(batch_y)
# break
# X = np.vstack(X)
# Y = np.squeeze(np.asarray(Y))
# print X.shape,Y.shape
for step in range(data_X.shape[0]/batch_size):
print "Step:",step
batch_x, batch_y = data_X[step*batch_size:(step+1)*batch_size],data_Y[step*batch_size:(step+1)*batch_size]
batch_x = uncompress(batch_x,86796)
# print batch_x.shape
batch_x = np.sum(batch_x,axis=1)
# print batch_x.shape
batch_x = np.squeeze(batch_x)
# print batch_x.shape
# print 'y'
# print batch_y.shape
batch_y = np.repeat(batch_y,50,axis=0)
# print batch_y.shape
# gnb.partial_fit(batch_x,batch_y,classes=[0,1])
x = gnb.score(batch_x,batch_y)
print x
s += x
i +=1
print 'average : ', s/i
# gnb.fit(X,Y)
#
print s/i
fp = open(os.path.join('nb_logs','nb_object' + '.save'), 'wb')
cPickle.dump(gnb, fp, protocol = cPickle.HIGHEST_PROTOCOL)
fp.close()