Skip to content

Commit

Permalink
Classifier evaluation and validation (jekyll#142)
Browse files Browse the repository at this point in the history
* Benchmark reporter improvement

* Added reset method placeholder

* Added fundamental validation API

* Added initial validation task with custom reporting

* k-fold accuracy report generated

* Hardcoded number removed

* Confusion matrix generated

* Reporting confusion matrix with various derived stats

* Reordered a derived attribute

* Fixed test failutres if redis is not running

* Added row and column for class-wise precision and recall in the confusion matrix

* Remove warnings of uninitialized instances of test teardown

* Checking right conditiones for the test teardown

* Added reset methods in Bayes, LSI, and Bayes backends

* Corrected teardown conditionals

* Added tests for reset functionality

* Corrected typo in documentation

* Validation tasks added for Redis backend and LSI

* Added a message if Redis server is not running during validation

* Loaded validator module in the gem

* A more meaningful table title

* Renamed a method to better reflect the role

* Added classifier auto instantiation to validate method, reordered for readability

* Added options argument in validate method

* Renamed methods with more suitable names

* Added optional header printing in run report

* Adding accuracy in the confusion matrix reporting
  • Loading branch information
ibnesayeed authored and Ch4s3 committed Jan 24, 2017
1 parent 8a12664 commit 15ec41a
Show file tree
Hide file tree
Showing 14 changed files with 331 additions and 13 deletions.
8 changes: 8 additions & 0 deletions Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ Rake::TestTask.new(:bench) do |t|
t.verbose = true
end

# Run validations
desc 'Run all validations'
Rake::TestTask.new(:validate) do |t|
t.libs << 'lib'
t.pattern = 'test/*/*_validation.rb'
t.verbose = true
end

# Make a console, useful when working on tests
desc 'Generate a test console'
task :console do
Expand Down
1 change: 1 addition & 0 deletions lib/classifier-reborn.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@
require_relative 'classifier-reborn/category_namer'
require_relative 'classifier-reborn/bayes'
require_relative 'classifier-reborn/lsi'
require_relative 'classifier-reborn/validators/classifier_validator'
4 changes: 4 additions & 0 deletions lib/classifier-reborn/backends/bayes_memory_backend.rb
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,9 @@ def delete_category_word(category, word)
def word_in_category?(category, word)
@categories[category].key?(word)
end

def reset
initialize
end
end
end
6 changes: 6 additions & 0 deletions lib/classifier-reborn/backends/bayes_redis_backend.rb
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,11 @@ def delete_category_word(category, word)
def word_in_category?(category, word)
@redis.hexists(category, word)
end

def reset
@redis.flushdb
@redis.set(:total_words, 0)
@redis.set(:total_trainings, 0)
end
end
end
21 changes: 15 additions & 6 deletions lib/classifier-reborn/bayes.rb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Bayes
# stopwords: nil Accepts path to a text file or an array of words, when supplied, overwrites the default stopwords; assign empty string or array to disable stopwords
# backend: BayesMemoryBackend.new Alternatively, BayesRedisBackend.new for persistent storage
def initialize(*args)
initial_categories = []
@initial_categories = []
options = { language: 'en',
enable_threshold: false,
threshold: 0.0,
Expand All @@ -36,12 +36,12 @@ def initialize(*args)
if arg.is_a?(Hash)
options.merge!(arg)
else
initial_categories.push(arg)
@initial_categories.push(arg)
end
end

unless options.key?(:auto_categorize)
options[:auto_categorize] = initial_categories.empty? ? true : false
options[:auto_categorize] = @initial_categories.empty? ? true : false
end

@language = options[:language]
Expand All @@ -51,9 +51,7 @@ def initialize(*args)
@enable_stemmer = options[:enable_stemmer]
@backend = options[:backend]

initial_categories.each do |c|
add_category(c)
end
populate_initial_categories

if options.key?(:stopwords)
custom_stopwords options[:stopwords]
Expand Down Expand Up @@ -244,8 +242,19 @@ def add_category(category)

alias_method :append_category, :add_category

def reset
@backend.reset
populate_initial_categories
end

private

def populate_initial_categories
@initial_categories.each do |c|
add_category(c)
end
end

# Overwrites the default stopwords for current language with supplied list of stopwords or file
def custom_stopwords(stopwords)
unless stopwords.is_a?(Enumerable)
Expand Down
6 changes: 5 additions & 1 deletion lib/classifier-reborn/lsi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def <<(item)
add_item(item)
end

# Returns the categories for a given indexed items. You are free to add and remove
# Returns categories for a given indexed item. You are free to add and remove
# items from this as you see fit. It does not invalide an index to change its categories.
def categories_for(item)
return [] unless @items[item]
Expand Down Expand Up @@ -300,6 +300,10 @@ def highest_ranked_stems(doc, count = 3)
top_n.collect { |x| @word_list.word_for_index(content_vector_array.index(x)) }
end

def reset
initialize(auto_rebuild: @auto_rebuild, cache_node_vectors: @cache_node_vectors)
end

private

def build_reduced_matrix(matrix, cutoff = 0.75)
Expand Down
169 changes: 169 additions & 0 deletions lib/classifier-reborn/validators/classifier_validator.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
module ClassifierReborn
module ClassifierValidator

module_function

def cross_validate(classifier, sample_data, fold=10, *options)
classifier = ClassifierReborn::const_get(classifier).new(options) if classifier.is_a?(String)
sample_data.shuffle!
partition_size = sample_data.length / fold
partitioned_data = sample_data.each_slice(partition_size)
conf_mats = []
fold.times do |i|
training_data = partitioned_data.take(fold)
test_data = training_data.slice!(i)
conf_mats << validate(classifier, training_data.flatten!(1), test_data)
end
classifier.reset()
generate_report(conf_mats)
end

def validate(classifier, training_data, test_data, *options)
classifier = ClassifierReborn::const_get(classifier).new(options) if classifier.is_a?(String)
classifier.reset()
training_data.each do |rec|
classifier.train(rec.first, rec.last)
end
evaluate(classifier, test_data)
end

def evaluate(classifier, test_data)
conf_mat = empty_conf_mat(classifier.categories.sort)
test_data.each do |rec|
actual = rec.first.tr('_', ' ').capitalize
predicted = classifier.classify(rec.last)
conf_mat[actual][predicted] += 1 unless predicted.nil?
end
conf_mat
end

def generate_report(*conf_mats)
conf_mats.flatten!
accumulated_conf_mat = conf_mats.length == 1 ? conf_mats.first : empty_conf_mat(conf_mats.first.keys.sort)
header = "Run Total Correct Incorrect Accuracy"
puts
puts " Run Report ".center(header.length, "-")
puts header
puts "-" * header.length
if conf_mats.length > 1
conf_mats.each_with_index do |conf_mat, i|
run_report = build_run_report(conf_mat)
print_run_report(run_report, i+1)
conf_mat.each do |actual, cols|
cols.each do |predicted, v|
accumulated_conf_mat[actual][predicted] += v
end
end
end
puts "-" * header.length
end
run_report = build_run_report(accumulated_conf_mat)
print_run_report(run_report, "All")
puts
print_conf_mat(accumulated_conf_mat)
puts
conf_tab = conf_mat_to_tab(accumulated_conf_mat)
print_conf_tab(conf_tab)
end

def build_run_report(conf_mat)
correct = incorrect = 0
conf_mat.each do |actual, cols|
cols.each do |predicted, v|
if actual == predicted
correct += v
else
incorrect += v
end
end
end
total = correct + incorrect
{total: total, correct: correct, incorrect: incorrect, accuracy: divide(correct, total)}
end

def conf_mat_to_tab(conf_mat)
conf_tab = Hash.new {|h, k| h[k] = {p: {t: 0, f: 0}, n: {t: 0, f: 0}}}
conf_mat.each_key do |positive|
conf_mat.each do |actual, cols|
cols.each do |predicted, v|
conf_tab[positive][positive == predicted ? :p : :n][actual == predicted ? :t : :f] += v
end
end
end
conf_tab
end

def print_run_report(stats, prefix="", print_header=false)
puts "#{"Run".rjust([3, prefix.length].max)} Total Correct Incorrect Accuracy" if print_header
puts "#{prefix.to_s.rjust(3)} #{stats[:total].to_s.rjust(9)} #{stats[:correct].to_s.rjust(9)} #{stats[:incorrect].to_s.rjust(9)} #{stats[:accuracy].round(5).to_s.ljust(7, '0').rjust(9)}"
end

def print_conf_mat(conf_mat)
header = ["Predicted ->"] + conf_mat.keys + ["Total", "Recall"]
cell_size = header.map(&:length).max
header = header.map{|h| h.rjust(cell_size)}.join(" ")
puts " Confusion Matrix ".center(header.length, "-")
puts header
puts "-" * header.length
predicted_totals = conf_mat.keys.map{|predicted| [predicted, 0]}.to_h
correct = 0
conf_mat.each do |k, rec|
actual_total = rec.values.reduce(:+)
puts ([k.ljust(cell_size)] + rec.values.map{|v| v.to_s.rjust(cell_size)} + [actual_total.to_s.rjust(cell_size), divide(rec[k], actual_total).round(5).to_s.rjust(cell_size)]).join(" ")
rec.each do |cat, val|
predicted_totals[cat] += val
correct += val if cat == k
end
end
total = predicted_totals.values.reduce(:+)
puts "-" * header.length
puts (["Total".ljust(cell_size)] + predicted_totals.values.map{|v| v.to_s.rjust(cell_size)} + [total.to_s.rjust(cell_size), "".rjust(cell_size)]).join(" ")
puts (["Precision".ljust(cell_size)] + predicted_totals.keys.map{|k| divide(conf_mat[k][k], predicted_totals[k]).round(5).to_s.rjust(cell_size)} + ["Accuracy ->".rjust(cell_size), divide(correct, total).round(5).to_s.rjust(cell_size)]).join(" ")
end

def print_conf_tab(conf_tab)
conf_tab.each do |positive, tab|
puts "# Positive class: #{positive}"
derivations = conf_tab_derivations(tab)
print_derivations(derivations)
puts
end
end

def conf_tab_derivations(tab)
positives = tab[:p][:t] + tab[:n][:f]
negatives = tab[:n][:t] + tab[:p][:f]
total = positives + negatives
{
total_population: positives + negatives,
condition_positive: positives,
condition_negative: negatives,
true_positive: tab[:p][:t],
true_negative: tab[:n][:t],
false_positive: tab[:p][:f],
false_negative: tab[:n][:f],
prevalence: divide(positives, total),
specificity: divide(tab[:n][:t], negatives),
recall: divide(tab[:p][:t], positives),
precision: divide(tab[:p][:t], tab[:p][:t] + tab[:p][:f]),
accuracy: divide(tab[:p][:t] + tab[:n][:t], total),
f1_score: divide(2 * tab[:p][:t], 2 * tab[:p][:t] + tab[:p][:f] + tab[:n][:f])
}
end

def print_derivations(derivations)
max_len = derivations.keys.map(&:length).max
derivations.each do |k, v|
puts k.to_s.tr('_', ' ').capitalize.ljust(max_len) + " : " + v.to_s
end
end

def empty_conf_mat(categories)
categories.map{|actual| [actual, categories.map{|predicted| [predicted, 0]}.to_h]}.to_h
end

def divide(dividend, divisor)
divisor.zero? ? 0.0 : dividend / divisor.to_f
end
end
end
2 changes: 1 addition & 1 deletion test/backends/backend_redis_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def setup
end

def teardown
@backend.instance_variable_get(:@redis).flushdb
@backend.instance_variable_get(:@redis).flushdb if defined? @backend
end
end
16 changes: 15 additions & 1 deletion test/bayes/bayesian_common_benchmarks.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@
module BayesianCommonBenchmarks
MAX_RECORDS = 5000

class BenchmarkReporter < Minitest::Reporters::RubyMateReporter
class BenchmarkReporter < Minitest::Reporters::BaseReporter
include ANSI::Code

def before_suite(suite)
puts
puts ([suite] + BayesianCommonBenchmarks.bench_range).join("\t")
end

def after_suite(suite)
end

def report
super
puts
puts('Finished in %.5fs' % total_time)
print('%d tests, %d assertions, ' % [count, assertions])
color = failures.zero? && errors.zero? ? :green : :red
print(send(color) { '%d failures, %d errors, ' } % [failures, errors])
print(yellow { '%d skips' } % skips)
puts
end
end
Expand Down
12 changes: 12 additions & 0 deletions test/bayes/bayesian_common_tests.rb
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ def test_custom_file_stopwords
refute_equal Float::INFINITY, classifier.classify_with_score('To be or not to be')[1]
end

def test_reset
@classifier.add_category 'Test'
assert_equal %w(Test Interesting Uninteresting).sort, @classifier.categories.sort
@classifier.reset
assert_equal %w(Interesting Uninteresting).sort, @classifier.categories.sort
classifier = empty_classifier
classifier.train('Ruby', 'A really sweet language')
assert classifier.categories.include?('Ruby')
classifier.reset
assert classifier.categories.empty?
end

private

def another_classifier
Expand Down
2 changes: 1 addition & 1 deletion test/bayes/bayesian_integration_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setup
end

def teardown
@redis_backend.instance_variable_get(:@redis).flushdb
@redis_backend.instance_variable_get(:@redis).flushdb unless @redis_backend.nil?
end

def test_equality_of_backends
Expand Down
8 changes: 5 additions & 3 deletions test/bayes/bayesian_redis_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@ class BayesianRedisTest < Minitest::Test

def setup
begin
@old_stopwords = Hasher::STOPWORDS['en']
@backend = ClassifierReborn::BayesRedisBackend.new
@backend.instance_variable_get(:@redis).config(:set, "save", "")
@alternate_backend = ClassifierReborn::BayesRedisBackend.new(db: 1)
@classifier = ClassifierReborn::Bayes.new 'Interesting', 'Uninteresting', backend: @backend
@old_stopwords = Hasher::STOPWORDS['en']
rescue Redis::CannotConnectError => e
skip(e)
end
end

def teardown
Hasher::STOPWORDS['en'] = @old_stopwords
@backend.instance_variable_get(:@redis).flushdb
@alternate_backend.instance_variable_get(:@redis).flushdb
if defined? @backend
@backend.instance_variable_get(:@redis).flushdb
@alternate_backend.instance_variable_get(:@redis).flushdb
end
end
end
11 changes: 11 additions & 0 deletions test/lsi/lsi_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,15 @@ def test_warn_when_adding_bad_document
def test_summary
assert_equal 'This text involves dogs too [...] This text also involves cats', Summarizer.summary([@str1, @str2, @str3, @str4, @str5].join, 2)
end

def test_reset
lsi = ClassifierReborn::LSI.new
assert lsi.items.empty?
lsi.add_item @str1, 'Dog'
refute lsi.items.empty?
lsi.reset
assert lsi.items.empty?
lsi.add_item @str3, 'Cat'
refute lsi.items.empty?
end
end
Loading

0 comments on commit 15ec41a

Please sign in to comment.