-
Notifications
You must be signed in to change notification settings - Fork 58
/
Copy pathtrain.py
33 lines (32 loc) · 976 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from crf import CRF
from features import *
import re, sys
import pickle
training_file = sys.argv[1]
if __name__ == '__main__':
labels,obsrvs,word_sets,word_data,label_data = fit_dataset(training_file)
crf = CRF(
labels=list(labels),
feature_functions = Membership.functions(labels,*word_sets.values()) +
MatchRegex.functions(labels,
'^[^0-9a-zA-Z\-]+$',
'^[^0-9\-]+$',
'^[A-Z]+$',
'^-?[1-9][0-9]*\.[0-9]+$',
'^[1-9][0-9\.]+[a-z]+$',
'^[0-9]+$',
'^[A-Z][a-z]+$',
'^([A-Z][a-z]*)+$',
'^[^aeiouAEIOU]+$'
))# + [
# lambda yp,y,x_v,i,_y=_y,_x=_x:
# 1 if i < len(x_v) and y==_y and x_v[i].lower() ==_x else 0
# for _y in labels
# for _x in obsrvs
#])
crf.train(word_data[:-5],label_data[:-5])
pickle.dump(crf,open(sys.argv[2],'wb'))
for i in range(-5,0):
print word_data[i]
print crf.predict(word_data[i])
print label_data[i]