Skip to content

Commit

Permalink
Merge pull request #13 from hatappi/progress-bar
Browse files Browse the repository at this point in the history
Add LogReport, PrintReport, ProgressBar extensions
  • Loading branch information
hatappi authored Oct 13, 2017
2 parents 12c9226 + 94725d9 commit f9bcff0
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 21 deletions.
5 changes: 3 additions & 2 deletions lib/chainer/iterators/serial_iterator.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module Chainer
module Iterators
class SerialIterator < Chainer::Dataset::Iterator
attr_reader :epoch

def initialize(dataset, batch_size, repeat: true, shuffle: true)
@dataset = dataset
@batch_size = batch_size
Expand Down Expand Up @@ -40,7 +42,6 @@ def next
end

@epoch += 1
puts "epoch is #{@epoch}"
@is_new_epoch = true
else
@is_new_epoch = false
Expand All @@ -51,7 +52,7 @@ def next
end

def epoch_detail
@epoch + @current_position / @dataset.size
@epoch + @current_position.to_f / @dataset.size
end

def reset
Expand Down
17 changes: 17 additions & 0 deletions lib/chainer/link.rb
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def namedparams(include_uninit: true)
end
end
end

def namedlinks(skipself: false)
yield('/', self) unless skipself
end
end

class Chain < Link
Expand Down Expand Up @@ -97,5 +101,18 @@ def namedparams(include_uninit: true)
end
end
end

def namedlinks(skipself: false)
yield('/' , self) unless skipself
d = self.instance_variables.each_with_object({}) { |sym, h| h[sym] = self.instance_variable_get(sym) }
@children.each do |name|
child = d[name.to_sym]
prefix = '/' + name.to_s
yield(prefix, child)
d[name].namedlinks(skipself: true) do |path, link|
yield(prefix + path, link)
end
end
end
end
end
6 changes: 2 additions & 4 deletions lib/chainer/links/model/classifier.rb
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ def call(*args)
@y = @predictor.(*x)

@loss = @lossfun.call(@y, t)
# TODO: reporter
Chainer::Reporter.save_report({loss: @loss}, self)
if @compute_accuracy
@accuracy = @accfun.call(@y, t)

puts "> #{@accuracy.data.to_a}"
# TODO:reporter
Chainer::Reporter.save_report({accuracy: @accuracy}, self)
end
@loss
end
Expand Down
111 changes: 105 additions & 6 deletions lib/chainer/reporter.rb
Original file line number Diff line number Diff line change
@@ -1,31 +1,130 @@
module Chainer
module ReportService
@@reporters = []
end

class Reporter
include ReportService

def initialize
@observer_names = {}
@observation = {}
end

def add_observer(name, observer)
@observer_names[observer.object_id] = name
def self.save_report(values, observer=nil)
reporter = @@reporters[-1]
reporter.report(values, observer)
end

def add_observers(prefix, observers)
observers.each do |(name, observer)|
@observer_names[observer.object_id] = prefix + name
def report(values, observer=nil)
# TODO: keep_graph_on_report option
if observer
observer_id = observer.object_id
unless @observer_names.keys.include?(observer_id)
raise "Given observer is not registered to the reporter."
end
observer_name = @observer_names[observer_id]
values.each do |key, value|
name = "#{observer_name}/#{key}"
@observation[name] = value
end
else
@observation.update(values)
end
end

def add_observer(name, observer)
@observer_names[observer.object_id] = name
end

def scope(observation)
@@reporters << self
old = @observation
@observation = observation
yield
@observation = old
@@reporters.pop
end
end

class Summary
def initialize
@x = 0
@x2 = 0
@n = 0
end

# Adds a scalar value.
# Args:
# value: Scalar value to accumulate.
def add(value)
@x += value
@x2 += value * value
@n += 1
end

# Computes the mean.
def compute_mean
@x.to_f / @n
end

# Computes and returns the mean and standard deviation values.
# Returns:
# array: Mean and standard deviation values.
def make_statistics
mean = @x / @n
var = @x2 / @n - mean * mean
std = Math.sqrt(var)
[mean, std]
end
end

# Online summarization of a sequence of dictionaries.
# ``DictSummary`` computes the statistics of a given set of scalars online.
# It only computes the statistics for scalar values and variables of scalar values in the dictionaries.
class DictSummary
def initialize
@summaries = Hash.new { |h,k| h[k] = [] }
@summaries = Hash.new { |h,k| h[k] = Summary.new }
end

# Adds a dictionary of scalars.
# Args:
# d (dict): Dictionary of scalars to accumulate. Only elements of
# scalars, zero-dimensional arrays, and variables of
# zero-dimensional arrays are accumulated.
def add(d)
d.each do |k, v|
v = v.data if v.kind_of?(Chainer::Variable)
if v.class.method_defined?(:to_i) || (v.class.method_defined?(:ndim) && v.ndim == 0)
@summaries[k].add(v)
end
end
end

# Creates a dictionary of mean values.
# It returns a single dictionary that holds a mean value for each entry added to the summary.
#
# Returns:
# dict: Dictionary of mean values.
def compute_mean
@summaries.each_with_object({}) { |(name, summary), h| h[name] = summary.compute_mean }
end

# Creates a dictionary of statistics.
# It returns a single dictionary that holds mean and standard deviation
# values for every entry added to the summary. For an entry of name
# ``'key'``, these values are added to the dictionary by names ``'key'`` and ``'key.std'``, respectively.
#
# Returns:
# dict: Dictionary of statistics of all entries.
def make_statistics
stats = {}
@summaries.each do |name, summary|
mean, std = summary.make_statistics
stats[name] = mean
stats[name + '.std'] = std
end
stats
end
end
end
49 changes: 49 additions & 0 deletions lib/chainer/training/extensions/log_report.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
require 'tempfile'
require 'json'

module Chainer
module Training
module Extensions
class LogReport < Extension
attr_reader :log

def initialize(keys: nil, trigger: [1, 'epoch'], postprocess: nil, log_name: 'log')
@keys = keys
@trigger = Chainer::Training::Util.get_trigger(trigger)
Expand All @@ -12,6 +17,50 @@ def initialize(keys: nil, trigger: [1, 'epoch'], postprocess: nil, log_name: 'lo
init_summary
end

def call(trainer)
observation = trainer.observation

if @keys.nil?
@summary.add(observation)
else
symbolized_observation = Hash[observation.map{|(k,v)| [k.to_sym,v]}]
filterd_keys = @keys.select {|k| observation.keys.include?(k.to_sym) }
@summary.add(filterd_keys.each_with_object({}) {|k, hash| hash[k.to_s] = observation[k.to_sym] })
end

# if trigger is true, output the result
return unless @trigger.(trainer)

stats = @summary.compute_mean
stats_cpu = {}
stats.each do |name, value|
stats_cpu[name] = value.to_f # copy to CPU
end

updater = trainer.updater
stats_cpu['epoch'] = updater.epoch
stats_cpu['iteration'] = updater.iteration
stats_cpu['elapsed_time'] = trainer.elapsed_time

@postprocess.(stats_cpu) unless @postprocess.nil?

@log << stats_cpu

unless @log_name.nil?
# example: sprintf("%{a}, %{b}", {a: "1", b: "2"})
# => "1, 2"
log_name = sprintf(@log_name, stats_cpu)
temp_file = Tempfile.create(basename: log_name, tmpdir: trainer.out)

JSON.dump(@log, temp_file)

new_path = File.join(trainer.out, log_name)
FileUtils.mv(temp_file.path, new_path)
end

init_summary
end

private

def init_summary
Expand Down
43 changes: 39 additions & 4 deletions lib/chainer/training/extensions/print_report.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,59 @@ module Chainer
module Training
module Extensions
class PrintReport < Extension
def initialize(entries, log_report: 'LogReport', out: STDOUT)
def initialize(entries, log_report: 'Chainer::Training::Extensions::LogReport', out: STDOUT)
@entries = entries
@log_report = log_report
@out = out

@log_len = 0
@log_len = 0 # number of observations already printed

# format information
entry_widths = entries.map { |s| [10, s.size].max }

@header = entries.map { |e| "#{e}" }.join(' ')

templates = []
header = []
entries.zip(entry_widths).each do |entry, w|
header << sprintf("%-#{w}s", entry)
templates << [entry, "%-#{w}g ", ' ' * (w + 2)]
end
@header = header.join(' ') + "\n"
@templates = templates
end

def call(trainer)
if @header
@out.write(@header)
@header = nil
end

if @log_report.is_a?(String)
log_report = trainer.get_extension(@log_report)
elsif @log_report.is_a?(LogReport)
log_report.(trainer)
else
raise TypeError, "log report has a wrong type #{log_report.class}"
end

log = log_report.log
while log.size > @log_len
@out.write("\033[J")
print(log[@log_len])
@log_len += 1
end
end

private

def print(observation)
@templates.each do |entry, template, empty|
if observation.keys.include?(entry)
@out.write(sprintf(template, observation[entry]))
else
@out.write(empty)
end
end
@out.write("\n")
end
end
end
Expand Down
Loading

0 comments on commit f9bcff0

Please sign in to comment.