From f2b442d7df2b037223632cc768d87590f20309bb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 11:19:22 -0800 Subject: [PATCH] [Feature] ChessEnv ghstack-source-id: 42d198f63f23b5ebdff5d30870c2e4210692a1b0 Pull Request resolved: https://github.com/pytorch/rl/pull/2641 --- torchrl/envs/__init__.py | 2 +- torchrl/envs/custom/__init__.py | 1 + torchrl/envs/custom/chess.py | 197 ++++++++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 torchrl/envs/custom/chess.py diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index f3dec221ce0..b863ad0801c 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,7 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv +from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index 375a0e23a57..d2c85a7198f 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .chess import ChessEnv from .llm import LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py new file mode 100644 index 00000000000..2e7feb610c9 --- /dev/null +++ b/torchrl/envs/custom/chess.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch +from tensordict import TensorDict, TensorDictBase +from torchrl.data import Categorical, Composite, NonTensor, Unbounded + +from torchrl.envs import EnvBase + +from torchrl.envs.utils import _classproperty +from typing import Dict + +class ChessEnv(EnvBase): + """A chess environment that follows the TorchRL API. + + Requires: the `chess` library. More info `here `__. + + Args: + stateful (bool): Whether to keep track of the internal state of the board. + If False, the state will be stored in the observation and passed back + to the environment on each call. Default: ``False``. + + .. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape. + Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves, + valid random actions cannot be taken. :meth:`~torchrl.envs.EnvBase.rand_action` has been adapted to account for + this behavior. + + Examples: + >>> env = ChessEnv() + >>> r = env.reset() + >>> env.rand_step(r) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None), + hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b KQkq - 1 1, batch_size=torch.Size([]), device=None), + hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> env.rollout(1000) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), + fen: NonTensorStack( + ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ..., + batch_size=torch.Size([322]), + device=None), + hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), + fen: NonTensorStack( + ['rnbqkbnr/pppppppp/8/8/2P5/8/PP1PPPPP/RNBQKBNR b ..., + batch_size=torch.Size([322]), + device=None), + hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), + reward: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.int32, is_shared=False), + terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([322]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([322]), + device=None, + is_shared=False) + + + """ + + _hash_table: Dict[int, str] = {} + + @_classproperty + def lib(cls): + try: + import chess + except ImportError: + raise ImportError( + "The `chess` library could not be found. Make sure you installed it through `pip install chess`." + ) + return chess + + def __init__(self, stateful: bool = False): + chess = self.lib + super().__init__() + self.full_observation_spec = Composite( + hashing=Unbounded(shape=(), dtype=torch.int64), + fen=NonTensor(shape=()), + turn=Categorical(n=2, dtype=torch.bool, shape=()), + ) + self.stateful = stateful + if not self.stateful: + self.full_state_spec = self.full_observation_spec.clone() + self.full_action_spec = Composite( + action=Categorical(n=-1, shape=(), dtype=torch.int64) + ) + self.full_reward_spec = Composite( + reward=Unbounded(shape=(1,), dtype=torch.int32) + ) + # done spec generated automatically + self.board = chess.Board() + if self.stateful: + self.action_spec.set_provisional_n(len(list(self.board.legal_moves))) + + def rand_action(self, tensordict: Optional[TensorDictBase] = None): + self._set_action_space(tensordict) + return super().rand_action(tensordict) + + def _reset(self, tensordict=None): + fen = None + if tensordict is not None: + fen = self._get_fen(tensordict) + dest = tensordict.empty() + else: + dest = TensorDict() + + if fen is None: + self.board.reset() + fen = self.board.fen() + else: + self.board.set_fen(fen.data) + + hashing = hash(fen) + + self._set_action_space() + turn = self.board.turn + return dest.set("fen", fen).set("hashing", hashing).set("turn", turn) + + def _set_action_space(self, tensordict: TensorDict | None = None): + if not self.stateful and tensordict is not None: + fen = self._get_fen(tensordict).data + self.board.set_fen(fen) + self.action_spec.set_provisional_n(self.board.legal_moves.count()) + + @classmethod + def _get_fen(cls, tensordict): + fen = tensordict.get("fen", None) + if fen is None: + hashing = tensordict.get("hashing", None) + if hashing is not None: + fen = cls._hash_table.get(hashing.item()) + return fen + + def _step(self, tensordict): + # action + action = tensordict.get("action") + board = self.board + if not self.stateful: + fen = self._get_fen(tensordict).data + board.set_fen(fen) + action = str(list(board.legal_moves)[action]) + # assert chess.Move.from_uci(action) in board.legal_moves + board.push_san(action) + self._set_action_space() + + # Collect data + fen = self.board.fen() + dest = tensordict.empty() + hashing = hash(fen) + dest.set("fen", fen) + dest.set("hashing", hashing) + + done = board.is_checkmate() + turn = torch.tensor(board.turn) + reward = torch.tensor([done]).int() * (turn.int() * 2 - 1) + done = done | board.is_stalemate() | board.is_game_over() + dest.set("reward", reward) + dest.set("turn", turn) + dest.set("done", [done]) + dest.set("terminated", [done]) + return dest + + def _set_seed(self, *args, **kwargs): + ... + + def cardinality(self, tensordict: TensorDictBase|None=None) -> int: + self._set_action_space(tensordict) + return self.action_spec.cardinality()