Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 18, 2024
2 parents 64d1d92 + 7e96847 commit 61734fc
Show file tree
Hide file tree
Showing 51 changed files with 439 additions and 401 deletions.
8 changes: 4 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ TorchRL provides pytorch and python-first, low and high level abstractions for R
The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.

This repo attempts to align with the existing pytorch ecosystem libraries in that it has a "dataset pillar"
:doc:`(environments) <reference/envs>`,
:ref:`transforms <reference/envs:Transforms>`,
:doc:`models <reference/modules>`,
:ref:`(environments) <Environment-API>`,
:ref:`transforms <transforms>`,
:ref:`models <ref_modules>`,
data utilities (e.g. collectors and containers), etc.
TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch).
Common environment libraries (e.g. OpenAI gym) are only optional.

On the low-level end, torchrl comes with a set of highly re-usable functionals
for :doc:`cost functions <reference/objectives>`, :ref:`returns <reference/objectives:Returns>` and data processing.
for :ref:`cost functions <ref_objectives>`, :ref:`returns <ref_returns>` and data processing.

TorchRL aims at a high modularity and good runtime performance.

Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ not predictable.
MultiCategorical
MultiOneHot
NonTensor
OneHotDiscrete
OneHot
Stacked
StackedComposite
Unbounded
Expand Down Expand Up @@ -1050,7 +1050,7 @@ and the tree can be expanded for each of these. The following figure shows how t

BinaryToDecimal
HashToInt
MCTSForeset
MCTSForest
QueryModule
RandomProjectionHash
SipHash
Expand Down
3 changes: 3 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ QMixer

Returns
-------

.. _ref_returns:

.. currentmodule:: torchrl.objectives.value

.. autosummary::
Expand Down
14 changes: 12 additions & 2 deletions sota-implementations/a2c/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@ Please note that each example is independent of each other for the sake of simpl
You can execute the A2C algorithm on Atari environments by running the following command:

```bash
python a2c_atari.py
python a2c_atari.py compile.compile=1 compile.cudagraphs=1
```


You can execute the A2C algorithm on MuJoCo environments by running the following command:

```bash
python a2c_mujoco.py
python a2c_mujoco.py compile.compile=1 compile.cudagraphs=1
```

## Runtimes

Runtimes when executed on H100:

| Environment | Eager | Compile | Compile+cudagraphs |
|-------------|-------|---------|--------------------|
| MUJOCO | | | |
| ATARI | | 60 mins | 43 mins |
8 changes: 4 additions & 4 deletions sota-implementations/decision_transformer/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class Lamb(Optimizer):
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
lr (:obj:`float`, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
eps (:obj:`float`, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
weight_decay (:obj:`float`, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
max_grad_norm (:obj:`float`, optional): value used to clip global grad norm (default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
Expand Down
22 changes: 22 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,28 @@ def make_and_test_policy(
)


@pytest.mark.parametrize(
"ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector]
)
def test_no_stopiteration(ctype):
# Tests that there is no StopIteration raised and that the length of the collector is properly set
if ctype is SyncDataCollector:
envs = SerialEnv(16, CountingEnv)
else:
envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)]

collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300)
try:
c_iter = iter(collector)
for i in range(len(collector)): # noqa: B007
c = next(c_iter)
assert c is not None
assert i == 1
finally:
collector.shutdown()
del collector


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
21 changes: 13 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
_iterator = None
total_frames: int
frames_per_batch: int
requested_frames_per_batch: int
trust_policy: bool
compiled_policy: bool
cudagraphed_policy: bool
Expand Down Expand Up @@ -305,7 +306,7 @@ def __class_getitem__(self, index):

def __len__(self) -> int:
if self.total_frames > 0:
return -(self.total_frames // -self.frames_per_batch)
return -(self.total_frames // -self.requested_frames_per_batch)
raise RuntimeError("Non-terminating collectors do not have a length")


Expand Down Expand Up @@ -700,7 +701,7 @@ def __init__(
remainder = total_frames % frames_per_batch
if remainder != 0 and RL_WARNINGS:
warnings.warn(
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})."
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
f"This means {frames_per_batch - remainder} additional frames will be collected."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
Expand Down Expand Up @@ -1399,11 +1400,13 @@ class _MultiDataCollector(DataCollectorBase):
instances) it will be wrapped in a `nn.Module` first.
Then, the collector will try to assess if these
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
- If the policy forward signature matches any of ``forward(self, tensordict)``,
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
any typing with a single argument typed as a subclass of ``TensorDictBase``)
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
- In all other cases an attempt to wrap it will be undergone as such:
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
Keyword Args:
frames_per_batch (int): A keyword-only argument representing the
Expand Down Expand Up @@ -1491,7 +1494,7 @@ class _MultiDataCollector(DataCollectorBase):
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
that will be allowed to finished collecting their rollout before the rest are forced to end early.
num_threads (int, optional): number of threads for this process.
Defaults to the number of workers.
Expand Down Expand Up @@ -2108,11 +2111,13 @@ class MultiSyncDataCollector(_MultiDataCollector):
trajectory and the start of the next collection.
This class can be safely used with online RL sota-implementations.
.. note:: Python requires multiprocessed code to be instantiated within a main guard:
.. note::
Python requires multiprocessed code to be instantiated within a main guard:
>>> from torchrl.collectors import MultiSyncDataCollector
>>> if __name__ == "__main__":
... # Create your collector here
... collector = MultiSyncDataCollector(...)
See https://docs.python.org/3/library/multiprocessing.html for more info.
Expand Down Expand Up @@ -2140,8 +2145,8 @@ class MultiSyncDataCollector(_MultiDataCollector):
... if i == 2:
... print(data)
... break
... collector.shutdown()
... del collector
>>> collector>shutdown()
>>> del collector
TensorDict(
fields={
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
Expand Down Expand Up @@ -2768,7 +2773,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
will be called before (sync) or after (async) each data collection.
Defaults to ``False``.
preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
that will be allowed to finished collecting their rollout before the rest are forced to end early.
num_threads (int, optional): number of threads for this process.
Defaults to the number of workers.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Tree(TensorClass["nocast"]):
If there are multiple actions taken at this node, subtrees are stored in the corresponding
entry. Rollouts can be reconstructed using the :meth:`~.rollout_from_path` method.
node (TensorDict): Data defining this node (e.g., observations) before the next branching.
Entries usually matches the ``in_keys`` in ``MCTSForeset.node_map``.
Entries usually matches the ``in_keys`` in ``MCTSForest.node_map``.
subtree (Tree): A stack of subtrees produced when actions are taken.
num_children (int): The number of child nodes (read-only).
is_terminal (bool): whether the tree has children nodes (read-only).
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class MultiStep(nn.Module):
It is an identity transform whenever :attr:`n_steps` is 0.
Args:
gamma (float): Discount factor for return computation
gamma (:obj:`float`): Discount factor for return computation
n_steps (integer): maximum look-ahead steps.
.. note:: This class is meant to be used within a ``DataCollector``.
Expand Down
18 changes: 8 additions & 10 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,16 +897,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
All arguments are keyword-only arguments.
Presented in
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
Prioritized experience replay."
(https://arxiv.org/abs/1511.05952)
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
Prioritized experience replay." (https://arxiv.org/abs/1511.05952)
Args:
alpha (float): exponent α determines how much prioritization is used,
alpha (:obj:`float`): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
eps (float): delta added to the priorities to ensure that the buffer
beta (:obj:`float`): importance sampling negative exponent.
eps (:obj:`float`): delta added to the priorities to ensure that the buffer
does not contain null priorities.
storage (Storage, optional): the storage to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
Expand Down Expand Up @@ -1366,10 +1364,10 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
tensordict to be passed to it with its new priority value.
Keyword Args:
alpha (float): exponent α determines how much prioritization is used,
alpha (:obj:`float`): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
eps (float): delta added to the priorities to ensure that the buffer
beta (:obj:`float`): importance sampling negative exponent.
eps (:obj:`float`): delta added to the priorities to ensure that the buffer
does not contain null priorities.
storage (Storage, optional): the storage to be used. If none is provided
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
Expand Down
12 changes: 6 additions & 6 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,10 @@ class PrioritizedSampler(Sampler):
Args:
max_capacity (int): maximum capacity of the buffer.
alpha (float): exponent α determines how much prioritization is used,
alpha (:obj:`float`): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
eps (float, optional): delta added to the priorities to ensure that the buffer
beta (:obj:`float`): importance sampling negative exponent.
eps (:obj:`float`, optional): delta added to the priorities to ensure that the buffer
does not contain null priorities. Defaults to 1e-8.
reduction (str, optional): the reduction method for multidimensional
tensordicts (ie stored trajectory). Can be one of "max", "min",
Expand Down Expand Up @@ -1652,10 +1652,10 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
:meth:`~.update_priority`.
Args:
alpha (float): exponent α determines how much prioritization is used,
alpha (:obj:`float`): exponent α determines how much prioritization is used,
with α = 0 corresponding to the uniform case.
beta (float): importance sampling negative exponent.
eps (float, optional): delta added to the priorities to ensure that the buffer
beta (:obj:`float`): importance sampling negative exponent.
eps (:obj:`float`, optional): delta added to the priorities to ensure that the buffer
does not contain null priorities. Defaults to 1e-8.
reduction (str, optional): the reduction method for multidimensional
tensordicts (i.e., stored trajectory). Can be one of "max", "min",
Expand Down
10 changes: 5 additions & 5 deletions torchrl/data/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ConstantKLController(KLControllerBase):
with.
Keyword Arguments:
kl_coef (float): The coefficient to multiply KL with when calculating the
kl_coef (:obj:`float`): The coefficient to multiply KL with when calculating the
reward.
model (nn.Module, optional): wrapped model that needs to be controlled.
Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
Expand Down Expand Up @@ -73,8 +73,8 @@ class AdaptiveKLController(KLControllerBase):
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences".
Keyword Arguments:
init_kl_coef (float): The starting value of the coefficient.
target (float): The target KL value. When the observed KL is smaller, the
init_kl_coef (:obj:`float`): The starting value of the coefficient.
target (:obj:`float`): The target KL value. When the observed KL is smaller, the
coefficient is decreased, thereby relaxing the KL penalty in the training
objective and allowing the model to stray further from the reference model.
When the observed KL is greater than the target, the KL coefficient is
Expand Down Expand Up @@ -146,10 +146,10 @@ class RolloutFromModel:
reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given
``input_ids`` and ``attention_mask``, calculates rewards for each token and
end_scores (the reward for the final token in each sequence).
kl_coef: (float, optional): initial kl coefficient.
kl_coef: (:obj:`float`, optional): initial kl coefficient.
max_new_tokens (int, optional): the maximum length of the sequence.
Defaults to 50.
score_clip (float, optional): Scores from the reward model are clipped to the
score_clip (:obj:`float`, optional): Scores from the reward model are clipped to the
range ``(-score_clip, score_clip)``. Defaults to 10.
kl_scheduler (KLControllerBase, optional): the KL coefficient scheduler.
num_steps (int, optional): number of steps between two optimization.
Expand Down
20 changes: 10 additions & 10 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,15 +840,15 @@ def full_action_spec(self) -> Composite:
... break
>>> env = BraxEnv(envname)
>>> env.full_action_spec
Composite(
action: BoundedContinuous(
shape=torch.Size([8]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([]))
Composite(
action: BoundedContinuous(
shape=torch.Size([8]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True),
high=Tensor(shape=torch.Size([8]), device=cpu, dtype=torch.float32, contiguous=True)),
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([]))
"""
full_action_spec = self.input_spec.get("full_action_spec", None)
Expand Down Expand Up @@ -1791,7 +1791,7 @@ def register_gym(
(results are tensors).
This arg can be passed during a call to :func:`~gym.make` (see
example below).
reward_threshold (float, optional): [Gym kwarg] The reward threshold
reward_threshold (:obj:`float`, optional): [Gym kwarg] The reward threshold
considered to have learnt an environment.
nondeterministic (bool, optional): [Gym kwarg If the environment is nondeterministic
(even with knowledge of the initial seed and all actions). Defaults to
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/custom/tictactoeenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class TicTacToeEnv(EnvBase):
output entry).
Specs:
>>> print(env.specs)
Composite(
output_spec: Composite(
full_observation_spec: Composite(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/rb_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MultiStepTransform(Transform):
n_steps (int): Number of steps in multi-step. The number of steps can be
dynamically changed by changing the ``n_steps`` attribute of this
transform.
gamma (float): Discount factor.
gamma (:obj:`float`): Discount factor.
Keyword Args:
reward_keys (list of NestedKey, optional): the reward keys in the input tensordict.
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class KLRewardTransform(Transform):
have the following features: it must have a set of input (``in_keys``)
and output keys (``out_keys``). It must have a ``get_dist`` method
that outputs the distribution of the action.
coef (float): the coefficient of the KL term. Defaults to ``1.0``.
coef (:obj:`float`): the coefficient of the KL term. Defaults to ``1.0``.
in_keys (str or list of str/tuples of str): the input key where the
reward should be fetched. Defaults to ``"reward"``.
out_keys (str or list of str/tuples of str): the output key where the
Expand Down
Loading

0 comments on commit 61734fc

Please sign in to comment.