This repository has been archived by the owner on Apr 20, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcross_validation.py
123 lines (99 loc) · 3.19 KB
/
cross_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from collections import namedtuple
import csv
import argparse
from argparse import ArgumentParser
import numpy as np
from sklearn import cross_validation
from sklearn.metrics import precision_score, recall_score, f1_score
from .argtypes import positive_integer
from .cnn import CNN, LabeledTweet
from .util import parse_tweets
def parse_args():
parser = ArgumentParser('Evaluate a CNN')
# TODO Validation
parser.add_argument('-m', '--model', required=True)
parser.add_argument(
'-d', '--dataset', required=True,
type=argparse.FileType('r')
)
# TODO It sucks that we have to specify an output file.
# We can't use stdout, though, since keras is cluttering that up.
parser.add_argument(
'-o', '--output', required=True,
type=argparse.FileType('w')
)
parser.add_argument(
'-b', '--batch-size',
default=50,
type=positive_integer
)
parser.add_argument(
'-e', '--epochs',
default=1,
type=positive_integer
)
return parser.parse_args()
EvaluationResult = namedtuple('EvaluationResult', ['p', 'r'])
# TODO Should we really assume that tweets is just a list?
def evaluate(
model,
train_tweets,
test_tweets,
train_labels,
test_labels,
batch_size,
epochs
):
cnn = CNN()
cnn.load(model)
cnn.fit_generator(
# TODO Ugh, really? Maybe just use or wrap fit?
lambda: (
LabeledTweet(label=label, tweet=tweet)
for label, tweet in zip(train_labels, train_tweets)
),
batch_size=batch_size,
nb_epoch=epochs,
samples_per_epoch=len(train_tweets)
)
predictions = cnn.predict(test_tweets)
predicted_labels = [p.argmax() for p in predictions['output']]
return EvaluationResult(
p=precision_score(test_labels, predicted_labels, average=None),
r=recall_score(test_labels, predicted_labels, average=None)
)
def cross_validate(model, dataset, epochs, batch_size, output):
# TODO It sucks that this reopens the file
test_tweets = list(parse_tweets(dataset.name))
n = len(test_tweets)
texts = np.array([labeled_tweet.tweet for labeled_tweet in test_tweets])
labels = np.array([labeled_tweet.label for labeled_tweet in test_tweets])
output_writer = csv.DictWriter(
output,
[
'positive_precision', 'negative_precision', 'neutral_precision',
'positive_recall', 'negative_recall', 'neutral_recall'
]
)
output_writer.writeheader()
cv = cross_validation.KFold(n, 10)
for train, test in cv:
scores = evaluate(
model,
texts[train], texts[test],
labels[train], labels[test],
batch_size, epochs
)
output_writer.writerow({
'positive_precision': scores.p[0],
'negative_precision': scores.p[1],
'neutral_precision': scores.p[2],
'positive_recall': scores.r[0],
'negative_recall': scores.r[1],
'neutral_recall': scores.r[2]
})
def main():
args = parse_args()
cross_validate(**vars(args))
if __name__ == '__main__':
main()