From 15ec41ac54117e8d154e09c315d1f36c7387cf4a Mon Sep 17 00:00:00 2001 From: Sawood Alam Date: Tue, 24 Jan 2017 13:42:42 -0500 Subject: [PATCH] Classifier evaluation and validation (#142) * Benchmark reporter improvement * Added reset method placeholder * Added fundamental validation API * Added initial validation task with custom reporting * k-fold accuracy report generated * Hardcoded number removed * Confusion matrix generated * Reporting confusion matrix with various derived stats * Reordered a derived attribute * Fixed test failutres if redis is not running * Added row and column for class-wise precision and recall in the confusion matrix * Remove warnings of uninitialized instances of test teardown * Checking right conditiones for the test teardown * Added reset methods in Bayes, LSI, and Bayes backends * Corrected teardown conditionals * Added tests for reset functionality * Corrected typo in documentation * Validation tasks added for Redis backend and LSI * Added a message if Redis server is not running during validation * Loaded validator module in the gem * A more meaningful table title * Renamed a method to better reflect the role * Added classifier auto instantiation to validate method, reordered for readability * Added options argument in validate method * Renamed methods with more suitable names * Added optional header printing in run report * Adding accuracy in the confusion matrix reporting --- Rakefile | 8 + lib/classifier-reborn.rb | 1 + .../backends/bayes_memory_backend.rb | 4 + .../backends/bayes_redis_backend.rb | 6 + lib/classifier-reborn/bayes.rb | 21 ++- lib/classifier-reborn/lsi.rb | 6 +- .../validators/classifier_validator.rb | 169 ++++++++++++++++++ test/backends/backend_redis_test.rb | 2 +- test/bayes/bayesian_common_benchmarks.rb | 16 +- test/bayes/bayesian_common_tests.rb | 12 ++ test/bayes/bayesian_integration_test.rb | 2 +- test/bayes/bayesian_redis_test.rb | 8 +- test/lsi/lsi_test.rb | 11 ++ test/validators/classifier_validation.rb | 78 ++++++++ 14 files changed, 331 insertions(+), 13 deletions(-) create mode 100644 lib/classifier-reborn/validators/classifier_validator.rb create mode 100644 test/validators/classifier_validation.rb diff --git a/Rakefile b/Rakefile index a32c0b2..c002347 100644 --- a/Rakefile +++ b/Rakefile @@ -23,6 +23,14 @@ Rake::TestTask.new(:bench) do |t| t.verbose = true end +# Run validations +desc 'Run all validations' +Rake::TestTask.new(:validate) do |t| + t.libs << 'lib' + t.pattern = 'test/*/*_validation.rb' + t.verbose = true +end + # Make a console, useful when working on tests desc 'Generate a test console' task :console do diff --git a/lib/classifier-reborn.rb b/lib/classifier-reborn.rb index 6a368a1..85969f2 100644 --- a/lib/classifier-reborn.rb +++ b/lib/classifier-reborn.rb @@ -28,3 +28,4 @@ require_relative 'classifier-reborn/category_namer' require_relative 'classifier-reborn/bayes' require_relative 'classifier-reborn/lsi' +require_relative 'classifier-reborn/validators/classifier_validator' diff --git a/lib/classifier-reborn/backends/bayes_memory_backend.rb b/lib/classifier-reborn/backends/bayes_memory_backend.rb index 40a76f5..f366f46 100644 --- a/lib/classifier-reborn/backends/bayes_memory_backend.rb +++ b/lib/classifier-reborn/backends/bayes_memory_backend.rb @@ -61,5 +61,9 @@ def delete_category_word(category, word) def word_in_category?(category, word) @categories[category].key?(word) end + + def reset + initialize + end end end diff --git a/lib/classifier-reborn/backends/bayes_redis_backend.rb b/lib/classifier-reborn/backends/bayes_redis_backend.rb index 9171c84..b3bc7ba 100644 --- a/lib/classifier-reborn/backends/bayes_redis_backend.rb +++ b/lib/classifier-reborn/backends/bayes_redis_backend.rb @@ -92,5 +92,11 @@ def delete_category_word(category, word) def word_in_category?(category, word) @redis.hexists(category, word) end + + def reset + @redis.flushdb + @redis.set(:total_words, 0) + @redis.set(:total_trainings, 0) + end end end diff --git a/lib/classifier-reborn/bayes.rb b/lib/classifier-reborn/bayes.rb index 7783506..9d12110 100644 --- a/lib/classifier-reborn/bayes.rb +++ b/lib/classifier-reborn/bayes.rb @@ -25,7 +25,7 @@ class Bayes # stopwords: nil Accepts path to a text file or an array of words, when supplied, overwrites the default stopwords; assign empty string or array to disable stopwords # backend: BayesMemoryBackend.new Alternatively, BayesRedisBackend.new for persistent storage def initialize(*args) - initial_categories = [] + @initial_categories = [] options = { language: 'en', enable_threshold: false, threshold: 0.0, @@ -36,12 +36,12 @@ def initialize(*args) if arg.is_a?(Hash) options.merge!(arg) else - initial_categories.push(arg) + @initial_categories.push(arg) end end unless options.key?(:auto_categorize) - options[:auto_categorize] = initial_categories.empty? ? true : false + options[:auto_categorize] = @initial_categories.empty? ? true : false end @language = options[:language] @@ -51,9 +51,7 @@ def initialize(*args) @enable_stemmer = options[:enable_stemmer] @backend = options[:backend] - initial_categories.each do |c| - add_category(c) - end + populate_initial_categories if options.key?(:stopwords) custom_stopwords options[:stopwords] @@ -244,8 +242,19 @@ def add_category(category) alias_method :append_category, :add_category + def reset + @backend.reset + populate_initial_categories + end + private + def populate_initial_categories + @initial_categories.each do |c| + add_category(c) + end + end + # Overwrites the default stopwords for current language with supplied list of stopwords or file def custom_stopwords(stopwords) unless stopwords.is_a?(Enumerable) diff --git a/lib/classifier-reborn/lsi.rb b/lib/classifier-reborn/lsi.rb index 123d540..99f5220 100644 --- a/lib/classifier-reborn/lsi.rb +++ b/lib/classifier-reborn/lsi.rb @@ -86,7 +86,7 @@ def <<(item) add_item(item) end - # Returns the categories for a given indexed items. You are free to add and remove + # Returns categories for a given indexed item. You are free to add and remove # items from this as you see fit. It does not invalide an index to change its categories. def categories_for(item) return [] unless @items[item] @@ -300,6 +300,10 @@ def highest_ranked_stems(doc, count = 3) top_n.collect { |x| @word_list.word_for_index(content_vector_array.index(x)) } end + def reset + initialize(auto_rebuild: @auto_rebuild, cache_node_vectors: @cache_node_vectors) + end + private def build_reduced_matrix(matrix, cutoff = 0.75) diff --git a/lib/classifier-reborn/validators/classifier_validator.rb b/lib/classifier-reborn/validators/classifier_validator.rb new file mode 100644 index 0000000..61d8735 --- /dev/null +++ b/lib/classifier-reborn/validators/classifier_validator.rb @@ -0,0 +1,169 @@ +module ClassifierReborn + module ClassifierValidator + + module_function + + def cross_validate(classifier, sample_data, fold=10, *options) + classifier = ClassifierReborn::const_get(classifier).new(options) if classifier.is_a?(String) + sample_data.shuffle! + partition_size = sample_data.length / fold + partitioned_data = sample_data.each_slice(partition_size) + conf_mats = [] + fold.times do |i| + training_data = partitioned_data.take(fold) + test_data = training_data.slice!(i) + conf_mats << validate(classifier, training_data.flatten!(1), test_data) + end + classifier.reset() + generate_report(conf_mats) + end + + def validate(classifier, training_data, test_data, *options) + classifier = ClassifierReborn::const_get(classifier).new(options) if classifier.is_a?(String) + classifier.reset() + training_data.each do |rec| + classifier.train(rec.first, rec.last) + end + evaluate(classifier, test_data) + end + + def evaluate(classifier, test_data) + conf_mat = empty_conf_mat(classifier.categories.sort) + test_data.each do |rec| + actual = rec.first.tr('_', ' ').capitalize + predicted = classifier.classify(rec.last) + conf_mat[actual][predicted] += 1 unless predicted.nil? + end + conf_mat + end + + def generate_report(*conf_mats) + conf_mats.flatten! + accumulated_conf_mat = conf_mats.length == 1 ? conf_mats.first : empty_conf_mat(conf_mats.first.keys.sort) + header = "Run Total Correct Incorrect Accuracy" + puts + puts " Run Report ".center(header.length, "-") + puts header + puts "-" * header.length + if conf_mats.length > 1 + conf_mats.each_with_index do |conf_mat, i| + run_report = build_run_report(conf_mat) + print_run_report(run_report, i+1) + conf_mat.each do |actual, cols| + cols.each do |predicted, v| + accumulated_conf_mat[actual][predicted] += v + end + end + end + puts "-" * header.length + end + run_report = build_run_report(accumulated_conf_mat) + print_run_report(run_report, "All") + puts + print_conf_mat(accumulated_conf_mat) + puts + conf_tab = conf_mat_to_tab(accumulated_conf_mat) + print_conf_tab(conf_tab) + end + + def build_run_report(conf_mat) + correct = incorrect = 0 + conf_mat.each do |actual, cols| + cols.each do |predicted, v| + if actual == predicted + correct += v + else + incorrect += v + end + end + end + total = correct + incorrect + {total: total, correct: correct, incorrect: incorrect, accuracy: divide(correct, total)} + end + + def conf_mat_to_tab(conf_mat) + conf_tab = Hash.new {|h, k| h[k] = {p: {t: 0, f: 0}, n: {t: 0, f: 0}}} + conf_mat.each_key do |positive| + conf_mat.each do |actual, cols| + cols.each do |predicted, v| + conf_tab[positive][positive == predicted ? :p : :n][actual == predicted ? :t : :f] += v + end + end + end + conf_tab + end + + def print_run_report(stats, prefix="", print_header=false) + puts "#{"Run".rjust([3, prefix.length].max)} Total Correct Incorrect Accuracy" if print_header + puts "#{prefix.to_s.rjust(3)} #{stats[:total].to_s.rjust(9)} #{stats[:correct].to_s.rjust(9)} #{stats[:incorrect].to_s.rjust(9)} #{stats[:accuracy].round(5).to_s.ljust(7, '0').rjust(9)}" + end + + def print_conf_mat(conf_mat) + header = ["Predicted ->"] + conf_mat.keys + ["Total", "Recall"] + cell_size = header.map(&:length).max + header = header.map{|h| h.rjust(cell_size)}.join(" ") + puts " Confusion Matrix ".center(header.length, "-") + puts header + puts "-" * header.length + predicted_totals = conf_mat.keys.map{|predicted| [predicted, 0]}.to_h + correct = 0 + conf_mat.each do |k, rec| + actual_total = rec.values.reduce(:+) + puts ([k.ljust(cell_size)] + rec.values.map{|v| v.to_s.rjust(cell_size)} + [actual_total.to_s.rjust(cell_size), divide(rec[k], actual_total).round(5).to_s.rjust(cell_size)]).join(" ") + rec.each do |cat, val| + predicted_totals[cat] += val + correct += val if cat == k + end + end + total = predicted_totals.values.reduce(:+) + puts "-" * header.length + puts (["Total".ljust(cell_size)] + predicted_totals.values.map{|v| v.to_s.rjust(cell_size)} + [total.to_s.rjust(cell_size), "".rjust(cell_size)]).join(" ") + puts (["Precision".ljust(cell_size)] + predicted_totals.keys.map{|k| divide(conf_mat[k][k], predicted_totals[k]).round(5).to_s.rjust(cell_size)} + ["Accuracy ->".rjust(cell_size), divide(correct, total).round(5).to_s.rjust(cell_size)]).join(" ") + end + + def print_conf_tab(conf_tab) + conf_tab.each do |positive, tab| + puts "# Positive class: #{positive}" + derivations = conf_tab_derivations(tab) + print_derivations(derivations) + puts + end + end + + def conf_tab_derivations(tab) + positives = tab[:p][:t] + tab[:n][:f] + negatives = tab[:n][:t] + tab[:p][:f] + total = positives + negatives + { + total_population: positives + negatives, + condition_positive: positives, + condition_negative: negatives, + true_positive: tab[:p][:t], + true_negative: tab[:n][:t], + false_positive: tab[:p][:f], + false_negative: tab[:n][:f], + prevalence: divide(positives, total), + specificity: divide(tab[:n][:t], negatives), + recall: divide(tab[:p][:t], positives), + precision: divide(tab[:p][:t], tab[:p][:t] + tab[:p][:f]), + accuracy: divide(tab[:p][:t] + tab[:n][:t], total), + f1_score: divide(2 * tab[:p][:t], 2 * tab[:p][:t] + tab[:p][:f] + tab[:n][:f]) + } + end + + def print_derivations(derivations) + max_len = derivations.keys.map(&:length).max + derivations.each do |k, v| + puts k.to_s.tr('_', ' ').capitalize.ljust(max_len) + " : " + v.to_s + end + end + + def empty_conf_mat(categories) + categories.map{|actual| [actual, categories.map{|predicted| [predicted, 0]}.to_h]}.to_h + end + + def divide(dividend, divisor) + divisor.zero? ? 0.0 : dividend / divisor.to_f + end + end +end diff --git a/test/backends/backend_redis_test.rb b/test/backends/backend_redis_test.rb index 3fe050f..48c7015 100644 --- a/test/backends/backend_redis_test.rb +++ b/test/backends/backend_redis_test.rb @@ -16,6 +16,6 @@ def setup end def teardown - @backend.instance_variable_get(:@redis).flushdb + @backend.instance_variable_get(:@redis).flushdb if defined? @backend end end diff --git a/test/bayes/bayesian_common_benchmarks.rb b/test/bayes/bayesian_common_benchmarks.rb index 24ac2ab..2c4cc5c 100644 --- a/test/bayes/bayesian_common_benchmarks.rb +++ b/test/bayes/bayesian_common_benchmarks.rb @@ -3,11 +3,25 @@ module BayesianCommonBenchmarks MAX_RECORDS = 5000 - class BenchmarkReporter < Minitest::Reporters::RubyMateReporter + class BenchmarkReporter < Minitest::Reporters::BaseReporter + include ANSI::Code + def before_suite(suite) + puts puts ([suite] + BayesianCommonBenchmarks.bench_range).join("\t") end + def after_suite(suite) + end + + def report + super + puts + puts('Finished in %.5fs' % total_time) + print('%d tests, %d assertions, ' % [count, assertions]) + color = failures.zero? && errors.zero? ? :green : :red + print(send(color) { '%d failures, %d errors, ' } % [failures, errors]) + print(yellow { '%d skips' } % skips) puts end end diff --git a/test/bayes/bayesian_common_tests.rb b/test/bayes/bayesian_common_tests.rb index a393225..902ba92 100644 --- a/test/bayes/bayesian_common_tests.rb +++ b/test/bayes/bayesian_common_tests.rb @@ -179,6 +179,18 @@ def test_custom_file_stopwords refute_equal Float::INFINITY, classifier.classify_with_score('To be or not to be')[1] end + def test_reset + @classifier.add_category 'Test' + assert_equal %w(Test Interesting Uninteresting).sort, @classifier.categories.sort + @classifier.reset + assert_equal %w(Interesting Uninteresting).sort, @classifier.categories.sort + classifier = empty_classifier + classifier.train('Ruby', 'A really sweet language') + assert classifier.categories.include?('Ruby') + classifier.reset + assert classifier.categories.empty? + end + private def another_classifier diff --git a/test/bayes/bayesian_integration_test.rb b/test/bayes/bayesian_integration_test.rb index a7664b8..c2aea46 100644 --- a/test/bayes/bayesian_integration_test.rb +++ b/test/bayes/bayesian_integration_test.rb @@ -26,7 +26,7 @@ def setup end def teardown - @redis_backend.instance_variable_get(:@redis).flushdb + @redis_backend.instance_variable_get(:@redis).flushdb unless @redis_backend.nil? end def test_equality_of_backends diff --git a/test/bayes/bayesian_redis_test.rb b/test/bayes/bayesian_redis_test.rb index 2348b87..0a23b6f 100644 --- a/test/bayes/bayesian_redis_test.rb +++ b/test/bayes/bayesian_redis_test.rb @@ -8,11 +8,11 @@ class BayesianRedisTest < Minitest::Test def setup begin + @old_stopwords = Hasher::STOPWORDS['en'] @backend = ClassifierReborn::BayesRedisBackend.new @backend.instance_variable_get(:@redis).config(:set, "save", "") @alternate_backend = ClassifierReborn::BayesRedisBackend.new(db: 1) @classifier = ClassifierReborn::Bayes.new 'Interesting', 'Uninteresting', backend: @backend - @old_stopwords = Hasher::STOPWORDS['en'] rescue Redis::CannotConnectError => e skip(e) end @@ -20,7 +20,9 @@ def setup def teardown Hasher::STOPWORDS['en'] = @old_stopwords - @backend.instance_variable_get(:@redis).flushdb - @alternate_backend.instance_variable_get(:@redis).flushdb + if defined? @backend + @backend.instance_variable_get(:@redis).flushdb + @alternate_backend.instance_variable_get(:@redis).flushdb + end end end diff --git a/test/lsi/lsi_test.rb b/test/lsi/lsi_test.rb index d20caf2..5ec3f42 100644 --- a/test/lsi/lsi_test.rb +++ b/test/lsi/lsi_test.rb @@ -200,4 +200,15 @@ def test_warn_when_adding_bad_document def test_summary assert_equal 'This text involves dogs too [...] This text also involves cats', Summarizer.summary([@str1, @str2, @str3, @str4, @str5].join, 2) end + + def test_reset + lsi = ClassifierReborn::LSI.new + assert lsi.items.empty? + lsi.add_item @str1, 'Dog' + refute lsi.items.empty? + lsi.reset + assert lsi.items.empty? + lsi.add_item @str3, 'Cat' + refute lsi.items.empty? + end end diff --git a/test/validators/classifier_validation.rb b/test/validators/classifier_validation.rb new file mode 100644 index 0000000..b5f8d3e --- /dev/null +++ b/test/validators/classifier_validation.rb @@ -0,0 +1,78 @@ +# encoding: utf-8 + +require File.dirname(__FILE__) + '/../test_helper' +require File.dirname(__FILE__) + '/../../lib/classifier-reborn/validators/classifier_validator' +require_relative '../data/test_data_loader' + +class ClassifierValidation < Minitest::Test + class ValidationReporter < Minitest::Reporters::BaseReporter + REPORT_WIDTH = 80 + + def before_suite(suite) + puts + puts "# #{suite}" + puts + end + + def after_suite(suite) + puts + end + + def before_test(test) + super + validation_name = test.name.gsub(/^test_/, '') + puts " #{validation_name} ".center(REPORT_WIDTH, "=") + end + + def after_test(test) + super + puts "-" * REPORT_WIDTH + puts + end + + def report + super + puts('Finished in %.5fs' % total_time) + puts + end + end + Minitest::Reporters.use! ValidationReporter.new + + SAMPLE_SIZE = 5000 + + def setup + data = TestDataLoader.sms_data + if data.length < SAMPLE_SIZE + TestDataLoader.report_insufficient_data(data.length, SAMPLE_SIZE) + skip(e) + end + @sample_data = data.take(SAMPLE_SIZE).collect { |line| line.strip.split("\t") } + end + + def test_bayes_classifier_10_fold_cross_validate_memory + classifier = ClassifierReborn::Bayes.new + ClassifierValidator.cross_validate(classifier, @sample_data) + end + + def test_bayes_classifier_3_fold_cross_validate_redis + begin + backend = ClassifierReborn::BayesRedisBackend.new + backend.instance_variable_get(:@redis).config(:set, "save", "") + classifier = ClassifierReborn::Bayes.new backend: backend + ClassifierValidator.cross_validate(classifier, @sample_data, 3) + rescue Redis::CannotConnectError => e + puts "Unable to connect to Redis server" + skip(e) + end + end + + def test_lsi_classifier_5_fold_cross_validate + lsi = ClassifierReborn::LSI.new + required_methods = [:train, :classify, :categories] + unless required_methods.reduce(true){|m, o| m && lsi.respond_to?(o)} + puts "TODO: LSI is not validatable until all of the #{required_methods} methods are implemented!" + skip + end + ClassifierValidator.cross_validate(lsi, @sample_data, 5) + end +end