Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gridworld constants and reward size fix #103

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions games/grid-world/game.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,47 @@ const RL = CommonRLInterface
# episode. Because time is not captured in the state, this introduces a slight bias in
# the value function.
const EPISODE_LENGTH_BOUND = 200
const SIZE = SA[10, 10]
const REWARDS = Dict(
SA[9,3] => 1.0,
SA[8,8] => 0.3,
SA[4,3] => -1.0,
SA[4,6] => -0.5)

mutable struct World <: AbstractEnv
size::SVector{2, Int}
rewards::Dict{SVector{2, Int}, Float64}
state::SVector{2, Int}
position ::SVector{2, Int}
time :: Int
end

function World()
rewards = Dict(
SA[9,3] => 10.0,
SA[8,8] => 3.0,
SA[4,3] => -10.0,
SA[4,6] => -5.0)
return World(
SA[10, 10],
rewards,
SA[rand(1:10), rand(1:10)],
0)
end

RL.reset!(env::World) = (env.state = SA[rand(1:env.size[1]), rand(1:env.size[2])])
RL.reset!(env::World) = (env.position = SA[rand(1:SIZE[1]), rand(1:SIZE[2])])
RL.actions(env::World) = [SA[1,0], SA[-1,0], SA[0,1], SA[0,-1]]
RL.observe(env::World) = env.state
RL.observe(env::World) = env.position

RL.terminated(env::World) =
haskey(env.rewards, env.state) || env.time > EPISODE_LENGTH_BOUND
haskey(REWARDS, env.position) || env.time > EPISODE_LENGTH_BOUND

function RL.act!(env::World, a)
# 40% chance of going in a random direction (=30% chance of going in a wrong direction)
if rand() < 0.4
a = rand(actions(env))
end
env.state = clamp.(env.state + a, SA[1,1], env.size)
env.position = clamp.(env.position + a, SA[1,1], SIZE)
env.time += 1
return get(env.rewards, env.state, 0.0)
return get(REWARDS, env.position, 0.0)
end

@provide RL.player(env::World) = 1 # An MDP is a one player game
@provide RL.players(env::World) = [1]
@provide RL.observations(env::World) = [SA[x, y] for x in 1:env.size[1], y in 1:env.size[2]]
@provide RL.clone(env::World) = World(env.size, copy(env.rewards), env.state, env.time)
@provide RL.state(env::World) = env.state
@provide RL.setstate!(env::World, s) = (env.state = s)
@provide RL.observations(env::World) = [SA[x, y] for x in 1:SIZE[1], y in 1:SIZE[2]]
@provide RL.clone(env::World) = World(env.position, env.time)
@provide RL.state(env::World) = env.position
@provide RL.setstate!(env::World, s) = (env.position = s)
@provide RL.valid_action_mask(env::World) = BitVector([1, 1, 1, 1])

# Additional functions needed by AlphaZero.jl that are not present in
Expand All @@ -64,11 +61,11 @@ end
# CommonRLInterfaceWrapper.

function GI.render(env::World)
for y in reverse(1:env.size[2])
for x in 1:env.size[1]
for y in reverse(1:SIZE[2])
for x in 1:SIZE[1]
s = SA[x, y]
r = get(env.rewards, s, 0.0)
if env.state == s
r = get(REWARDS, s, 0.0)
if env.position == s
c = ("+",)
elseif r > 0
c = (crayon"green", "o")
Expand All @@ -84,7 +81,7 @@ function GI.render(env::World)
end

function GI.vectorize_state(env::World, state)
v = zeros(Float32, env.size[1], env.size[2])
v = zeros(Float32, SIZE[1], SIZE[2])
v[state[1], state[2]] = 1
return v
end
Expand All @@ -107,8 +104,8 @@ function GI.read_state(env::World)
@assert length(s) == 2
x = parse(Int, s[1])
y = parse(Int, s[2])
@assert 1 <= x <= env.size[1]
@assert 1 <= y <= env.size[2]
@assert 1 <= x <= SIZE[1]
@assert 1 <= y <= SIZE[2]
return SA[x, y]
catch e
return nothing
Expand Down
1 change: 0 additions & 1 deletion games/grid-world/params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ learning = LearningParams(
use_gpu=false,
use_position_averaging=false,
samples_weighing_policy=CONSTANT_WEIGHT,
rewards_renormalization=10,
l2_regularization=1e-4,
optimiser=Adam(lr=5e-3),
batch_size=64,
Expand Down