Skip to content

Training and using a classifier

EdwardRaff edited this page Jun 3, 2015 · 1 revision

Introduction

This example quickly shows how to create a data set for classification, and how to train and use a classifier. This is only a very basic and bare-bones piece of code. The "Naive Bayes" algorithm is used as the classifier, but that doesn't meant you will always want to use it - it's just for demonstration purposes here.

Code

import java.io.File;
import jsat.ARFFLoader;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.bayesian.NaiveBayes;


/**
 * A simple example where we load up a data set for classification purposes. 
 * 
 * @author Edward Raff
 */
public class ClassificationExample
{

    /**
     * @param args the command line arguments
     */
    public static void main(String[] args)
    {
        String nominalPath = "uci" + File.separator + "nominal" + File.separator;
        File file = new File(nominalPath + "iris.arff");
        DataSet dataSet = ARFFLoader.loadArffFile(file);
        
        //We specify '0' as the class we would like to make the target class. 
        ClassificationDataSet cDataSet = new ClassificationDataSet(dataSet, 0);
        
        int errors = 0;
        Classifier classifier = new NaiveBayes();
        classifier.trainC(cDataSet);
        
        for(int i = 0; i < dataSet.getSampleSize(); i++)
        {
            DataPoint dataPoint = cDataSet.getDataPoint(i);//It is important not to mix these up, the class has been removed from data points in 'cDataSet' 
            int truth = cDataSet.getDataPointCategory(i);//We can grab the true category from the data set
            
            //Categorical Results contains the probability estimates for each possible target class value. 
            //Classifiers that do not support probability estimates will mark its prediction with total confidence. 
            CategoricalResults predictionResults = classifier.classify(dataPoint);
            int predicted = predictionResults.mostLikely();
            if(predicted != truth)
                errors++;
            System.out.println( i + "| True Class: " + truth + ", Predicted: " + predicted + ", Confidence: " + predictionResults.getProb(predicted) );
        }
        
        System.out.println(errors + " errors were made, " + 100.0*errors/dataSet.getSampleSize() + "% error rate" );
    }
}