Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 4, 2024
1 parent f335d72 commit b768fbc
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ TorchRL offers a series of custom built-in environments.

PendulumEnv
TicTacToeEnv
LLMHashingEnv


Multi-agent environments
------------------------
Expand Down
45 changes: 43 additions & 2 deletions torchrl/envs/custom/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,20 @@
class LLMHashingEnv(EnvBase):
"""A text generation environment that uses a hashing module to identify unique observations.
The primary goal of this environment is to identify token chains using a hashing function.
This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
identifiers, or easily prune repeated token chains in a data structure.
The following figure gives an overview of this workflow:
.. figure:: /_static/img/rollout-llm.png
:alt: Data collection loop with our LLM environment.
.. seealso:: the :ref:`Beam Search <beam_search>` tutorial gives a practical example of how this env can be used.
Args:
vocab_size (int): The size of the vocabulary.
vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
Keyword Args:
hashing_module (Callable[[torch.Tensor], torch.Tensor], optional):
A hashing function that takes a tensor as input and returns a hashed tensor.
Defaults to :class:`~torchrl.data.SipHash` if not provided.
Expand All @@ -39,18 +51,47 @@ class LLMHashingEnv(EnvBase):
text_key (NestedKey | None, optional): The key for the text output in the TensorDict.
Defaults to "text".
Examples:
>>> from tensordict import TensorDict
>>> from torchrl.envs import LLMHashingEnv
>>> from transformers import GPT2Tokenizer
>>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
>>> x = tokenizer(["Check out TorchRL!"])["input_ids"]
>>> env = LLMHashingEnv(tokenizer=tokenizer)
>>> td = TensorDict(observation=x, batch_size=[1])
>>> td = env.reset(td)
>>> print(td)
TensorDict(
fields={
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False),
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
text: NonTensorStack(
['Check out TorchRL!'],
batch_size=torch.Size([1]),
device=None)},
batch_size=torch.Size([1]),
device=None,
is_shared=False)
"""

def __init__(
self,
vocab_size: int,
vocab_size: int | None=None,
*,
hashing_module: Callable[[torch.Tensor], torch.Tensor] = None,
observation_key: NestedKey = "observation",
text_output: bool = True,
tokenizer: Callable[[Union[str, List[str]]], torch.Tensor] | None = None,
text_key: NestedKey | None = "text",
):
super().__init__()
if vocab_size is None:
if tokenizer is None:
raise TypeError("You must provide a vocab_size integer if tokenizer is `None`.")
vocab_size = tokenizer.vocab_size
self._batch_locked = False
if hashing_module is None:
hashing_module = SipHash()
Expand Down
3 changes: 3 additions & 0 deletions tutorials/sphinx-tutorials/beam_search_with_gpt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""
Beam Search with TorchRL
========================
**Author**: `Vincent Moens <https://github.com/vmoens>`_
.. _beam_search:
Key learning
------------
Expand Down

0 comments on commit b768fbc

Please sign in to comment.