From b768fbc7518959ed56a08a16a8d1712f987e9203 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 4 Dec 2024 11:20:53 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- docs/source/reference/envs.rst | 2 + torchrl/envs/custom/llm.py | 45 ++++++++++++++++++- .../sphinx-tutorials/beam_search_with_gpt.py | 3 ++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 4519900ae8b..70fdf03c0ff 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -347,6 +347,8 @@ TorchRL offers a series of custom built-in environments. PendulumEnv TicTacToeEnv + LLMHashingEnv + Multi-agent environments ------------------------ diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 0413671f32c..4e9e5b7d3c0 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -22,8 +22,20 @@ class LLMHashingEnv(EnvBase): """A text generation environment that uses a hashing module to identify unique observations. + The primary goal of this environment is to identify token chains using a hashing function. + This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node + identifiers, or easily prune repeated token chains in a data structure. + The following figure gives an overview of this workflow: + + .. figure:: /_static/img/rollout-llm.png + :alt: Data collection loop with our LLM environment. + + .. seealso:: the :ref:`Beam Search ` tutorial gives a practical example of how this env can be used. + Args: - vocab_size (int): The size of the vocabulary. + vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed. + + Keyword Args: hashing_module (Callable[[torch.Tensor], torch.Tensor], optional): A hashing function that takes a tensor as input and returns a hashed tensor. Defaults to :class:`~torchrl.data.SipHash` if not provided. @@ -39,11 +51,36 @@ class LLMHashingEnv(EnvBase): text_key (NestedKey | None, optional): The key for the text output in the TensorDict. Defaults to "text". + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.envs import LLMHashingEnv + >>> from transformers import GPT2Tokenizer + >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + >>> x = tokenizer(["Check out TorchRL!"])["input_ids"] + >>> env = LLMHashingEnv(tokenizer=tokenizer) + >>> td = TensorDict(observation=x, batch_size=[1]) + >>> td = env.reset(td) + >>> print(td) + TensorDict( + fields={ + done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False), + terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + text: NonTensorStack( + ['Check out TorchRL!'], + batch_size=torch.Size([1]), + device=None)}, + batch_size=torch.Size([1]), + device=None, + is_shared=False) + """ def __init__( self, - vocab_size: int, + vocab_size: int | None=None, + *, hashing_module: Callable[[torch.Tensor], torch.Tensor] = None, observation_key: NestedKey = "observation", text_output: bool = True, @@ -51,6 +88,10 @@ def __init__( text_key: NestedKey | None = "text", ): super().__init__() + if vocab_size is None: + if tokenizer is None: + raise TypeError("You must provide a vocab_size integer if tokenizer is `None`.") + vocab_size = tokenizer.vocab_size self._batch_locked = False if hashing_module is None: hashing_module = SipHash() diff --git a/tutorials/sphinx-tutorials/beam_search_with_gpt.py b/tutorials/sphinx-tutorials/beam_search_with_gpt.py index a3214e89b4e..c2ef0baeca3 100644 --- a/tutorials/sphinx-tutorials/beam_search_with_gpt.py +++ b/tutorials/sphinx-tutorials/beam_search_with_gpt.py @@ -1,6 +1,9 @@ """ Beam Search with TorchRL ======================== +**Author**: `Vincent Moens `_ + +.. _beam_search: Key learning ------------