-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] ChessEnv #2641
[Feature] ChessEnv #2641
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Optional | ||
|
||
import torch | ||
from tensordict import TensorDict, TensorDictBase | ||
from torchrl.data import Categorical, Composite, NonTensor, Unbounded | ||
|
||
from torchrl.envs import EnvBase | ||
|
||
from torchrl.envs.utils import _classproperty | ||
|
||
|
||
class ChessEnv(EnvBase): | ||
"""A chess environment that follows the TorchRL API. | ||
|
||
Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__. | ||
|
||
Args: | ||
stateful (bool): Whether to keep track of the internal state of the board. | ||
If False, the state will be stored in the observation and passed back | ||
to the environment on each call. Default: ``False``. | ||
|
||
.. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape. | ||
Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves, | ||
valid random actions cannot be taken. :meth:`~torchrl.envs.EnvBase.rand_action` has been adapted to account for | ||
this behavior. | ||
|
||
Examples: | ||
>>> env = ChessEnv() | ||
>>> r = env.reset() | ||
>>> env.rand_step(r) | ||
TensorDict( | ||
fields={ | ||
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), | ||
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None), | ||
hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), | ||
next: TensorDict( | ||
fields={ | ||
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b KQkq - 1 1, batch_size=torch.Size([]), device=None), | ||
hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), | ||
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int32, is_shared=False), | ||
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, | ||
batch_size=torch.Size([]), | ||
device=None, | ||
is_shared=False), | ||
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, | ||
batch_size=torch.Size([]), | ||
device=None, | ||
is_shared=False) | ||
>>> env.rollout(1000) | ||
TensorDict( | ||
fields={ | ||
action: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), | ||
done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
fen: NonTensorStack( | ||
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ..., | ||
batch_size=torch.Size([322]), | ||
device=None), | ||
hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), | ||
next: TensorDict( | ||
fields={ | ||
done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
fen: NonTensorStack( | ||
['rnbqkbnr/pppppppp/8/8/2P5/8/PP1PPPPP/RNBQKBNR b ..., | ||
batch_size=torch.Size([322]), | ||
device=None), | ||
hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), | ||
reward: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.int32, is_shared=False), | ||
terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)}, | ||
batch_size=torch.Size([322]), | ||
device=None, | ||
is_shared=False), | ||
terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), | ||
turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)}, | ||
batch_size=torch.Size([322]), | ||
device=None, | ||
is_shared=False) | ||
|
||
|
||
""" | ||
|
||
_hash_table = {} | ||
|
||
@_classproperty | ||
def lib(cls): | ||
try: | ||
import chess | ||
except ImportError: | ||
raise ImportError( | ||
"The `chess` library could not be found. Make sure you installed it through `pip install chess`." | ||
) | ||
return chess | ||
|
||
def __init__(self, stateful: bool = False): | ||
chess = self.lib | ||
super().__init__() | ||
self.full_observation_spec = Composite( | ||
hashing=Unbounded(shape=(), dtype=torch.int64), | ||
fen=NonTensor(shape=()), | ||
turn=Categorical(n=2, dtype=torch.bool, shape=()), | ||
) | ||
self.stateful = stateful | ||
if not self.stateful: | ||
self.full_state_spec = self.full_observation_spec.clone() | ||
self.full_action_spec = Composite( | ||
action=Categorical(n=-1, shape=(), dtype=torch.int64) | ||
) | ||
self.full_reward_spec = Composite( | ||
reward=Unbounded(shape=(1,), dtype=torch.int32) | ||
) | ||
# done spec generated automatically | ||
self.board = chess.Board() | ||
if self.stateful: | ||
self.action_spec.set_provisional_n(len(list(self.board.legal_moves))) | ||
|
||
def rand_action(self, tensordict: Optional[TensorDictBase] = None): | ||
self._set_action_space(tensordict) | ||
return super().rand_action(tensordict) | ||
|
||
def _reset(self, tensordict=None): | ||
fen = None | ||
if tensordict is not None: | ||
fen = self._get_fen(tensordict) | ||
dest = tensordict.empty() | ||
else: | ||
dest = TensorDict() | ||
|
||
if fen is None: | ||
self.board.reset() | ||
fen = self.board.fen() | ||
else: | ||
self.board.set_fen(fen.data) | ||
|
||
hashing = hash(fen) | ||
|
||
self._set_action_space() | ||
turn = self.board.turn | ||
return dest.set("fen", fen).set("hashing", hashing).set("turn", turn) | ||
|
||
def _set_action_space(self, tensordict: TensorDict | None = None): | ||
if not self.stateful and tensordict is not None: | ||
fen = self._get_fen(tensordict).data | ||
self.board.set_fen(fen) | ||
self.action_spec.set_provisional_n(self.board.legal_moves.count()) | ||
|
||
@classmethod | ||
def _get_fen(cls, tensordict): | ||
fen = tensordict.get("fen", None) | ||
if fen is None: | ||
hashing = tensordict.get("hashing", None) | ||
if hashing is not None: | ||
fen = cls._hash_table.get(hashing.item()) | ||
return fen | ||
|
||
def _step(self, tensordict): | ||
# action | ||
action = tensordict.get("action") | ||
board = self.board | ||
if not self.stateful: | ||
fen = self._get_fen(tensordict).data | ||
board.set_fen(fen) | ||
action = str(list(board.legal_moves)[action]) | ||
# assert chess.Move.from_uci(action) in board.legal_moves | ||
board.push_san(action) | ||
self._set_action_space() | ||
|
||
# Collect data | ||
fen = self.board.fen() | ||
dest = tensordict.empty() | ||
hashing = hash(fen) | ||
dest.set("fen", fen) | ||
dest.set("hashing", hashing) | ||
|
||
done = board.is_checkmate() | ||
turn = torch.tensor(board.turn) | ||
reward = torch.tensor([done]).int() * (turn.int() * 2 - 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a nit: I think this line is a little hard to read. Maybe something like this is a little easier to grasp?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since there are multiple ways of ending the game maybe we should be a bit more comprehensive there.: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it's worth considering that usually in tournaments, if there's a checkmate or resignation, the winner gets 1 point and the loser gets 0, but if it's a draw of any kind, both players get 0.5 points. source But I think whatever the values of win/lose are, the value of draw should probably be the average of win and lose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also chess evaluation engines, like stockfish, use 0 if the position is equal (including draws), negative score for a black advantage, and positive score for white advantage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok so what about 0 / 0.5 / 1? Also worth considering: more granular reward There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that -1/0/1 is probably more consistent with existing chess engines, and it can encapsulate the value of the game for both black and white in a single number. 0/0.5/1 is probably also good, but then I guess we'd need two rewards, one for black and one for white. That is, unless we always set the reward to 0.5 for non-terminal states, but I'm not sure that would be ideal for
I think that could promote moves that do not optimize actually winning the game. For instance, if the reward decreases when losing a piece, then it could discourage making a sacrifice (or a combination of sacrifices) that leads to a forced checkmate. Likewise, if the reward increases when taking pieces, it could potentially encourage taking a piece rather than properly defending against an imminent checkmate threat. Although it depends on what you want. Sometimes you don't want a chess bot to play optimally--like if it is meant to play against humans who are no match for the best chess bots There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with all of this. With the |
||
done = done | board.is_stalemate() | board.is_game_over() | ||
dest.set("reward", reward) | ||
dest.set("turn", turn) | ||
dest.set("done", [done]) | ||
dest.set("terminated", [done]) | ||
return dest | ||
|
||
def _set_seed(self, *args, **kwargs): | ||
... | ||
|
||
def cardinality(self, tensordict: TensorDictBase|None=None) -> int: | ||
self._set_action_space(tensordict) | ||
return self.action_spec.cardinality() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly,
hash
is not guaranteed to give unique values for string inputs, since the number of possible strings is infinite, greater than the number of hash values. The output ofhash
only has 2^64 possible values.Is uniqueness required for this hash?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the questions to ask are
If we build a forest with - say - 10M elements, the number of combinations is still 10^13 times bigger than the capacity of the forest so I think it's safe to assume that the risk of failures to rebuild a tree due to hash collision is going to be small.
Another option on the safe size could also be to tokenize the fen. We could pad the tokens if they're shorter to make sure they all fit contiguously in memory.
Another question to solve is the one of reproducibility which worries me more than collision: if you restart your python process, the hash map will not hold anymore so any save data will be meaningless. IIRC there's a way to set the "seed" of the hash but that'd acting on a global variable which we may want to avoid anyway!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, at 10 million, the risk of collision is insignificant. I was curious where the limit is, so I looked into it. If my reasoning is correct, the risk starts to become significant around the order of 1 billion generated hashes. I think at 1 billion, the probability of collision is 2.7%. At 5 billion, the probability is almost 50%. (I put up some notes here).
So if we expect the tree to have significantly fewer than 1 billion nodes, then Python
hash
should be good enough I suppose.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like one option for reproducibility is to use
hashlib
, which I think also has features for generating unique hashes if we decide we need thatThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open to it!