-
-
Notifications
You must be signed in to change notification settings - Fork 205
Training and using a classifier
EdwardRaff edited this page Jun 3, 2015
·
1 revision
This example quickly shows how to create a data set for classification, and how to train and use a classifier. This is only a very basic and bare-bones piece of code. The "Naive Bayes" algorithm is used as the classifier, but that doesn't meant you will always want to use it - it's just for demonstration purposes here.
import java.io.File;
import jsat.ARFFLoader;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.bayesian.NaiveBayes;
/**
* A simple example where we load up a data set for classification purposes.
*
* @author Edward Raff
*/
public class ClassificationExample
{
/**
* @param args the command line arguments
*/
public static void main(String[] args)
{
String nominalPath = "uci" + File.separator + "nominal" + File.separator;
File file = new File(nominalPath + "iris.arff");
DataSet dataSet = ARFFLoader.loadArffFile(file);
//We specify '0' as the class we would like to make the target class.
ClassificationDataSet cDataSet = new ClassificationDataSet(dataSet, 0);
int errors = 0;
Classifier classifier = new NaiveBayes();
classifier.trainC(cDataSet);
for(int i = 0; i < dataSet.getSampleSize(); i++)
{
DataPoint dataPoint = cDataSet.getDataPoint(i);//It is important not to mix these up, the class has been removed from data points in 'cDataSet'
int truth = cDataSet.getDataPointCategory(i);//We can grab the true category from the data set
//Categorical Results contains the probability estimates for each possible target class value.
//Classifiers that do not support probability estimates will mark its prediction with total confidence.
CategoricalResults predictionResults = classifier.classify(dataPoint);
int predicted = predictionResults.mostLikely();
if(predicted != truth)
errors++;
System.out.println( i + "| True Class: " + truth + ", Predicted: " + predicted + ", Confidence: " + predictionResults.getProb(predicted) );
}
System.out.println(errors + " errors were made, " + 100.0*errors/dataSet.getSampleSize() + "% error rate" );
}
}