Skip to content

Commit 74f7d1b

Browse files
Introduce SinglesEnv (#683)
* initial commit * black * unused imports * fix this * don't need this * fix * put import back * import * fix * fix * simplify * name fix * simplify * fix * introduce SinglesEnv * fix tests * unused import * fix integration tests * fix examples * simplify * update docs * rename files * format * polish * unused import * condense code * polish * bugfix * add strictness * fix test * format and parameterize strict * add parameter * fix * format * invalid causes default * fix test * fix test * experiment * debug * debug * fix bug * cleanup * put tests back to normal * avoid returning DefaultBattleOrder if at all possible during conversions * add "fake" parameter which will allow conversions to be fabricated if they are invalid, if at all possible * add fake as parameter to env * bugfix * fix test * bugfix * add docstring * up the timeout * separate tests * remove init * new .rst docs
1 parent a85cea7 commit 74f7d1b

File tree

13 files changed

+714
-1096
lines changed

13 files changed

+714
-1096
lines changed

docs/source/examples/rl_with_gymnasium_wrapper.rst

+5-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Reinforcement learning with the Gymnasium wrapper
55

66
The corresponding complete source code can be found `here <https://github.com/hsahovic/poke-env/blob/master/examples/rl_with_new_gymnasium_wrapper.py>`__.
77

8-
The goal of this example is to demonstrate how to use the `farama gymnasium <https://gymnasium.farama.org/>`__ interface proposed by ``EnvPlayer``, and to train a simple deep reinforcement learning agent comparable in performance to the ``MaxDamagePlayer`` we created in :ref:`max_damage_player`.
8+
The goal of this example is to demonstrate how to use the `farama gymnasium <https://gymnasium.farama.org/>`__ interface proposed by ``PokeEnv``, and to train a simple deep reinforcement learning agent comparable in performance to the ``MaxDamagePlayer`` we created in :ref:`max_damage_player`.
99

1010
.. note:: This example necessitates `keras-rl <https://github.com/keras-rl/keras-rl>`__ (compatible with Tensorflow 1.X) or `keras-rl2 <https://github.com/wau/keras-rl2>`__ (Tensorflow 2.X), which implement numerous reinforcement learning algorithms and offer a simple API fully compatible with the Gymnasium API. You can install them by running ``pip install keras-rl`` or ``pip install keras-rl2``. If you are unsure, ``pip install keras-rl2`` is recommended.
1111

@@ -33,7 +33,7 @@ for each component of the embedding vector and return them as a ``gymnasium.Spac
3333
Defining rewards
3434
^^^^^^^^^^^^^^^^
3535

36-
Rewards are signals that the agent will use in its optimization process (a common objective is optimizing a discounted total reward). ``EnvPlayer`` objects provide a helper method, ``reward_computing_helper``, that can help defining simple symmetric rewards that take into account fainted pokemons, remaining hp, status conditions and victory.
36+
Rewards are signals that the agent will use in its optimization process (a common objective is optimizing a discounted total reward). ``PokeEnv`` objects provide a helper method, ``reward_computing_helper``, that can help defining simple symmetric rewards that take into account fainted pokemons, remaining hp, status conditions and victory.
3737

3838
We will use this method to define the following reward:
3939

@@ -135,8 +135,6 @@ Instantiating train environment and evaluation environment
135135
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
136136

137137
Normally, to ensure isolation between training and testing, two different environments are created.
138-
The base class ``EnvPlayer`` allows you to choose the opponent either when you instantiate it or replace it during training
139-
with the ``set_opponent`` method.
140138
If you don't want the player to start challenging the opponent you can set ``start_challenging=False`` when creating it.
141139
In this case, we want them to start challenging right away:
142140

@@ -270,7 +268,7 @@ This can be done with the following code:
270268
)
271269
...
272270
273-
The ``reset_env`` method of the ``EnvPlayer`` class allows you to reset the environment
271+
The ``reset_env`` method of the ``PokeEnv`` class allows you to reset the environment
274272
to a clean state, including internal counters for victories, battles, etc.
275273

276274
It takes two optional parameters:
@@ -301,7 +299,7 @@ In order to evaluate the player with the provided method, we need to use a backg
301299
302300
The ``result`` method of the ``Future`` object will block until the task is done and will return the result.
303301

304-
.. warning:: ``background_evaluate_player`` requires the challenge loop to be stopped. To ensure this use method ``reset_env(restart=False)`` of ``EnvPlayer``.
302+
.. warning:: ``background_evaluate_player`` requires the challenge loop to be stopped. To ensure this use method ``reset_env(restart=False)`` of ``PokeEnv``.
305303

306304
.. warning:: If you call ``result`` before the task is finished, the main thread will be blocked. Only do that if the agent is operating on a different thread than the one asking for the result.
307305

@@ -337,7 +335,7 @@ To use the ``cross_evaluate`` method, the strategy is the same to the one used f
337335
print(tabulate(table))
338336
...
339337
340-
.. warning:: ``background_cross_evaluate`` requires the challenge loop to be stopped. To ensure this use method ``reset_env(restart=False)`` of ``EnvPlayer``.
338+
.. warning:: ``background_cross_evaluate`` requires the challenge loop to be stopped. To ensure this use method ``reset_env(restart=False)`` of ``PokeEnv``.
341339

342340
.. warning:: If you call ``result`` before the task is finished, the main thread will be blocked. Only do that if the agent is operating on a different thread than the one asking for the result.
343341

docs/source/modules/env.rst

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
.. _env:
2+
3+
The env object and related subclasses
4+
========================================
5+
6+
.. contents:: :local:
7+
8+
PokeEnv
9+
******
10+
11+
.. automodule:: poke_env.player.env
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
15+
16+
SinglesEnv
17+
*************
18+
19+
.. automodule:: poke_env.player.singles_env
20+
:members:
21+
:undoc-members:
22+
:show-inheritance:

docs/source/modules/player.rst

-17
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,6 @@ The player object and related subclasses
55

66
.. contents:: :local:
77

8-
Env player
9-
**********
10-
11-
.. automodule:: poke_env.player.env_player
12-
:members:
13-
:undoc-members:
14-
:show-inheritance:
15-
168
Player
179
******
1810

@@ -21,19 +13,10 @@ Player
2113
:undoc-members:
2214
:show-inheritance:
2315

24-
GymnasiumEnv
25-
************
26-
27-
.. automodule:: poke_env.player.gymnasium_api
28-
:members:
29-
:undoc-members:
30-
:show-inheritance:
31-
3216
Random Player
3317
*************
3418

3519
.. automodule:: poke_env.player.random_player
3620
:members:
3721
:undoc-members:
3822
:show-inheritance:
39-

examples/env_example.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import numpy.typing as npt
3+
from gymnasium.spaces import Box
4+
from pettingzoo.test.parallel_test import parallel_api_test
5+
6+
from poke_env.environment.abstract_battle import AbstractBattle
7+
from poke_env.player import SinglesEnv
8+
9+
10+
class TestEnv(SinglesEnv[npt.NDArray[np.float32]]):
11+
def __init__(self, **kwargs):
12+
super().__init__(**kwargs)
13+
self.observation_spaces = {
14+
agent: Box(np.array([0, 0]), np.array([6, 6]), dtype=np.float32)
15+
for agent in self.possible_agents
16+
}
17+
18+
def calc_reward(self, battle) -> float:
19+
return self.reward_computing_helper(battle)
20+
21+
def embed_battle(self, battle: AbstractBattle):
22+
to_embed = []
23+
fainted_mons = 0
24+
for mon in battle.team.values():
25+
if mon.fainted:
26+
fainted_mons += 1
27+
to_embed.append(fainted_mons)
28+
fainted_enemy_mons = 0
29+
for mon in battle.opponent_team.values():
30+
if mon.fainted:
31+
fainted_enemy_mons += 1
32+
to_embed.append(fainted_enemy_mons)
33+
return np.array(to_embed)
34+
35+
36+
if __name__ == "__main__":
37+
gymnasium_env = TestEnv(
38+
battle_format="gen8randombattle",
39+
start_challenging=True,
40+
)
41+
parallel_api_test(gymnasium_env)
42+
gymnasium_env.close()

examples/gymnasium_example.py

-95
This file was deleted.

integration_tests/test_env.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import numpy as np
2+
import pytest
3+
from gymnasium.spaces import Box
4+
from pettingzoo.test.parallel_test import parallel_api_test
5+
6+
from poke_env.player import SinglesEnv
7+
8+
9+
class SinglesTestEnv(SinglesEnv):
10+
def __init__(self, **kwargs):
11+
super().__init__(**kwargs)
12+
self.observation_spaces = {
13+
agent: Box(np.array([0]), np.array([1]), dtype=np.int64)
14+
for agent in self.possible_agents
15+
}
16+
17+
def calc_reward(self, battle) -> float:
18+
return 0.0
19+
20+
def embed_battle(self, battle):
21+
return np.array([0])
22+
23+
24+
def play_function(env, n_battles):
25+
for _ in range(n_battles):
26+
done = False
27+
env.reset()
28+
while not done:
29+
actions = {name: env.action_space(name).sample() for name in env.agents}
30+
_, _, terminated, truncated, _ = env.step(actions)
31+
done = any(terminated.values()) or any(truncated.values())
32+
33+
34+
@pytest.mark.timeout(120)
35+
def test_env_run():
36+
for gen in range(4, 10):
37+
env = SinglesTestEnv(
38+
battle_format=f"gen{gen}randombattle",
39+
log_level=25,
40+
start_challenging=False,
41+
strict=False,
42+
)
43+
env.start_challenging(3)
44+
play_function(env, 3)
45+
env.close()
46+
47+
48+
@pytest.mark.timeout(60)
49+
def test_repeated_runs():
50+
env = SinglesTestEnv(
51+
battle_format="gen8randombattle",
52+
log_level=25,
53+
start_challenging=False,
54+
strict=False,
55+
)
56+
env.start_challenging(2)
57+
play_function(env, 2)
58+
env.start_challenging(2)
59+
play_function(env, 2)
60+
env.close()
61+
env = SinglesTestEnv(
62+
battle_format="gen9randombattle",
63+
log_level=25,
64+
start_challenging=False,
65+
strict=False,
66+
)
67+
env.start_challenging(2)
68+
play_function(env, 2)
69+
env.start_challenging(2)
70+
play_function(env, 2)
71+
env.close()
72+
73+
74+
@pytest.mark.timeout(60)
75+
def test_env_api():
76+
for gen in range(4, 10):
77+
env = SinglesTestEnv(
78+
battle_format=f"gen{gen}randombattle",
79+
log_level=25,
80+
start_challenging=True,
81+
strict=False,
82+
)
83+
parallel_api_test(env)
84+
env.close()

0 commit comments

Comments
 (0)