Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 6, 2024
1 parent 68e64c6 commit 9147b63
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ TorchRL offers a series of custom built-in environments.

PendulumEnv
TicTacToeEnv
LLMHashingEnv


Multi-agent environments
------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

Expand Down Expand Up @@ -174,7 +174,7 @@ Trainer and hooks
BatchSubSampler
ClearCudaCache
CountFramesLog
LogScaler
LogScalar
OptimizerHook
LogValidationReward
ReplayBufferTrainer
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .batched_envs import ParallelEnv, SerialEnv
from .common import EnvBase, EnvMetaData, make_tensordict
from .custom import PendulumEnv, TicTacToeEnv
from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv
from .env_creator import env_creator, EnvCreator, get_env_metadata
from .gym_like import default_info_dict_reader, GymLikeEnv
from .libs import (
Expand Down
30 changes: 25 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
import numpy as np
import torch
import torch.nn as nn
from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key
from tensordict.utils import NestedKey
from tensordict import (
is_tensor_collection,
LazyStackedTensorDict,
TensorDictBase,
unravel_key,
)
from tensordict.base import _is_leaf_nontensor
from tensordict.utils import is_non_tensor, NestedKey
from torchrl._utils import (
_ends_with,
_make_ordinal_device,
Expand All @@ -25,7 +31,13 @@
seed_generator,
)

from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded
from torchrl.data.tensor_specs import (
Categorical,
Composite,
NonTensor,
TensorSpec,
Unbounded,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.utils import (
_make_compatible_policy,
Expand Down Expand Up @@ -430,7 +442,6 @@ def auto_specs_(
done_key: NestedKey | List[NestedKey] | None = None,
observation_key: NestedKey | List[NestedKey] = "observation",
reward_key: NestedKey | List[NestedKey] = "reward",
batch_size: torch.Size | None = None,
):
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
Expand Down Expand Up @@ -484,6 +495,7 @@ def auto_specs_(
tensordict2,
named=True,
nested_keys=True,
is_leaf=_is_leaf_nontensor,
)
input_spec = Composite(input_spec_stack, batch_size=batch_size)
if not self.batch_locked and batch_size != self.batch_size:
Expand All @@ -501,6 +513,7 @@ def auto_specs_(
nexts_1,
named=True,
nested_keys=True,
is_leaf=_is_leaf_nontensor,
)

output_spec = Composite(output_spec_stack, batch_size=batch_size)
Expand All @@ -523,7 +536,8 @@ def auto_specs_(
full_observation_spec = output_spec.separates(*observation_key, default=None)
if not output_spec.is_empty(recurse=True):
raise RuntimeError(
f"Keys {list(output_spec.keys(True, True))} are unaccounted for."
f"Keys {list(output_spec.keys(True, True))} are unaccounted for. "
f"Make sure you have passed all the leaf names to the auto_specs_ method."
)

if full_action_spec is not None:
Expand Down Expand Up @@ -3572,6 +3586,12 @@ def _has_dynamic_specs(spec: Composite):


def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack):
if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)):
stack[name] = NonTensor(shape=())
return
elif is_non_tensor(leaf):
stack[name] = NonTensor(shape=leaf.shape)
return
shape = leaf.shape
if leaf_compare is not None:
shape_compare = leaf_compare.shape
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .llm import LLMHashingEnv
from .pendulum import PendulumEnv
from .tictactoeenv import TicTacToeEnv
199 changes: 199 additions & 0 deletions torchrl/envs/custom/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, List, Union

import torch
from tensordict import NestedKey, TensorDictBase
from tensordict.tensorclass import NonTensorData, NonTensorStack

from torchrl.data import (
Categorical as CategoricalSpec,
Composite,
NonTensor,
SipHash,
Unbounded,
)
from torchrl.envs import EnvBase
from torchrl.envs.utils import _StepMDP


class LLMHashingEnv(EnvBase):
"""A text generation environment that uses a hashing module to identify unique observations.
The primary goal of this environment is to identify token chains using a hashing function.
This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
identifiers, or easily prune repeated token chains in a data structure.
The following figure gives an overview of this workflow:
.. figure:: /_static/img/rollout-llm.png
:alt: Data collection loop with our LLM environment.
.. seealso:: the :ref:`Beam Search <beam_search>` tutorial gives a practical example of how this env can be used.
Args:
vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
Keyword Args:
hashing_module (Callable[[torch.Tensor], torch.Tensor], optional):
A hashing function that takes a tensor as input and returns a hashed tensor.
Defaults to :class:`~torchrl.data.SipHash` if not provided.
observation_key (NestedKey, optional): The key for the observation in the TensorDict.
Defaults to "observation".
text_output (bool, optional): Whether to include the text output in the observation.
Defaults to True.
tokenizer (transformers.Tokenizer | None, optional):
A tokenizer function that converts text to tensors.
Only used when `text_output` is `True`.
Must implement the following methods: `decode` and `batch_decode`.
Defaults to ``None``.
text_key (NestedKey | None, optional): The key for the text output in the TensorDict.
Defaults to "text".
Examples:
>>> from tensordict import TensorDict
>>> from torchrl.envs import LLMHashingEnv
>>> from transformers import GPT2Tokenizer
>>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
>>> x = tokenizer(["Check out TorchRL!"])["input_ids"]
>>> env = LLMHashingEnv(tokenizer=tokenizer)
>>> td = TensorDict(observation=x, batch_size=[1])
>>> td = env.reset(td)
>>> print(td)
TensorDict(
fields={
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
text: NonTensorStack(
['Check out TorchRL!'],
batch_size=torch.Size([1]),
device=None)},
batch_size=torch.Size([1]),
device=None,
is_shared=False)
"""

def __init__(
self,
vocab_size: int | None = None,
*,
hashing_module: Callable[[torch.Tensor], torch.Tensor] = None,
observation_key: NestedKey = "observation",
text_output: bool = True,
tokenizer: Callable[[Union[str, List[str]]], torch.Tensor] | None = None,
text_key: NestedKey | None = "text",
):
super().__init__()
if vocab_size is None:
if tokenizer is None:
raise TypeError(
"You must provide a vocab_size integer if tokenizer is `None`."
)
vocab_size = tokenizer.vocab_size
self._batch_locked = False
if hashing_module is None:
hashing_module = SipHash()

self._hashing_module = hashing_module
self._tokenizer = tokenizer
self.observation_key = observation_key
observation_spec = {
observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)),
"hash": Unbounded(shape=(1,), dtype=torch.int64),
}
self.text_output = text_output
if not text_output:
text_key = None
elif text_key is None:
text_key = "text"
if text_key is not None:
observation_spec[text_key] = NonTensor(shape=())
self.text_key = text_key
self.observation_spec = Composite(observation_spec)
self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,)))
_StepMDP(self)

def _reset(self, tensordict: TensorDictBase):
"""Initializes the environment with a given observation.
Args:
tensordict (TensorDictBase): A TensorDict containing the initial observation.
Returns:
A TensorDict containing the initial observation, its hash, and other relevant information.
"""
out = tensordict.empty()
obs = tensordict.get(self.observation_key)
if self.text_output:
if obs.ndim > 1:
text = self._tokenizer.batch_decode(obs)
text = NonTensorStack.from_list(text)
else:
text = self._tokenizer.decode(obs)
text = NonTensorData(text)
out.set(self.text_key, text)

if obs.ndim > 1:
out.set("hash", self._hashing_module(obs).unsqueeze(-1))
else:
out.set("hash", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1))

if not self.full_done_spec.is_empty():
out.update(self.full_done_spec.zero(tensordict.shape))
else:
out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool))
out.set(
"terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)
)
return out

def _step(self, tensordict):
"""Takes an action (i.e., the next token to generate) and returns the next observation and reward.
Args:
tensordict: A TensorDict containing the current observation and action.
Returns:
A TensorDict containing the next observation, its hash, and other relevant information.
"""
out = tensordict.empty()
action = tensordict.get("action")
obs = torch.cat([tensordict.get(self.observation_key), action], -1)
kwargs = {self.observation_key: obs}

catval = torch.cat([tensordict.get("hash"), action], -1)
if obs.ndim > 1:
new_hash = self._hashing_module(catval).unsqueeze(-1)
else:
new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1)

if self.text_output:
if obs.ndim > 1:
text = self._tokenizer.batch_decode(obs)
text = NonTensorStack.from_list(text)
else:
text = self._tokenizer.decode(obs)
text = NonTensorData(text)
kwargs[self.text_key] = text
kwargs.update(
{
"hash": new_hash,
"done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool),
"terminated": torch.zeros(
(*tensordict.batch_size, 1), dtype=torch.bool
),
}
)
return out.update(kwargs)

def _set_seed(self, *args):
"""Sets the seed for the environment's randomness.
.. note:: This environment has no randomness, so this method does nothing.
"""
pass
2 changes: 1 addition & 1 deletion torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __get__(self, cls, owner):


class _StepMDP:
"""Stateful version of step_mdp.
"""Stateful version of :func:`~torchrl.envs.step_mdp`.
Precomputes the list of keys to include and exclude during a call to step_mdp
to reduce runtime.
Expand Down

0 comments on commit 9147b63

Please sign in to comment.