From 43a13a4062a3cb4ea273d6491caaabda588f4f87 Mon Sep 17 00:00:00 2001 From: Joao Hespanha Date: Mon, 2 Sep 2024 17:19:58 -0700 Subject: [PATCH 1/3] Fix and unit test for issue "Wrong style for state report for TicTacToeEnv() #1079" --- .../src/environments/examples/TicTacToeEnv.jl | 1 + .../test/environments/examples/tic_tac_toe.jl | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl index ff4c89b4d..027502793 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl @@ -72,6 +72,7 @@ RLBase.players(::TicTacToeEnv) = (Player(:Cross), Player(:Nought)) RLBase.state(env::TicTacToeEnv, ::Observation, ::DefaultPlayer) = state(env, Observation{Int}(), Player(:Any)) RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, player) = env.board +RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}) = env.board RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(1)) RLBase.state(env::TicTacToeEnv, ::Observation{Int}, player::Player) = get_tic_tac_toe_state_info()[env].index diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl index 0eca516ff..f5b15f289 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl @@ -3,15 +3,15 @@ using ReinforcementLearningEnvironments, ReinforcementLearningBase, ReinforcementLearningCore trajectory_1 = Trajectory( - CircularArraySARTSTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity=1), BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + InsertSampleRatioController(n_inserted=-1), ) trajectory_2 = Trajectory( - CircularArraySARTSTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity=1), BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + InsertSampleRatioController(n_inserted=-1), ) multiagent_policy = MultiAgentPolicy(PlayerTuple( @@ -30,6 +30,7 @@ @test length(state_space(env, Observation{Int}())) == 5478 @test RLBase.state(env, Observation{BitArray{3}}(), Player(:Cross)) == env.board + @test RLBase.state(env, Observation{BitArray{3}}()) == env.board @test RLBase.state_space(env, Observation{BitArray{3}}(), Player(:Cross)) isa ArrayProductDomain @test RLBase.state_space(env, Observation{String}(), Player(:Cross)) isa DomainSets.FullSpace{String} @test RLBase.state(env, Observation{String}(), Player(:Cross)) isa String From b8572cc6b1279ab656040f02c407a9768d65de5c Mon Sep 17 00:00:00 2001 From: Jeremiah <4462211+jeremiahpslewis@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:33:42 +0200 Subject: [PATCH 2/3] Adapt TicTacToe fix --- .../src/environments/examples/TicTacToeEnv.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl index 027502793..ba7a56350 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl @@ -70,10 +70,9 @@ end RLBase.players(::TicTacToeEnv) = (Player(:Cross), Player(:Nought)) -RLBase.state(env::TicTacToeEnv, ::Observation, ::DefaultPlayer) = state(env, Observation{Int}(), Player(:Any)) +RLBase.state(env::TicTacToeEnv, o::Observation, ::DefaultPlayer) = state(env, o, Player(:Any)) RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, player) = env.board -RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}) = env.board -RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(1)) +RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(:Any)) RLBase.state(env::TicTacToeEnv, ::Observation{Int}, player::Player) = get_tic_tac_toe_state_info()[env].index From e0b4d8286e3a484ecd4a9e9bc5ee03d78ae913f0 Mon Sep 17 00:00:00 2001 From: Jeremiah <4462211+jeremiahpslewis@users.noreply.github.com> Date: Tue, 3 Sep 2024 11:22:02 +0200 Subject: [PATCH 3/3] Refactor TicTacToe state / multiplayer handling --- .../src/environments/examples/TicTacToeEnv.jl | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl index ba7a56350..9f2cf02cb 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl @@ -70,20 +70,19 @@ end RLBase.players(::TicTacToeEnv) = (Player(:Cross), Player(:Nought)) -RLBase.state(env::TicTacToeEnv, o::Observation, ::DefaultPlayer) = state(env, o, Player(:Any)) -RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, player) = env.board -RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(:Any)) -RLBase.state(env::TicTacToeEnv, ::Observation{Int}, player::Player) = +RLBase.state(env::TicTacToeEnv, o::Observation, ::RLBase.AbstractPlayer) = state(env, o) +RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}()) +RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}) = env.board +RLBase.state(env::TicTacToeEnv, ::Observation{Int}) = get_tic_tac_toe_state_info()[env].index -RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, player::Player) = ArrayProductDomain(fill(false:true, 3, 3, 3)) -RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, player::Player) = +RLBase.state_space(env::TicTacToeEnv, o::Observation, ::RLBase.AbstractPlayer) = state_space(env, o) +RLBase.state_space(::TicTacToeEnv, ::Observation{BitArray{3}}) = ArrayProductDomain(fill(false:true, 3, 3, 3)) +RLBase.state_space(::TicTacToeEnv, ::Observation{Int}) = Base.OneTo(length(get_tic_tac_toe_state_info())) -RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, player::Player) = fullspace(String) +RLBase.state_space(::TicTacToeEnv, ::Observation{String}) = fullspace(String) -RLBase.state(env::TicTacToeEnv, ::Observation{String}) = state(env::TicTacToeEnv, Observation{String}(), Player(1)) - -function RLBase.state(env::TicTacToeEnv, ::Observation{String}, player::Player) +function RLBase.state(env::TicTacToeEnv, ::Observation{String}) buff = IOBuffer() for i in 1:3 for j in 1:3