From 7b18e165f54c4445bc41cc437f71b3245323ca81 Mon Sep 17 00:00:00 2001 From: elliottower Date: Fri, 3 Feb 2023 02:41:41 -0500 Subject: [PATCH] Added greedy agent as a strong baseline (search depth 2), fixed corner cases in game logic and UI/interactive play --- README.md | 16 +- gobblet/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 157 bytes gobblet/examples/example_tianshou_DQN.py | 64 ++-- gobblet/examples/example_tianshou_PPO.py | 359 +++++++++++++++++ gobblet/examples/example_tianshou_greedy.py | 177 +++++++++ gobblet/examples/example_tianshou_rainbow.py | 360 ++++++++++++++++++ gobblet/game/board.py | 17 +- gobblet/game/gobblet.py | 3 +- gobblet/game/greedy_policy.py | 141 +++++++ gobblet/game/manual_policy.py | 14 +- pyproject.toml | 2 +- ...ieces_tianshou.cpython-38-pytest-7.1.2.pyc | Bin 0 -> 8065 bytes ...covered_pieces.cpython-38-pytest-7.1.2.pyc | Bin 0 -> 11980 bytes tests/test_manual_policy_collector.py | 145 +++++++ 14 files changed, 1257 insertions(+), 41 deletions(-) create mode 100644 gobblet/__pycache__/__init__.cpython-39.pyc create mode 100644 gobblet/examples/example_tianshou_PPO.py create mode 100644 gobblet/examples/example_tianshou_greedy.py create mode 100644 gobblet/examples/example_tianshou_rainbow.py create mode 100644 gobblet/game/greedy_policy.py create mode 100644 tests/__pycache__/test_covered_pieces_tianshou.cpython-38-pytest-7.1.2.pyc create mode 100644 tests/__pycache__/test_tianshou_covered_pieces.cpython-38-pytest-7.1.2.pyc create mode 100644 tests/test_manual_policy_collector.py diff --git a/README.md b/README.md index 556862d..25a5ddd 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,18 @@ from gobblet import gobblet_v1 env = gobblet_v1.env() ``` -### Play against a DQL agent trained with Tianshou +### Play against a greedy agent + +In the terminal, run the following: +``` +python gobblet/example_tianshou_DQN.py --cpu-players 1 +``` + +This will launch a game vs a greedy agent, which is a very strong baseline. This agent considers all possible moves with a depth of 2, winning if possible, blocking enemy wins, and even forcing the enemy to make losing moves. + +Note: this policy exploits domain knowledge to reconstruct the internal game board from the observation (perfect information) and directly uses functions from `board.py`. Tianshou policies do not get direct access to the environment, only observations/action masks. So the greedy agent should not be directly compared with other RL agents. + +### Play against a DQN agent trained with Tianshou In the terminal, run the following: ``` @@ -58,6 +69,9 @@ In the terminal, run the following: ``` python gobblet/examples/example_user_input.py" ``` + +Note: Interactive play can be enabled in other scripts using the argument `--num-cpu 1` + To select a piece size, press a number key `1`, `2`, or `3`, or press `space` to cycle through pieces. Placing a piece is done by clicking on a square on the board. A preview will appear showing legal moves with the selected piece size. Clicking on an already placed piece will pick it up and prompt you to place it in a new location (re-placing in the same location is an illegal move). ### Create screen recording of a game diff --git a/gobblet/__pycache__/__init__.cpython-39.pyc b/gobblet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d6ec7a274db870a944dcc2f01962d9a7dbdb287 GIT binary patch literal 157 zcmYe~<>g`kg1pkJ$sqbMh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o6vBKeRZts8~NW zCnqz%q$IyQwMgG3Ke;qFHLs*t-#xR$qcllBJwGWaC$&VkC Tianshou PettingZoo Wrapper -> PettingZoo Env @@ -320,24 +342,10 @@ def play( recorder = None manual_policy = gobblet_v1.ManualPolicy(env=pettingzoo_env, agent_id=args.player, recorder=recorder) # Gobblet keyboard input requires access to raw_env (uses functions from board) - # Get the first move from the CPU (human goes second)) - if args.player == 1: - result = collector.collect(n_step=1, render=args.render) - - # Get the first move from the player - else: - observation = {"observation": collector.data.obs.obs.flatten(), # Observation not used for manual_policy, bu - "action_mask": collector.data.obs.mask.flatten()} # Collector mask: [1,54], PettingZoo: [54,] - action = manual_policy(observation, pettingzoo_env.agents[0]) - - result = collector.collect_result(action=action.reshape(1), render=args.render) - - while not (collector.data.terminated or collector.data.truncated): + while pettingzoo_env.agents: agent_id = collector.data.obs.agent_id # If it is the players turn and there are less than 2 CPU players (at least one human player) if agent_id == pettingzoo_env.agents[args.player]: - # action_mask = collector.data.obs.mask[0] - # action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)) observation = {"observation": collector.data.obs.obs.flatten(), "action_mask": collector.data.obs.mask.flatten()} # PettingZoo expects a dict with this format action = manual_policy(observation, agent_id) @@ -346,17 +354,19 @@ def play( else: result = collector.collect(n_step=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") + if collector.data.terminated or collector.data.truncated: + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") + if recorder is not None: + recorder.end_recording() if __name__ == "__main__": # train the agent and watch its performance in a match! args = get_args() + print("Training agent...") result, agent = train_agent(args) print("Starting game...") if args.cpu_players == 2: watch(args, agent) else: - play(args, agent) - - #TODO: debug why it seems to not let you move when your smaller pieces are covered (print out the currently selected size and the + play(args, agent) \ No newline at end of file diff --git a/gobblet/examples/example_tianshou_PPO.py b/gobblet/examples/example_tianshou_PPO.py new file mode 100644 index 0000000..10b274a --- /dev/null +++ b/gobblet/examples/example_tianshou_PPO.py @@ -0,0 +1,359 @@ +# adapted from https://github.com/Farama-Foundation/PettingZoo/blob/master/tutorials/Tianshou/3_cli_and_logging.py +""" +This is a full example of using Tianshou with MARL to train agents, complete with argument parsing (CLI) and logging. + +Author: Will (https://github.com/WillDudley) + +Python version used: 3.8.10 + +Requirements: +pettingzoo == 1.22.0 +git+https://github.com/thu-ml/tianshou +""" + +import argparse +import os +from copy import deepcopy +from typing import Optional, Tuple + +import gym +import numpy as np +import torch +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, PPOPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from torch.utils.tensorboard import SummaryWriter + +from gobblet import gobblet_v1 +from gobblet.game.collector_manual_policy import ManualPolicyCollector +from gobblet.game.utils import GIFRecorder +import time + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.05) + parser.add_argument("--eps-train", type=float, default=0.1) + parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--lr", type=float, default=1e-4) # TODO: Changing this to 1e-5 for some reason makes it pause after 3 or 4 epochs + parser.add_argument( + "--gamma", type=float, default=0.9, help="a smaller gamma favors earlier win" + ) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--epoch", type=int, default=50) + parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument( + "--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128] + ) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.1) + parser.add_argument("--render_mode", type=str, default="human", choices=["human","rgb_array", "text", "text_full"], help="Choose the rendering mode for the game.") + parser.add_argument("--debug", action="store_true", help="Flag to enable to print extra debugging info") + parser.add_argument("--self_play", action="store_true", help="Flag to enable training via self-play (as opposed to fixed opponent)") + parser.add_argument("--cpu-players", type=int, default=2, choices=[1, 2], help="Number of CPU players (options: 1, 2)") + parser.add_argument("--player", type=int, default=0, choices=[0,1], help="Choose which player to play as: red = 0, yellow = 1") + parser.add_argument("--record", action="store_true", help="Flag to save a recording of the game (game.gif)") + parser.add_argument( + "--win-rate", + type=float, + default=0.6, + help="the expected winning rate: Optimal policy can get 0.7", + ) + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="no training, " "watch the play of pre-trained models", + ) + parser.add_argument( + "--agent-id", + type=int, + default=2, + help="the learned agent plays as the" + " agent_id-th player. Choices are 1 and 2.", + ) + parser.add_argument( + "--resume-path", + type=str, + default="", + help="the path of agent pth file " "for resuming from a pre-trained agent", + ) + parser.add_argument( + "--opponent-path", + type=str, + default="", + help="the path of opponent agent pth file " + "for resuming from a pre-trained agent", + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + return parser + + +def get_args() -> argparse.Namespace: + parser = get_parser() + return parser.parse_known_args()[0] + + +def get_agents( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[BasePolicy, torch.optim.Optimizer, list]: + env = get_env() + observation_space = ( + env.observation_space["observation"] + if isinstance(env.observation_space, gym.spaces.Dict) + else env.observation_space + ) + args.state_shape = ( + observation_space["observation"].shape or observation_space["observation"].n + ) + args.action_shape = env.action_space.shape or env.action_space.n + if agent_learn is None: + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ).to(args.device) + if optim is None: + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + agent_learn = PPOPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq, + ) + if args.resume_path: + agent_learn.load_state_dict(torch.load(args.resume_path)) + + if agent_opponent is None: + if args.self_play: + agent_opponent = deepcopy(agent_learn) + elif args.opponent_path: + agent_opponent = deepcopy(agent_learn) + agent_opponent.load_state_dict(torch.load(args.opponent_path)) + else: + # agent_opponent = RandomPolicy() + agent_opponent = deepcopy(agent_learn) + + if args.agent_id == 1: + agents = [agent_learn, agent_opponent] + else: + agents = [agent_opponent, agent_learn] + policy = MultiAgentPolicyManager(agents, env) + return policy, optim, env.agents + + +def get_env(render_mode=None, args=None): + return PettingZooEnv(gobblet_v1.env(render_mode=render_mode, args=args)) + + +def train_agent( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[dict, BasePolicy]: + # ======== environment setup ========= + train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + # ======== agent setup ========= + policy, optim, agents = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim + ) + + # ======== collector setup ========= + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True, + ) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + + # ======== tensorboard logging setup ========= + log_path = os.path.join(args.logdir, "gobblet", "dqn") + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + # ======== callback functions used during training ========= + def save_best_fn(policy): + if hasattr(args, "model_save_path"): + model_save_path = args.model_save_path + else: + model_save_path = os.path.join( + args.logdir, "gobblet", "dqn", "policy.pth" + ) + torch.save( + policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path + ) + + def stop_fn(mean_rewards): + return mean_rewards >= args.win_rate + + def train_fn(epoch, env_step): + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train) + + def train_fn_selfplay(epoch, env_step): + policy.policies[agents[0]].set_eps(args.eps_train) # Same as train_fn but for both agents instead of only learner + policy.policies[agents[1]].set_eps(args.eps_train) + + def test_fn(epoch, env_step): + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + + def test_fn_selfplay(epoch, env_step): + policy.policies[agents[0]].set_eps(args.eps_test) # Same as test_fn but for both agents instead of only learner + policy.policies[agents[1]].set_eps(args.eps_test) + + + def reward_metric(rews): + return rews[:, args.agent_id - 1] + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn_selfplay if args.self_play else train_fn, + test_fn=test_fn_selfplay if args.self_play else train_fn, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric, + ) + + return result, policy.policies[agents[args.agent_id - 1]] + + +# ======== a test function that tests a pre-trained agent ====== +def watch( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, +) -> None: + env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) + policy, optim, agents = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent + ) + policy.eval() + if not args.self_play: + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + else: + policy.policies[agents[0]].set_eps(args.eps_test) + policy.policies[agents[1]].set_eps(args.eps_test) + + collector = Collector(policy, env, exploration_noise=True) + + # First step (while loop stopping conditions are not defined until we run the first step) + result = collector.collect(n_step=1, render=args.render) + time.sleep(0.25) + + while not (collector.data.terminated or collector.data.truncated): + result = collector.collect(n_step=1, render=args.render) + time.sleep(0.25) # Slow down rendering so the actions can be seen sequentially (otherwise moves happen too fast) + + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") + +# TODO: Look more into Tianshou and see if self play is possible +# For watching I think it could just be the same policy for both agents, but for training I think self play would be different +def watch_selfplay(args, agent): + raise NotImplementedError() + env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, debug=args.debug)]) + agent.set_eps(args.eps_test) + policy = MultiAgentPolicyManager([agent, deepcopy(agent)], env) # fixed here + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") + + +# ======== allows the user to input moves and play vs a pre-trained agent ====== +def play( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, +) -> None: + env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) + # env = get_env(render_mode=args.render_mode, args=args) # Throws error because collector looks for length, could just override though since I'm using my own collector + policy, optim, agents = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent + ) + # policy.eval() + # policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + + # Set the CPU agent to continue training + policy.policies[agents[1 - args.player]].set_eps(args.eps_train) + + collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions + + pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env + if args.record: + recorder = GIFRecorder() + else: + recorder = None + manual_policy = gobblet_v1.ManualPolicy(env=pettingzoo_env, agent_id=args.player, recorder=recorder) # Gobblet keyboard input requires access to raw_env (uses functions from board) + + while pettingzoo_env.agents: + agent_id = collector.data.obs.agent_id + # If it is the players turn and there are less than 2 CPU players (at least one human player) + if agent_id == pettingzoo_env.agents[args.player]: + # action_mask = collector.data.obs.mask[0] + # action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)) + observation = {"observation": collector.data.obs.obs.flatten(), + "action_mask": collector.data.obs.mask.flatten()} # PettingZoo expects a dict with this format + action = manual_policy(observation, agent_id) + + result = collector.collect_result(action=action.reshape(1), render=args.render) + else: + result = collector.collect(n_step=1, render=args.render) + + if collector.data.terminated or collector.data.truncated: + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") + +if __name__ == "__main__": + # train the agent and watch its performance in a match! + args = get_args() + result, agent = train_agent(args) + print("Starting game...") + if args.cpu_players == 2: + watch(args, agent) + else: + play(args, agent) + + #TODO: debug why it seems to not let you move when your smaller pieces are covered (print out the currently selected size and the diff --git a/gobblet/examples/example_tianshou_greedy.py b/gobblet/examples/example_tianshou_greedy.py new file mode 100644 index 0000000..f83737a --- /dev/null +++ b/gobblet/examples/example_tianshou_greedy.py @@ -0,0 +1,177 @@ +# adapted from https://github.com/Farama-Foundation/PettingZoo/blob/master/tutorials/Tianshou/3_cli_and_logging.py +""" +This is a full example of using Tianshou with MARL to train agents, complete with argument parsing (CLI) and logging. + +Author: Will (https://github.com/WillDudley) + +Python version used: 3.8.10 + +Requirements: +pettingzoo == 1.22.0 +git+https://github.com/thu-ml/tianshou +""" + +import argparse +from typing import Optional, Tuple + + +import torch +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy + + +from gobblet import gobblet_v1 +from gobblet.game.collector_manual_policy import ManualPolicyCollector +from gobblet.game.utils import GIFRecorder +from gobblet.game.greedy_policy import GreedyPolicy +import time + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.05) + parser.add_argument("--eps-train", type=float, default=0.1) + parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--lr", type=float, default=1e-4) # TODO: Changing this to 1e-5 for some reason makes it pause after 3 or 4 epochs + parser.add_argument( + "--gamma", type=float, default=0.9, help="a smaller gamma favors earlier win" + ) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--epoch", type=int, default=50) + parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument( + "--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128] + ) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.1) + parser.add_argument("--render_mode", type=str, default="human", choices=["human","rgb_array", "text", "text_full"], help="Choose the rendering mode for the game.") + parser.add_argument("--debug", action="store_true", help="Flag to enable to print extra debugging info") + parser.add_argument("--self_play", action="store_true", help="Flag to enable training via self-play (as opposed to fixed opponent)") + parser.add_argument("--cpu-players", type=int, default=1, choices=[1, 2], help="Number of CPU players (options: 1, 2)") + parser.add_argument("--player", type=int, default=0, choices=[0,1], help="Choose which player to play as: red = 0, yellow = 1") + parser.add_argument("--record", action="store_true", help="Flag to save a recording of the game (game.gif)") + parser.add_argument( + "--win-rate", + type=float, + default=0.6, + help="the expected winning rate: Optimal policy can get 0.7", + ) + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="no training, " "watch the play of pre-trained models", + ) + parser.add_argument( + "--agent-id", + type=int, + default=2, + help="the learned agent plays as the" + " agent_id-th player. Choices are 1 and 2.", + ) + parser.add_argument( + "--resume-path", + type=str, + default="", + help="the path of agent pth file " "for resuming from a pre-trained agent", + ) + parser.add_argument( + "--opponent-path", + type=str, + default="", + help="the path of opponent agent pth file " + "for resuming from a pre-trained agent", + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + return parser + + +def get_args() -> argparse.Namespace: + parser = get_parser() + return parser.parse_known_args()[0] + + +def get_agents() -> Tuple[BasePolicy, list]: + env = get_env() + agents = [GreedyPolicy(), GreedyPolicy()] + policy = MultiAgentPolicyManager(agents, env) + return policy, env.agents + + +def get_env(render_mode=None, args=None): + return PettingZooEnv(gobblet_v1.env(render_mode=render_mode, args=args)) + + +# ======== watch two greedy agents play each other (tends to get stuck in loops) ====== +def watch() -> None: + env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) + policy, agents = get_agents() + + collector = ManualPolicyCollector(policy, env) + + pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env + if args.record: + recorder = GIFRecorder() + else: + recorder = None + + manual_policy = gobblet_v1.ManualPolicy(env=pettingzoo_env, agent_id=args.player, recorder=recorder) # Gobblet keyboard input requires access to raw_env (uses functions from board) + while pettingzoo_env.agents: + # agent_id = collector.data.obs.agent_id + + result = collector.collect(n_step=1, render=args.render) + + if collector.data.terminated or collector.data.truncated: + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.player].mean()}, length: {lens.mean()}") + +# ======== allows the user to input moves and play vs a greedy agent ====== +def play() -> None: + env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) + + policy, agents = get_agents() + collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions + + pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env + if args.record: + recorder = GIFRecorder() + else: + recorder = None + manual_policy = gobblet_v1.ManualPolicy(env=pettingzoo_env, agent_id=args.player, recorder=recorder) # Gobblet keyboard input requires access to raw_env (uses functions from board) + + while pettingzoo_env.agents: + agent_id = collector.data.obs.agent_id + # If it is the players turn and there are less than 2 CPU players (at least one human player) + if agent_id == pettingzoo_env.agents[args.player]: + # action_mask = collector.data.obs.mask[0] + # action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)) + observation = {"observation": collector.data.obs.obs.flatten(), + "action_mask": collector.data.obs.mask.flatten()} # PettingZoo expects a dict with this format + action = manual_policy(observation, agent_id) + + result = collector.collect_result(action=action.reshape(1), render=args.render) + else: + result = collector.collect(n_step=1, render=args.render) + + if collector.data.terminated or collector.data.truncated: + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.player].mean()}, length: {lens.mean()}") + +if __name__ == "__main__": + # train the agent and watch its performance in a match! + args = get_args() + print("Starting game...") + if args.cpu_players == 2: + watch() + else: + play() \ No newline at end of file diff --git a/gobblet/examples/example_tianshou_rainbow.py b/gobblet/examples/example_tianshou_rainbow.py new file mode 100644 index 0000000..30d73b7 --- /dev/null +++ b/gobblet/examples/example_tianshou_rainbow.py @@ -0,0 +1,360 @@ +# adapted from https://github.com/Farama-Foundation/PettingZoo/blob/master/tutorials/Tianshou/3_cli_and_logging.py +""" +This is a full example of using Tianshou with MARL to train agents, complete with argument parsing (CLI) and logging. + +Author: Will (https://github.com/WillDudley) + +Python version used: 3.8.10 + +Requirements: +pettingzoo == 1.22.0 +git+https://github.com/thu-ml/tianshou +""" + +import argparse +import os +from copy import deepcopy +from typing import Optional, Tuple + +import gym +import numpy as np +import torch +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy, RainbowPolicy, PGPolicy # Rainbow and PG don't work directly +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from torch.utils.tensorboard import SummaryWriter + +from gobblet import gobblet_v1 +from gobblet.game.collector_manual_policy import ManualPolicyCollector +from gobblet.game.utils import GIFRecorder +import time + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.05) + parser.add_argument("--eps-train", type=float, default=0.1) + parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument( + "--gamma", type=float, default=0.9, help="a smaller gamma favors earlier win" + ) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--epoch", type=int, default=50) + parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--step-per-collect", type=int, default=10) + parser.add_argument("--update-per-step", type=float, default=0.1) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument( + "--hidden-sizes", type=int, nargs="*", default=[128, 128, 128, 128] + ) + parser.add_argument("--training-num", type=int, default=10) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.1) + parser.add_argument("--render_mode", type=str, default="human", choices=["human","rgb_array", "text", "text_full"], help="Choose the rendering mode for the game.") + parser.add_argument("--debug", action="store_true", help="Flag to enable to print extra debugging info") + parser.add_argument("--self_play", action="store_true", help="Flag to enable training via self-play (as opposed to fixed opponent)") + parser.add_argument("--cpu-players", type=int, default=2, choices=[1, 2], help="Number of CPU players (options: 1, 2)") + parser.add_argument("--player", type=int, default=0, choices=[0,1], help="Choose which player to play as: red = 0, yellow = 1") + parser.add_argument("--record", action="store_true", help="Flag to save a recording of the game (game.gif)") + parser.add_argument( + "--win-rate", + type=float, + default=0.6, + help="the expected winning rate: Optimal policy can get 0.7", + ) + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="no training, " "watch the play of pre-trained models", + ) + parser.add_argument( + "--agent-id", + type=int, + default=2, + help="the learned agent plays as the" + " agent_id-th player. Choices are 1 and 2.", + ) + parser.add_argument( + "--resume-path", + type=str, + default="", + help="the path of agent pth file " "for resuming from a pre-trained agent", + ) + parser.add_argument( + "--opponent-path", + type=str, + default="", + help="the path of opponent agent pth file " + "for resuming from a pre-trained agent", + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + return parser + + +def get_args() -> argparse.Namespace: + parser = get_parser() + return parser.parse_known_args()[0] + + +def get_agents( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[BasePolicy, torch.optim.Optimizer, list]: + env = get_env() + observation_space = ( + env.observation_space["observation"] + if isinstance(env.observation_space, gym.spaces.Dict) + else env.observation_space + ) + args.state_shape = ( + observation_space["observation"].shape or observation_space["observation"].n + ) + args.action_shape = env.action_space.shape or env.action_space.n + if agent_learn is None: + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ).to(args.device) + if optim is None: + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + agent_learn = RainbowPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq, + ) + if args.resume_path: + agent_learn.load_state_dict(torch.load(args.resume_path)) + + if agent_opponent is None: + if args.self_play: + agent_opponent = deepcopy(agent_learn) + elif args.opponent_path: + agent_opponent = deepcopy(agent_learn) + agent_opponent.load_state_dict(torch.load(args.opponent_path)) + else: + agent_opponent = RandomPolicy() + + if args.agent_id == 1: + agents = [agent_learn, agent_opponent] + else: + agents = [agent_opponent, agent_learn] + policy = MultiAgentPolicyManager(agents, env) + return policy, optim, env.agents + + +def get_env(render_mode=None, args=None): + return PettingZooEnv(gobblet_v1.env(render_mode=render_mode, args=args)) + + +def train_agent( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[dict, BasePolicy]: + # ======== environment setup ========= + train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + # ======== agent setup ========= + policy, optim, agents = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim + ) + + # ======== collector setup ========= + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True, + ) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + + # ======== tensorboard logging setup ========= + log_path = os.path.join(args.logdir, "gobblet", "rainbow") + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + # ======== callback functions used during training ========= + def save_best_fn(policy): + if hasattr(args, "model_save_path"): + model_save_path = args.model_save_path + else: + model_save_path = os.path.join( + args.logdir, "gobblet", "rainbow", "policy.pth" + ) + torch.save( + policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path + ) + + def stop_fn(mean_rewards): + return mean_rewards >= args.win_rate + + def train_fn(epoch, env_step): + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train) + + def train_fn_selfplay(epoch, env_step): + policy.policies[agents[:]].set_eps(args.eps_train) # Same as train_fn but for both agents instead of only learner + + def test_fn(epoch, env_step): + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + + def test_fn_selfplay(epoch, env_step): + policy.policies[agents[:]].set_eps(args.eps_test) # Same as test_fn but for both agents instead of only learner + + + def reward_metric(rews): + return rews[:, args.agent_id - 1] + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn if not args.self_play else train_fn_selfplay, + test_fn=test_fn if not args.self_play else train_fn_selfplay, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric, + ) + + return result, policy.policies[agents[args.agent_id - 1]] + + +# ======== a test function that tests a pre-trained agent ====== +def watch( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, +) -> None: + env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) + policy, optim, agents = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent + ) + policy.eval() + if not args.self_play: + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + else: + policy.policies[agents[:]].set_eps(args.eps_test) + collector = Collector(policy, env, exploration_noise=True) + + # First step (while loop stopping conditions are not defined until we run the first step) + result = collector.collect(n_step=1, render=args.render) + time.sleep(0.25) + + while not (collector.data.terminated or collector.data.truncated): + result = collector.collect(n_step=1, render=args.render) + time.sleep(0.25) # Slow down rendering so the actions can be seen sequentially (otherwise moves happen too fast) + + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") + +# TODO: Look more into Tianshou and see if self play is possible +# For watching I think it could just be the same policy for both agents, but for training I think self play would be different +def watch_selfplay(args, agent): + raise NotImplementedError() + env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, debug=args.debug)]) + agent.set_eps(args.eps_test) + policy = MultiAgentPolicyManager([agent, deepcopy(agent)], env) # fixed here + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") + + +# ======== allows the user to input moves and play vs a pre-trained agent ====== +def play( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, +) -> None: + env = DummyVectorEnv([lambda: get_env(render_mode=args.render_mode, args=args)]) + # env = get_env(render_mode=args.render_mode, args=args) # Throws error because collector looks for length, could just override though since I'm using my own collector + policy, optim, agents = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent + ) + policy.eval() + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + + collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions + + pettingzoo_env = env.workers[0].env.env # DummyVectorEnv -> Tianshou PettingZoo Wrapper -> PettingZoo Env + if args.record: + recorder = GIFRecorder() + else: + recorder = None + manual_policy = gobblet_v1.ManualPolicy(env=pettingzoo_env, agent_id=args.player, recorder=recorder) # Gobblet keyboard input requires access to raw_env (uses functions from board) + + # Get the first move from the CPU (human goes second)) + if args.player == 1: + result = collector.collect(n_step=1, render=args.render) + + # Get the first move from the player + else: + observation = {"observation": collector.data.obs.obs.flatten(), # Observation not used for manual_policy, bu + "action_mask": collector.data.obs.mask.flatten()} # Collector mask: [1,54], PettingZoo: [54,] + action = manual_policy(observation, pettingzoo_env.agents[0]) + + result = collector.collect_result(action=action.reshape(1), render=args.render) + + while not (collector.data.terminated or collector.data.truncated): + agent_id = collector.data.obs.agent_id + # If it is the players turn and there are less than 2 CPU players (at least one human player) + if agent_id == pettingzoo_env.agents[args.player]: + # action_mask = collector.data.obs.mask[0] + # action = np.random.choice(np.arange(len(action_mask)), p=action_mask / np.sum(action_mask)) + observation = {"observation": collector.data.obs.obs.flatten(), + "action_mask": collector.data.obs.mask.flatten()} # PettingZoo expects a dict with this format + action = manual_policy(observation, agent_id) + + result = collector.collect_result(action=action.reshape(1), render=args.render) + else: + result = collector.collect(n_step=1, render=args.render) + + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") + +if __name__ == "__main__": + # train the agent and watch its performance in a match! + args = get_args() + result, agent = train_agent(args) + if args.cpu_players == 2: + + watch(args, agent) + else: + play(args, agent) diff --git a/gobblet/game/board.py b/gobblet/game/board.py index 6939563..c0ef6b8 100644 --- a/gobblet/game/board.py +++ b/gobblet/game/board.py @@ -1,7 +1,7 @@ import numpy as np class Board: - def __init__(self): + def __init__(self, squares=None): # internally self.board.squares holds a representation of the gobblet board, consisting of three stacked 3x3 boards, one for each piece size. # We flatten it for simplicity, from [3,3,3] to [27,] @@ -92,11 +92,13 @@ def is_legal(self, action, agent_index=0): if any(board[piece_size-1] == piece * agent_multiplier): current_loc = np.where(board[piece_size-1] == piece * agent_multiplier)[0] # Returns array of values where piece is placed if len(current_loc) > 1: - raise Exception("Error: piece has been used twice") + raise Exception("PIECE HAS BEEN USED TWICE") + return False # Piece has been used twice (not valid) else: current_loc = current_loc[0] # Current location [0-27] # If this piece is currently covered, moving it is not a legal action - if self.check_covered()[current_loc] == 1: + + if self.check_covered().reshape(3,9)[piece_size - 1][current_loc] == 1: return False # If this piece has not been placed @@ -150,6 +152,9 @@ def calculate_winners(self): self.winning_combinations = winning_combinations + def print(self): + print(self.get_flatboard().reshape(3,3).transpose()) + # returns flattened board consisting of only top pieces (excluding pieces which are gobbled by other pieces) def get_flatboard(self): flatboard = np.zeros(9) @@ -174,10 +179,10 @@ def check_for_winner(self): for combination in self.winning_combinations: states = [] for index in combination: - states.append(flatboard[index]) # default: self.squares[index] + states.append(flatboard[index]) if all(x > 0 for x in states): winner = 1 - if all(x < 0 for x in states): # Change to -1 probably? + if all(x < 0 for x in states): winner = -1 return winner @@ -215,4 +220,4 @@ def print_pieces(self): print("squares with covered pieces: ", covered_squares) def __str__(self): - return str(self.squares) + return str(self.squares.reshape(3,3,3)) diff --git a/gobblet/game/gobblet.py b/gobblet/game/gobblet.py index d480b8b..6a021d9 100644 --- a/gobblet/game/gobblet.py +++ b/gobblet/game/gobblet.py @@ -204,7 +204,7 @@ def action_space(self, agent): def _legal_moves(self): legal_moves = [] for action in range(54): - if self.board.is_legal(action): + if self.board.is_legal(action, self.agents.index(self.agent_selection)): legal_moves.append(action) return legal_moves @@ -216,7 +216,6 @@ def step(self, action): ): return self._was_dead_step(action) # check if input action is a valid move (0 == empty spot) - # assert self.board.is_legal(action), "played illegal move" if not self.board.is_legal(action, self.agent_selection) and self.debug: print("piece: ", self.board.get_piece_from_action(action)) print("piece_size: ", self.board.get_piece_size_from_action(action)) diff --git a/gobblet/game/greedy_policy.py b/gobblet/game/greedy_policy.py new file mode 100644 index 0000000..091f06b --- /dev/null +++ b/gobblet/game/greedy_policy.py @@ -0,0 +1,141 @@ +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch + +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as +from tianshou.policy import BasePolicy + +from gobblet.game.board import Board + +class GreedyPolicy(BasePolicy): + """ + Basic greedy policy which checks if a move results in a victory, and if it sets the opponent up to win (or lose) in the next turn + """ + + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.board = None + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + input: str = "obs", + **kwargs: Any, + ) -> Batch: + """Compute action over the given batch data. + + If you need to mask the action, please add a "mask" into batch.obs, for + example, if we have an environment that has "0/1/2" three actions: + :: + + batch == Batch( + obs=Batch( + obs="original obs, with batch_size=1 for demonstration", + mask=np.array([[False, True, False]]), + # action 1 is available + # action 0 and 2 are unavailable + ), + ... + ) + + :return: A :class:`~tianshou.data.Batch` with "act" key, containing + the greedy action. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + obs = batch[input] + obs_next = obs.obs if hasattr(obs, "obs") else obs + + obs_player = obs_next[..., :6] # Index the last dimension + board_player = np.array([ (i+1) * obs_player[..., i] + (i+2) * obs_player[..., i + 1] for i in range(0, 6, 2)]) # Reshapes [3,3,6] to [3,3,3] + obs_opponent = obs_next[..., 6:12] + board_opponent = np.array([ (i+1) * obs_opponent[..., i] + (i+2) * obs_opponent[..., i + 1] for i in range(0, 6, 2)]) + board = np.where(board_player > board_opponent, board_player, -board_opponent) + + agents = ["player_1", "player_2"] + agent_index = agents.index(obs.agent_id[0]) + opponent_index = 1 - agent_index # If agent index is 1, we want 0; if agent index is 0, we want 1 + + # If we are playing as the second agent, we need to adjust the representation of the board to reflect that + # TODO: it would be cleaner to add a thirteenth channel that explicitly encodes the player whose turn it is + # avoids issues of the wrong agent index (easier debugging) and enforces turn order in a stricter way + if agent_index == 1: + board = -board + + self.board = Board() + self.board.squares = board.flatten().copy() + + # Map agent index to winner value that the board will return (agent idx: [0,1], winner_vals: [1,-1]) + winner_values = [1, -1] + + legal_actions = obs.mask.nonzero()[1] + next_actions = list(legal_actions) # Initialize the same as legal actions, then remove ones that cause a loss + chosen_action = None + + + for action in legal_actions: + if self.board.is_legal(agent_index=agent_index, action=action): + next_board = Board() + next_board.squares = self.board.squares.copy() + next_board.play_turn(agent_index=agent_index, action=action) + + # If we can win next turn, do it + winner = next_board.check_for_winner() + if winner == winner_values[agent_index]: # Win for our agent + chosen_action = action + break + elif winner == winner_values[opponent_index]: # Loss for our agent + if len(next_actions) > 1: + next_actions.remove(action) + else: + break # If there is nothing we can do to prevent them from winning, we have to do one of the moves + print() + else: + # Otherwise, check what the opponent can do after this potential move + legal_actions_next = [next_act for next_act in range(obs.mask.shape[-1]) if next_board.is_legal(agent_index=opponent_index, action=next_act)] + # print("legal actions next", legal_actions_next) + + winner_next = {} + for action_next in legal_actions_next: + next_next_board = Board() + next_next_board.squares = next_board.squares.copy() + next_next_board.play_turn(agent_index=opponent_index, action=action_next) + + winner_next[action_next] = next_next_board.check_for_winner() + + # If the opponent can win in the next move, we don't want to do this action + if winner_next[action_next] == winner_values[opponent_index]: # Check if it's possible for the opponent to win next turn after + if len(next_actions) > 1: + if action in next_actions: + next_actions.remove(action) + else: + break # If there is nothing we can do to prevent them from winning, we have to pick one + + # # If we can put our piece in the place that the opponent would have gone to win the game, do that + # # BUT this might miss us from winning the game ourselves in blcoking them, so don't exit the loop + if self.board.is_legal(action_next, agent_index=agent_index): + chosen_action = action_next + # break + + # Pick the move if it prevents the opponent from winning no matter what he does after our move + if all(winner != winner_values[opponent_index] for winner in winner_next.values()): + chosen_action = action + break + + if chosen_action is None: + chosen_action = np.random.choice(next_actions) + print(f"Choosing randomly between possible actions: {next_actions} --> {chosen_action}") + act = np.array(chosen_action).reshape(1) + return Batch(act=act) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + """Since a random agent learns nothing, it returns an empty dict.""" + return {} \ No newline at end of file diff --git a/gobblet/game/manual_policy.py b/gobblet/game/manual_policy.py index 66b1c14..afd9e1f 100644 --- a/gobblet/game/manual_policy.py +++ b/gobblet/game/manual_policy.py @@ -67,8 +67,11 @@ def __call__(self, observation, agent): if piece_size_selected > 0: piece = piece else: - piece = unplaced[-1] - piece_size_selected = (piece + 1) // 2 + if len(unplaced) > 0: + piece = unplaced[-1] + piece_size_selected = (piece + 1) // 2 + else: + piece = -1 ''' READ KEYBOARD INPUT''' if event.type == pygame.KEYDOWN: @@ -79,7 +82,10 @@ def __call__(self, observation, agent): cycle_choices = np.unique( [(p + 1) // 2 for p in unplaced]) # Transform [1,2,3,4,5,6] to [1,2,3) - piece_size = cycle_choices[(np.amax(cycle_choices) - (piece_cycle + 1)) % len(cycle_choices)] + if len(cycle_choices) > 0: + piece_size = cycle_choices[(np.amax(cycle_choices) - (piece_cycle + 1)) % len(cycle_choices)] + # else: + # piece_size = -1 piece_size_selected = piece_size if (piece_size * 2) - 1 in unplaced: # Check if the first of this piece size is available @@ -147,7 +153,7 @@ def __call__(self, observation, agent): if recorder is not None: recorder.capture_frame(env.unwrapped.screen) - ''' PLACE A PIECE ''' + ''' PICK UP / PLACE A PIECE ''' if event.type == pygame.MOUSEBUTTONDOWN: # Pick up a piece (only able to if it has already been placed, and is not currently picked up) if flatboard[pos] in placed_pieces_agent and not picked_up: diff --git a/pyproject.toml b/pyproject.toml index 0194e84..e567e32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gobblet-rl" -version = "1.2.3" +version = "1.3.0" authors = [ { name="Elliot Tower", email="elliot@elliottower.com" }, ] diff --git a/tests/__pycache__/test_covered_pieces_tianshou.cpython-38-pytest-7.1.2.pyc b/tests/__pycache__/test_covered_pieces_tianshou.cpython-38-pytest-7.1.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e353ce137239d4a046244d38bf5812c6710ea21c GIT binary patch literal 8065 zcmc&(TWB2D8J?NFk5;Rzb+OZU`yW$p`@ks`_Ig-c5Nd&O-Oe3 z{PUkP|9{T>m-C-<{`qQuzpmi-=XXCSersA${zjGVUj&tlDB|w`OfeOvGR;v_E^C!O)iggGJclrveGGN&Z2In(B}ow7&lnRRVaVftN_^|AEZ zQS&exVf`%gwrU<>nrt6ngJ_?%4K{RB;~&~bcNFuOeVh&NsH$Sm+Q;b5v^SOPsMW73 zihg6GEDZcDW4+-xhW(aRsXMk&TQ?e_TrC+l%2riu)EdU^vb$kiy>$IG!>t)Ex5`z+ zD%n+6oHB|v;&JVO#^R+$1rLn6#bp;KFTeJ3*08F~aB8IzI?d_&rG~pvE!*S0`iHE!AH5ch_#**Rk|x01`x>-u&3JB>28 zX-IKSuiLJRA+~EZ4%Jg!Jer&b8aZ2-hu}xt69xYzFv1RkmdNX z8x4q-RsGDBMy0a#HM{86_zTrrXz9Ne3}x18)S&tL^Omr$)tqv1%OAhmaNP1GQc%!x z)v7{W9DI7IwzdX!6>jBGJA?-f%L#}s*BnQVfCuRpxoxwp-~sA!Zo3U$Ez&!-f8@JV zE~AJi0bFHAX)3O|BTGqZs;=g#p0cAgHBa-D&B!|0UR}*b{X_}xU{`N(njbs?;Ku`t zW+Q%t9`wt$x|HZ$g%6_o@r}8!3YcNecARp}b!)e6K6j;7lol6rFP7bx8f$Z*gwJp& z*A}jjuu!bQfHo`C%XZNgg|^yr^)2WG?-C4t1fz%~fTG6Lw5qFo7j2qii`A5J82&-y%4&00HsVv#11m_gddXYVKSLG)p1#nO&C7#$* zvxl5O0&-$M9YxZw+EenUu8;EX~t0!C`pu5 zQiyvtK?PctHREU&~`yG1Nb_c!2&hq5dNx-mU*#P1H#* zr1xEoZ<8(#uz@CiLDbR)8vhkh6Mmo#20J#;#0Qc)Bggtsasvl9X?(!@WHuPRcZdz` zVcC)T9{vohF~IJ>?$`cf*f6vPw0M9u2108@#ZP*y5fT3-4M}T6x%BG;tPvHbq&{gz zMESDR;iIfE(y>NVY)b9{$?ZyR;IPJs#6OKSMjv5~`}IhA@749=we<2t9_S?!T4V4) zYYc|gh>K%A)`*E8k%pu-V*ID1l>@916Tc^FIX7bbJ*mS-S>sU08Zq%N$vp^b#JMK9 zfx{YyB>rivG4`2Qk{6D}N3GtHD zC(Vrnzb#%d101rgu(MrMKa^A> zlIq8j3OK0FvU7Jetdpbc%j~>2T2_G{3fIg-?KQI*4_Cp*o7WHXrTg0+&$q@7^sTYb zw{(&2@vW4&Mbb*&O7ZWJvdi^^Z4rDQwhCbpYtdW~j?72Hyr#Agy4b;JP@ouU)?#`eOwK;+N zIPR0)gf|Hv!_IEEPil*NECC<8L{FrTJ>QX07k?xf-5JQJ^Y=+E!Xe}H>~a?sNuTlh zCDlh#df=eCEOE>fw%=!8IRB~7zHt7x`@GWWvrk- zQ|v43MQ=*_?{w(D({2C7K5v5de9cMB)R$h+ud%BS ze^P?aG-pD8#ax!V=EHK=d<47Z)aFR@Fq>{3L77D{yu;oRZ`Lz*BF&@NQy=wa;LXRp zW4rxAW2^gn>fN*;Sp#oQ^Hq}j0dIc2qp`H$MD0bPu{8IjmViTJuS@(P`)b*KINWC+ z3GcJv-3aeTq3LT~qx?mVLe@)jngIh|V7+S+KL+YI?)Jkj$3t5j@3KYD{u_2VjB*5J z7R5k0+C0V-oX)PV_8h0^Topxm6GhPBOCG0?r?^9>)yS$D+h|m*sywS@gX37&kxd@n z=-lQf&pTFSjakoTQTr!~vOPlw(85~H;w*RGsTD0pJPRZsyOCARA^WYmQ{xsL_zTrq z8K+z1MU?GdS^}tRJaMwVm0!pT{kv_o;J6Ogd!bh`;i*tCf5nB`h}Ggn&^fE7#J)SIV*cfw#Q91~NT% z$#M)!819C>Wmr7SBvouyZd8mQUyw6SE-uR~(&8B+y^RcnvDR?!??`Tm{pfx~%{HjHM)uPl&>nQ>!36OWSzC_?@ z0;dSf5tt{CC$KJu)Um68W%d>ZM2jCmgP8zU5H2$ls1X98R!I^#Y%ooL z>`w_f3G&ig+aj;D34!eUza(4^+&Us$PB|}15#SN=kvdIbCIcz{183 z;Xd8thn+_CHn-|^n<47R;EB6HOoKsy0d81zJB)96SUZeKh+QFup%9&d%0d{bGW*F{ zN<6mBx{~eR_w#+Ta>@kqSpXTO@)9mecv`|U5}qZvAR`$xJIiUF%akEc5o_Xhjo(DV zGD!A~qsfoS5bgJMMQT4van_6=%=%G^vVQtveXFq6C?kuD84wRv4qr#BTu?0TPZR=I zDsa2b3nizvW;unbRk8h4L0IcZb=J9`xg=K+#L*WxuW^5*uwLU83yD{(I!Fo22%8DR zH+l-(J=Cusb<5~3CLG(Y`>{GNS6wsJF^!onbOj54sDQ=Ik#9o@d6>#itUEQ!U0(8I z;9Oqf*GMBVX-_i|nlpPMPz{AVnepWlgwayWCyBm~G!gi3zPqy2U0Lp~te8@JAx}Q2 z1#;&X`Ely`4Fa@$aEh&*!YZeT%3mh%3P8|dz_r+2In7@sn(G~!V4Q{SH#pPHaJJ_K z^85yIcE90#Pi>*6w$xKw?y2?2xzHo$!eZA3Mea@X{f!7oE5>M5QZqRp0<=@G8w~SxtsM;4LYf?L} z>RJ*u-cZued-QYH_$bCduO=0`;>vsSEz(+2O{+;Y`Clyp*nR&vo@!eyj-5YUX@z|B zPk4JyOQ)YK6e{@0SSV}{zUf+=KI$1IMAuv{hmSj}*%N-EP++xUq2MRmahiXVI4281 zkj@2x0kIbg)@@#P5t%zpK5b&Cjem=H;@ARV1Wx-!PBYVtl(s5T{uW(;5c ziW_FS^WBO*k)k9j6Uklp=@{^v0F|dOxV<$x$eEAitqgz_2NOkhV8$OKjQ6ec4kmsp= uAfs3EJiehd9Q#>Pke~n-O(&IMWe_%|D~)R~LJ$^@jK%e6W<9CJ@#KFVtK}g8 literal 0 HcmV?d00001 diff --git a/tests/__pycache__/test_tianshou_covered_pieces.cpython-38-pytest-7.1.2.pyc b/tests/__pycache__/test_tianshou_covered_pieces.cpython-38-pytest-7.1.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0be440427877f74969ff3d52960a2ebd03af38da GIT binary patch literal 11980 zcmc&)Yit|Yb)Fdxhvbl=p4Q84%hw{-uz)ipTw1S(2D! zN=#;oBb$nN$EvC#$x=0LYI0DLFcYFAX(sWGIjO2{>Y_|_($$QasScO}A#K*o3fe() zNW6#5VZ7tcNOja4t&W*vf=+V|RL9M60VkY^>ZCa-;G{EEoi?Yd2hD@k8FQw3$UG$Q zDd%wYh&#YjW=_C^&U4l0 z&F2L?9KifiECGB%oyW7uz7 z)rMmm^$nxxR%&JA&5BiXH|tH~e#P4~Ze6?cn&H(Ak6V?RVU_Kg=bkZ2b*kgp0gc7W z%_<6vhQ&n_vo~IQC1+SQW;pe78I2Zn{aVx8tn*dla}{(tyXkoicXe*Aj6Ry{1++Iu z1vi__vA1)&etR3#HRGO*2JvpXHd{3o3oC{Bv$}rA{!+8TZR*lp)f=|wp^H|%Zd|%# z%oi3G3TJh+`}|Y-Kxg@?Gv|d<)VrkbYdP6ZGTUyH>W%H3;wN8kc<9t}{P>$qOcSWo zjk@F5C9lr?kgH$GLU7_%)A1_Th_M0YmQ@3rxu3ZM2G*+q5)BU5H#QogdQk{N^gH~fU32Sv zy>4;#8h9;`u{GPH5xjvhS$z9*yy6MGUaqgNLu|!+^C&)mj+&MewBMf}3eqpV^5Px4 zROd|8%DlvFn{D^Up>f-5@>+>l*ZYHBNl836@VEsnKl0@)_g@Se6YwRKpgBP-8ekF@pSPBtM4yfk=KF`H4t= z68Wh}etO@C4x)S}Qho^e!|X`#E33@tkL2jSmX7VqANS(yM5LvY`+7Qs^3(XuvK)I3 zEkDoBJdB64%(DWUV`uw(oJW2kl3zq$OZc5DD{MJZ?>z7;_$>i{A%b5;{ZB>eU&L50 zvC9#@USU^3bB$f^&+i7>I>&BCXkNgGUt}*uFfR+tD-q1AUfk2R5jV47~NwX}TRg`QPRO_H$URByXs`w3=mSXdt|J)gH5Zr-z9ul%jwqrZ2rw9+Ax zOjC&ciz}^xe12Ul%e-4@*%f0=l3Hp$@9=W|#>~I}%kuPdhyR`0A$ZG+N1gzLs}AyU~pS*&55| zJ=j9q%QqX$^6dNux4%@mip4=a+KqZ?v%5fqLzpXS^XS4W@t?8$iWso!Mr5vstxjW{Q{Bix%hBHe_bM?U^Y` zi?p}2PT$z9*InE2Hf({7!4(Xq<3E7>&*292NwuV4d0&)hOZ1=@CuN^PSKWzwt?S@fVqXpL(-8%46; zt(*H>B3+`I_bSi`kmZTYn6+G^-e_Phj5^z>ybXX5HJI%jai-L03Q}xXU9yH<7(uCZ za;;fihx%Z1y>a_3gW3XtFT(E^`Rt?%9v098^J2FaE&4AV$XA9>^UiV2aUu@0SLazZm zpD_ed^ez~RkuNtFzX+s3?bNEQl#9=q^Z@(s(|^rhBSYMEeSP}nju8x^P)=_$8C zP5j&7^(q(|9l0QH4d(MbZ4iyNUWqo@EBfi$hH^0b9;V@|r6#k?7*s8n@{?H2{{v**seRPkPJuFr)|MV`a2f(>y+TI{)4 z#ZnzkoXv`jiVaB=8r!}OT`2|~okB0}*O5qaQXcC6R9TV4@5fqf7fVpDM30orhIW}@eDkA9YM<+pF>iKhSmF{bxAAw1|D|=i6=b}?hUy)OrDaki79Jh z`Q{W+-5(NDEZ(iv@7G8iED!CwX7nMWKHJFc<182N+o_eP;H_6 zem2PTbjpvDEvfPYs8ABssR{hi#*x4yM%=>q#fVi;ZmV007u!;|;tv#VP>#l5Qwi$C z+i}WmX^oW-O0=b75@jjiUKMY`=xv!vca!{kUK%c55|552gJ*y#57nJ4K~Ln4){Y6f zZ+e4ULtDdbZD-^OIf3MyDV>-%)&>Rg2k?yJnPBR}ct-{_$-mc@cc#3-ooUKFk)Pl@ zrmawG_|91c?jYVnAJKmTJtUsegHn4AEzF>WCGSvMejs<0w$kryCg`oEd50lo1(Ziv zoW8HzJAe(QiEpjXx8=4{RyxUAvaPgbg1n<`@alomNwrhxRoyza5rbwV*35K$9ZTt+ zMf(tTq)f0=7lWNz3l=zhZ^faniIKpKC$smFrLk-Ge7%D2E!VSZ zunxnK9*SU!pS7K0hv8DO0Nz$z|b}i9Riisg3pTZJL9y zN(_ZG4o0I>31O!C`?ERF({oXQ(-P!k)J`&31{G#%;nta^E;WI6b2>jyq$`wMpo9kN z$G|gw4UieD*`Ap}QtWlWKTQ-`u%+!^rNU$dP1ix@I!Giz`m~~Gn)NWvVB#WZD7)|I5J%N$OuRHE7U z{s@lai5MhL-h!%PVagAc2l5-zyUN$zR1Fv6AsY z`1x2z?Zi8pr?gdA3VaqmjCHVHw_;3#6;iiUtf}~phLWG#ue@A_t?8t^_*SB$w-Z~* zPP(1mN_8^r%$D96XzT5Pb_#1e+s?uwd;_1ZNiV&Vc_Q7_+z-5govassBEzz1fCmZ2 z$4-*qw*EwBsU9urK1bB;grHs$6o5gU5^(y7EZ&*$4SUM;NbOBgJ0)s!K?xYObpdBV zoqn7gmpX&q5HQ2-!6)GVS75hQsiWgb-+NwW5ad7<;BX#zJ{$?#do{#NBTb zkAytp{7;B4&yYvl{Q*%6dBnLdxcE`>IMA1e=4M50@a~buCq-?*kjDW5e-e3&|BU3Z ze|L0` zO!VcEaDQLaJ^*vp;fIo>m;G=hjya_fXP8_D$K{}lfVN!z~WyjcY1U^r) zQ@!9Og*I6j(q!`pWG;%(YQlKXIN) zHk2nSTj|d&>9)k&pcjvl{2ehLz?j<#yAYw;5mXa`>W>8#U{GCPtB)1zmy-yJUTjZR zWWZD5-Z|CVJ3D%~7k<2z{WM+LkNbGKHT_K8nhteKcTeomt(5z9nyt{S6#reC?K5;M z<^Hwc4fN?zivOM9%}43hr9NjoDyZ=nH^bUzx{u|AVV9fG5yV0jgyX5_4+t|$?rSX*tDuO|EL%{gL zWV>yqUGn~dHq-nvQS{rq8EMlbs`f$ACKK%dMw>SUj8U)`9&6Cr!xWHj9~K&XB-G#| zJq?D(H$&cE=MdKF%d}R7M!&-FE5|5}zJYm_eL9NiQJ~~CcI)XS+5fIqCyulG5YamoA3wv*skhX0P>5@2xbbpb!c!xs3Xp~rnJd>;?rC&Kqh zaQb$nm;VyI7@`;Y0vK!qMlZJod@y}2~-*@lI!3Ui9)(gT< z7agm*&aBISh2sClBeiBIidPgTiiL|#y<|D=WdH#gn(Tj5znPQGVMK_XI-3(zx8&NjB^7jmxtYdr|TY6AN>03qsZ&Mk<8gd)y&El2On65)kjWdI@J z`5p>Dh&tCp0SHlxXbz_^mfk5KGJ>>fxp(!>nP5!lET?yTrPlCM;aLj~cN6ud*JyfY zb4fqe*gi|i0ws%-oTKDCB^P`Z(>iNXy3ic~*F<;*%%RY=j}ub*r1x~J#~%wmHgs4gAU~#W~|!=PRi{v;_pG^8)pIt zZ;R9VDh|+^RU?eG8ncVbBIvz%o=EQ_jBBhnJ&Kk`>R5(@C?f*k6rjy<3d!;*BwbRk zF2xwSIFggP6in<=iMo%`R7T%ze)lo%QP?2MUkEp>3NY}pmaf4FL2i2397Y530X2T zalg**QXp5v2Hyh8j|b4M!tcIxYvQ>-Cw?5B#Zg^M&m<1|a`9N14@;2b7-w1^|Ps)cwrb{brb zp`FIeh9(;=1?nF1mw%CltKvY!Mx@)pogP0fPKM1?Up{807_k`sa1lG4BZfq`njkMf zAx_Jdm;5-+&X$+N$};0ZsAeLRX>N9ZEL5=T7Wx;lW)V*2pQY^ADES-`GfgZC6nj3J zS&C+sqnQ;`ur;hRALIgt%r6E9rN2eZTa=K4z->xsDe(uCyn`fYFsQW{&79-EK{UVJ zrwRI5h>qZVw8Diw1DNLzsb+M9XZPe5_T-lK z@jWq$3{HbGI2%g-uM(p(qR9I_E!C5nh6q3XGB^a*r0-+&dYLH=D%~B_~&wXQ-K>K2dnPl9O|5X8Y)- zS8-fU&aOGN7g@kLB3)sv;))UZ!RCaxVo?w^3f}GwO->H7nFw!8h!Z&TM3^gx`#Aki z5lD(|XBNW|i=BC74{0!-QEG8|PeEDS*X-3C0|EaQCFFkxR})fe*mT9Qu()UtoaCm& ztsdMQ<8WdtxG5rB0`UR4gPNp0v{x64!N&KPT3gtg2pAx?1WNhf>Q19x-2XvmQ;it<(?`Ll4adCCyM?I|aL8huJ H8e025f1EJ1 literal 0 HcmV?d00001 diff --git a/tests/test_manual_policy_collector.py b/tests/test_manual_policy_collector.py new file mode 100644 index 0000000..9ac019b --- /dev/null +++ b/tests/test_manual_policy_collector.py @@ -0,0 +1,145 @@ +# adapted from https://github.com/Farama-Foundation/PettingZoo/blob/master/tutorials/Tianshou/3_cli_and_logging.py +""" +This is a full example of using Tianshou with MARL to train agents, complete with argument parsing (CLI) and logging. + +Author: Will (https://github.com/WillDudley) + +Python version used: 3.8.10 + +Requirements: +pettingzoo == 1.22.0 +git+https://github.com/thu-ml/tianshou +""" + +from typing import Optional, Tuple + +import gym +import numpy as np +import torch +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.policy import BasePolicy, MultiAgentPolicyManager + + +from gobblet import gobblet_v1 +from gobblet.game.collector_manual_policy import ManualPolicyCollector +from gobblet.game.greedy_policy import GreedyPolicy +import time + + +def get_agents() -> Tuple[BasePolicy, list]: + env = get_env() + agents = [GreedyPolicy(), GreedyPolicy()] + policy = MultiAgentPolicyManager(agents, env) + return policy, env.agents + +def get_env(render_mode=None, args=None): + return PettingZooEnv(gobblet_v1.env(render_mode=render_mode, args=args)) + +# ======== allows the user to input moves and play vs a pre-trained agent ====== +def test_collector() -> None: + env = DummyVectorEnv([lambda: get_env(render_mode="human", args=None)]) + policy, agents = get_agents() + collector = ManualPolicyCollector(policy, env, exploration_noise=True) # Collector for CPU actions + pettingzoo_env = env.workers[0].env.env + + output0 = np.array([[True, True, True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True]]) + assert(np.array_equal(collector.data.obs.mask, output0)) + + ''' PLAYER 1''' + action = np.array(18) + result = collector.collect_result(action=action.reshape(1), render=0.1) + output1 = np.array([[False, True, True, True, True, True, True, True, True, False, True, True, + True, True, True, True, True, True, False, True, True, True, True, True, + True, True, True, False, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True]]) + assert(np.array_equal(collector.data.obs.mask, output1)) + + time.sleep(.25) + + ''' PLAYER 2 (covers it)''' + action = np.array(36) + result = collector.collect_result(action=action.reshape(1), render=0.1) + + output2 = np.array([[False, True, True, True, True, True, True, True, True, + False, True, True, True, True, True, True, True, True, + False, False, False, False, False, False, False, False, False, + False, True, True, True, True, True, True, True, True, + False, True, True, True, True, True, True, True, True, + False, True, True, True, True, True, True, True, True]]) + assert np.array_equal(collector.data.obs.mask, output2) + + time.sleep(.25) + + ''' PLAYER 1''' + action = np.array(27+1) + result = collector.collect_result(action=action.reshape(1), render=0.1) + + output3 = np.array([[False, False, True, True, True, True, True, True, True, + False, False, True, True, True, True, True, True, True, + False, False, True, True, True, True, True, True, True, + False, False, True, True, True, True, True, True, True, + False, True, True, True, True, True, True, True, True, + False, True, True, True, True, True, True, True, True]]) + assert(np.array_equal(collector.data.obs.mask, output3)) + + time.sleep(.25) + ''' PLAYER 2 (covers it)''' + action = np.array(45+1) + result = collector.collect_result(action=action.reshape(1), render=0.1) + + output4 = np.array([[False, False, True, True, True, True, True, True, True, + False, False, True, True, True, True, True, True, True, + False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, + False, False, True, True, True, True, True, True, True, + False, False, True, True, True, True, True, True, True]]) + assert(np.array_equal(collector.data.obs.mask, output4)) + + time.sleep(.25) + + ''' PLAYER 1 (tries to move covered piece [ILLEGAL])''' + action = np.array(27+2) + # Moves 18-35 should be illegal as they are with medium pieces. (36 and 37 as well but they are with a large piece) + + output6 = [2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 38, 39, 40, 41, 42, 43, 44, 47, 48, 49, 50, 51, 52, 53] + legal_moves = pettingzoo_env.unwrapped._legal_moves() + + assert(output6 == legal_moves) + + output5 = np.array([[False, False, True, True, True, True, True, True, True, + False, False, True, True, True, True, True, True, True, + False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, + False, False, True, True, True, True, True, True, True, + False, False, True, True, True, True, True, True, True]]) + assert(np.array_equal(collector.data.obs.mask, output5)) + + result = collector.collect_result(action=action.reshape(1), render=0.1) + # Result should be empty because this is an illegal move + + output7 = {'n/ep': 0, 'n/st': 1, 'rews': np.array([], dtype=np.float64), 'lens': np.array([], dtype=np.int64), 'idxs': np.array([], dtype=np.int64), 'rew': 0, 'len': 0, 'rew_std': 0, 'len_std': 0} + assert(str(result) == str(output7)) + + # Board state should be unchanged because the bot tried to execute an illegal move" + output8 = np.array([[[ 0., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]], + [[ 3., 4., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]], + [[-5., -6., 0.], + [ 0., 0., 0.], + [ 0., 0., 0.]]]) + assert(np.array_equal(pettingzoo_env.unwrapped.board.squares.reshape(3,3,3), output8)) + + +if __name__ == "__main__": + # train the agent and watch its performance in a match! + print("Starting game...") + test_collector() \ No newline at end of file