Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] ChessEnv #2641

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .batched_envs import ParallelEnv, SerialEnv
from .common import EnvBase, EnvMetaData, make_tensordict
from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv
from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv
from .env_creator import env_creator, EnvCreator, get_env_metadata
from .gym_like import default_info_dict_reader, GymLikeEnv
from .libs import (
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .chess import ChessEnv
from .llm import LLMHashingEnv
from .pendulum import PendulumEnv
from .tictactoeenv import TicTacToeEnv
197 changes: 197 additions & 0 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

import torch
from tensordict import TensorDict, TensorDictBase
from torchrl.data import Categorical, Composite, NonTensor, Unbounded

from torchrl.envs import EnvBase

from torchrl.envs.utils import _classproperty


class ChessEnv(EnvBase):
"""A chess environment that follows the TorchRL API.

Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.

Args:
stateful (bool): Whether to keep track of the internal state of the board.
If False, the state will be stored in the observation and passed back
to the environment on each call. Default: ``False``.

.. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape.
Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves,
valid random actions cannot be taken. :meth:`~torchrl.envs.EnvBase.rand_action` has been adapted to account for
this behavior.

Examples:
>>> env = ChessEnv()
>>> r = env.reset()
>>> env.rand_step(r)
TensorDict(
fields={
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b KQkq - 1 1, batch_size=torch.Size([]), device=None),
hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int32, is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> env.rollout(1000)
TensorDict(
fields={
action: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
fen: NonTensorStack(
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
batch_size=torch.Size([322]),
device=None),
hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
fen: NonTensorStack(
['rnbqkbnr/pppppppp/8/8/2P5/8/PP1PPPPP/RNBQKBNR b ...,
batch_size=torch.Size([322]),
device=None),
hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False),
reward: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.int32, is_shared=False),
terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([322]),
device=None,
is_shared=False),
terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([322]),
device=None,
is_shared=False)


"""

_hash_table = {}

@_classproperty
def lib(cls):
try:
import chess
except ImportError:
raise ImportError(
"The `chess` library could not be found. Make sure you installed it through `pip install chess`."
)
return chess

def __init__(self, stateful: bool = False):
chess = self.lib
super().__init__()
self.full_observation_spec = Composite(
hashing=Unbounded(shape=(), dtype=torch.int64),
fen=NonTensor(shape=()),
turn=Categorical(n=2, dtype=torch.bool, shape=()),
)
self.stateful = stateful
if not self.stateful:
self.full_state_spec = self.full_observation_spec.clone()
self.full_action_spec = Composite(
action=Categorical(n=-1, shape=(), dtype=torch.int64)
)
self.full_reward_spec = Composite(
reward=Unbounded(shape=(1,), dtype=torch.int32)
)
# done spec generated automatically
self.board = chess.Board()
if self.stateful:
self.action_spec.set_provisional_n(len(list(self.board.legal_moves)))

def rand_action(self, tensordict: Optional[TensorDictBase] = None):
self._set_action_space(tensordict)
return super().rand_action(tensordict)

def _reset(self, tensordict=None):
fen = None
if tensordict is not None:
fen = self._get_fen(tensordict)
dest = tensordict.empty()
else:
dest = TensorDict()

if fen is None:
self.board.reset()
fen = self.board.fen()
else:
self.board.set_fen(fen.data)

hashing = hash(fen)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, hash is not guaranteed to give unique values for string inputs, since the number of possible strings is infinite, greater than the number of hash values. The output of hash only has 2^64 possible values.

>>> import sys
>>> sys.hash_info.width
64

Is uniqueness required for this hash?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the questions to ask are

  1. what are the chances that a collision will occur
  2. what happens if a collision occurs

If we build a forest with - say - 10M elements, the number of combinations is still 10^13 times bigger than the capacity of the forest so I think it's safe to assume that the risk of failures to rebuild a tree due to hash collision is going to be small.

Another option on the safe size could also be to tokenize the fen. We could pad the tokens if they're shorter to make sure they all fit contiguously in memory.

Another question to solve is the one of reproducibility which worries me more than collision: if you restart your python process, the hash map will not hold anymore so any save data will be meaningless. IIRC there's a way to set the "seed" of the hash but that'd acting on a global variable which we may want to avoid anyway!

Copy link
Collaborator

@kurtamohler kurtamohler Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, at 10 million, the risk of collision is insignificant. I was curious where the limit is, so I looked into it. If my reasoning is correct, the risk starts to become significant around the order of 1 billion generated hashes. I think at 1 billion, the probability of collision is 2.7%. At 5 billion, the probability is almost 50%. (I put up some notes here).

So if we expect the tree to have significantly fewer than 1 billion nodes, then Python hash should be good enough I suppose.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like one option for reproducibility is to use hashlib, which I think also has features for generating unique hashes if we decide we need that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm open to it!


self._set_action_space()
turn = self.board.turn
return dest.set("fen", fen).set("hashing", hashing).set("turn", turn)

def _set_action_space(self, tensordict: TensorDict | None = None):
if not self.stateful and tensordict is not None:
fen = self._get_fen(tensordict).data
self.board.set_fen(fen)
self.action_spec.set_provisional_n(self.board.legal_moves.count())

@classmethod
def _get_fen(cls, tensordict):
fen = tensordict.get("fen", None)
if fen is None:
hashing = tensordict.get("hashing", None)
if hashing is not None:
fen = cls._hash_table.get(hashing.item())
return fen

def _step(self, tensordict):
# action
action = tensordict.get("action")
board = self.board
if not self.stateful:
fen = self._get_fen(tensordict).data
board.set_fen(fen)
action = str(list(board.legal_moves)[action])
# assert chess.Move.from_uci(action) in board.legal_moves
board.push_san(action)
self._set_action_space()

# Collect data
fen = self.board.fen()
dest = tensordict.empty()
hashing = hash(fen)
dest.set("fen", fen)
dest.set("hashing", hashing)

done = board.is_checkmate()
turn = torch.tensor(board.turn)
reward = torch.tensor([done]).int() * (turn.int() * 2 - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nit: I think this line is a little hard to read. Maybe something like this is a little easier to grasp?

winner = not board.turn
reward = torch.tensor([0 if not done else (-1 if winner == chess.BLACK else 1)])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are multiple ways of ending the game maybe we should be a bit more comprehensive there.:
Win = 1
Lose = -1
Stalemate ?

Copy link
Collaborator

@kurtamohler kurtamohler Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's worth considering that usually in tournaments, if there's a checkmate or resignation, the winner gets 1 point and the loser gets 0, but if it's a draw of any kind, both players get 0.5 points. source

But I think whatever the values of win/lose are, the value of draw should probably be the average of win and lose

Copy link
Collaborator

@kurtamohler kurtamohler Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also chess evaluation engines, like stockfish, use 0 if the position is equal (including draws), negative score for a black advantage, and positive score for white advantage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so what about 0 / 0.5 / 1?

Also worth considering: more granular reward
https://www.restack.io/p/reinforcement-learning-answer-chess-bot-cat-ai
Since we have the opportunity of having multiple rewards, we could add another tensor that assign a reward for taking / losing pieces.

Copy link
Collaborator

@kurtamohler kurtamohler Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that -1/0/1 is probably more consistent with existing chess engines, and it can encapsulate the value of the game for both black and white in a single number.

0/0.5/1 is probably also good, but then I guess we'd need two rewards, one for black and one for white. That is, unless we always set the reward to 0.5 for non-terminal states, but I'm not sure that would be ideal for transforms.RewardSum.

a reward for taking / losing pieces

I think that could promote moves that do not optimize actually winning the game. For instance, if the reward decreases when losing a piece, then it could discourage making a sacrifice (or a combination of sacrifices) that leads to a forced checkmate. Likewise, if the reward increases when taking pieces, it could potentially encourage taking a piece rather than properly defending against an imminent checkmate threat.

Although it depends on what you want. Sometimes you don't want a chess bot to play optimally--like if it is meant to play against humans who are no match for the best chess bots

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with all of this. With the chess lib, it's always white-centric no? ie, win means white wins, or is there a way to flip things around and be black-centric?

done = done | board.is_stalemate() | board.is_game_over()
dest.set("reward", reward)
dest.set("turn", turn)
dest.set("done", [done])
dest.set("terminated", [done])
return dest

def _set_seed(self, *args, **kwargs):
...

def cardinality(self, tensordict: TensorDictBase|None=None) -> int:
self._set_action_space(tensordict)
return self.action_spec.cardinality()
Loading