Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple XOR example #104

Open
bkmgit opened this issue Jun 2, 2020 · 0 comments
Open

Simple XOR example #104

bkmgit opened this issue Jun 2, 2020 · 0 comments

Comments

@bkmgit
Copy link

bkmgit commented Jun 2, 2020

May be helpful to add a simple XOR example. Classification is one way to do this, but can add more traditional method if considered useful.

require 'numo/narray'
require 'chainer'

class XOR < Chainer::Chain
  L = Chainer::Links::Connection::Linear
  F = Chainer::Functions

  def initialize(n_units, n_out)
    super()
    init_scope do
      @l1 = L.new(nil, out_size: n_units)
      @l2 = L.new(nil, out_size: n_out)
    end
  end

  def call(x, y)
    return F::Loss::MeanSquaredError.mean_squared_error(fwd(x), y)
  end

  def fwd(x)
    h1 = F::Activation::Sigmoid.sigmoid(@l1.(x))
    h2 = @l2.(h1)
    return h2
  end
end

device = Chainer::Device.create(-1)
Chainer::Device.change_default(device)
xm = device.xm

model = XOR.new(4,2)

optimizer = Chainer::Optimizers::Adam.new
optimizer.setup(model)

x = [[0,0],[1,0],[0,1],[1,1]]

# target
y = [0,1,1,0]

y_onehot = xm::SFloat.eye(2)[y, false]

x = xm::SFloat.cast(x)
y = xm::SFloat.cast(y)
y_onehot = xm::SFloat.cast(y_onehot)

x_train = x 
y_train = y_onehot  
x_test = x       
y_test = y             

# Train
print("Training ")

10000.times{|i|
  print(".") if i % 1000 == 0
  x = Chainer::Variable.new(x_train)
  y = Chainer::Variable.new(y_train)
  model.cleargrads()
  loss = model.(x, y)
  loss.backward()
  optimizer.update()
}

puts

# Test
xt = Chainer::Variable.new(x_test)
yt = model.fwd(xt)
n_row, n_col = yt.data.shape

puts "Result : Correct Answer : Answer <= One-Hot"
ok = 0
n_row.times{|i|
  ans = yt.data[i, true].max_index()
  if ans == y_test[i]
    ok += 1
    printf("OK")
  else
    printf("--")
  end
  printf(" : #{y_test[i].to_i} :")

  puts " #{ans.to_i} <+{yt.data[i, 0..-1].to_a}"
}
puts "Row: #{n_row}, Column: #{n_col}"
puts "Accuracy rate : #{ok}/#{n_row} = #{ok.to_f / n_row}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant