diff --git a/games/grid-world/game.jl b/games/grid-world/game.jl index c195359d..11be2cb3 100644 --- a/games/grid-world/game.jl +++ b/games/grid-world/game.jl @@ -12,50 +12,47 @@ const RL = CommonRLInterface # episode. Because time is not captured in the state, this introduces a slight bias in # the value function. const EPISODE_LENGTH_BOUND = 200 +const SIZE = SA[10, 10] +const REWARDS = Dict( + SA[9,3] => 1.0, + SA[8,8] => 0.3, + SA[4,3] => -1.0, + SA[4,6] => -0.5) mutable struct World <: AbstractEnv - size::SVector{2, Int} - rewards::Dict{SVector{2, Int}, Float64} - state::SVector{2, Int} + position ::SVector{2, Int} time :: Int end function World() - rewards = Dict( - SA[9,3] => 10.0, - SA[8,8] => 3.0, - SA[4,3] => -10.0, - SA[4,6] => -5.0) return World( - SA[10, 10], - rewards, SA[rand(1:10), rand(1:10)], 0) end -RL.reset!(env::World) = (env.state = SA[rand(1:env.size[1]), rand(1:env.size[2])]) +RL.reset!(env::World) = (env.position = SA[rand(1:SIZE[1]), rand(1:SIZE[2])]) RL.actions(env::World) = [SA[1,0], SA[-1,0], SA[0,1], SA[0,-1]] -RL.observe(env::World) = env.state +RL.observe(env::World) = env.position RL.terminated(env::World) = - haskey(env.rewards, env.state) || env.time > EPISODE_LENGTH_BOUND + haskey(REWARDS, env.position) || env.time > EPISODE_LENGTH_BOUND function RL.act!(env::World, a) # 40% chance of going in a random direction (=30% chance of going in a wrong direction) if rand() < 0.4 a = rand(actions(env)) end - env.state = clamp.(env.state + a, SA[1,1], env.size) + env.position = clamp.(env.position + a, SA[1,1], SIZE) env.time += 1 - return get(env.rewards, env.state, 0.0) + return get(REWARDS, env.position, 0.0) end @provide RL.player(env::World) = 1 # An MDP is a one player game @provide RL.players(env::World) = [1] -@provide RL.observations(env::World) = [SA[x, y] for x in 1:env.size[1], y in 1:env.size[2]] -@provide RL.clone(env::World) = World(env.size, copy(env.rewards), env.state, env.time) -@provide RL.state(env::World) = env.state -@provide RL.setstate!(env::World, s) = (env.state = s) +@provide RL.observations(env::World) = [SA[x, y] for x in 1:SIZE[1], y in 1:SIZE[2]] +@provide RL.clone(env::World) = World(env.position, env.time) +@provide RL.state(env::World) = env.position +@provide RL.setstate!(env::World, s) = (env.position = s) @provide RL.valid_action_mask(env::World) = BitVector([1, 1, 1, 1]) # Additional functions needed by AlphaZero.jl that are not present in @@ -64,11 +61,11 @@ end # CommonRLInterfaceWrapper. function GI.render(env::World) - for y in reverse(1:env.size[2]) - for x in 1:env.size[1] + for y in reverse(1:SIZE[2]) + for x in 1:SIZE[1] s = SA[x, y] - r = get(env.rewards, s, 0.0) - if env.state == s + r = get(REWARDS, s, 0.0) + if env.position == s c = ("+",) elseif r > 0 c = (crayon"green", "o") @@ -84,7 +81,7 @@ function GI.render(env::World) end function GI.vectorize_state(env::World, state) - v = zeros(Float32, env.size[1], env.size[2]) + v = zeros(Float32, SIZE[1], SIZE[2]) v[state[1], state[2]] = 1 return v end @@ -107,8 +104,8 @@ function GI.read_state(env::World) @assert length(s) == 2 x = parse(Int, s[1]) y = parse(Int, s[2]) - @assert 1 <= x <= env.size[1] - @assert 1 <= y <= env.size[2] + @assert 1 <= x <= SIZE[1] + @assert 1 <= y <= SIZE[2] return SA[x, y] catch e return nothing diff --git a/games/grid-world/params.jl b/games/grid-world/params.jl index 955f16d3..412ec43a 100644 --- a/games/grid-world/params.jl +++ b/games/grid-world/params.jl @@ -37,7 +37,6 @@ learning = LearningParams( use_gpu=false, use_position_averaging=false, samples_weighing_policy=CONSTANT_WEIGHT, - rewards_renormalization=10, l2_regularization=1e-4, optimiser=Adam(lr=5e-3), batch_size=64,