-
Notifications
You must be signed in to change notification settings - Fork 0
/
CrossValidation.java
97 lines (79 loc) · 3.66 KB
/
CrossValidation.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
/*
* Separated application for performing cross validation.
*
* Copyright (C) 2013 Lisa Vitolo <[email protected]>
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the Creative Commons
* Attribution-NonCommercial-ShareAlike 3.0 license.
* You should have received a copy of the license with this product.
* Otherwise, visit http://creativecommons.org/licenses/by-nc-sa/3.0/
*/
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import libsvm.*;
public class CrossValidation
{
public static void main(String[] args)
{
if (args.length == 0) {
System.out.println("Usage: CrossValidation <sample files>");
System.exit(-1);
}
/* Suppress all output from libsvm */
svm_print_interface iface = new svm_print_interface() {
@Override public void print(String string) {}
};
svm.svm_set_print_string_function(iface);
/* Initialize the Stanford NLP pipeline */
Properties prop = new Properties();
prop.put("annotators", "tokenize, ssplit, pos, lemma");
StanfordCoreNLP pipeline = new StanfordCoreNLP(prop);
System.out.println(":: Stanford NLP pipeline initialized correctly.");
String[] files = { args[0] };
SimilarityLearner sl = new SimilarityLearner(pipeline);
List<TrainingSample> samples = sl.extractFeatures(files);
svm_problem problem = sl.buildSVMProblem(samples);
svm_parameter parameters = Constants.getSVMParameters();
double[] C_values = Constants.getCValues();
double[] P_values = Constants.getPValues();
double[] G_values = Constants.getGammaValues();
double bestCorr = Double.MIN_VALUE;
double bestC = 0.0;
double bestP = 0.0;
double bestGamma = 0.0;
double[] targets = new double[ samples.size() ];
double[] gs = new double[ samples.size() ]; /* gold standard scores provided with the samples */
int i = 0;
for (Iterator<TrainingSample> it = samples.iterator(); it.hasNext(); ) {
gs[i++] = it.next().target;
}
System.out.println(":: Starting cross validation.");
for (int iC = 0; iC < C_values.length; iC++) {
parameters.C = C_values[iC];
for (int iP = 0; iP < P_values.length; iP++) {
parameters.p = P_values[iP];
for (int iG = 0; iG < G_values.length; iG++) {
parameters.gamma = G_values[iG];
System.out.println("Trying C = " + parameters.C + ", P = " + parameters.p + ", G = " + parameters.gamma);
svm.svm_cross_validation(problem, parameters, Constants.getValidationFold(), targets);
double corr = Correlation.getPearsonCorrelation(targets, gs);
if (corr > bestCorr) {
System.out.println(":: New best correlation is " + corr);
bestCorr = corr;
bestC = C_values[iC];
bestP = P_values[iP];
bestGamma = G_values[iG];
}
}
}
}
System.out.println(":: Cross validation finished.");
System.out.println("C: " + bestC);
System.out.println("P: " + bestP);
System.out.println("Gamma: " + bestGamma);
System.out.println("Best correlation is " + bestCorr);
}
}