Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into extend-rb-dim1
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 13, 2024
2 parents 238e9b3 + 899af07 commit eb731a4
Show file tree
Hide file tree
Showing 48 changed files with 2,204 additions and 233 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f

TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library.

## Getting started

Check our [Getting Started tutorials](https://pytorch.org/rl/index.html#getting-started) for quickly ramp up with the basic
features of the library!

## Documentation and knowledge base

The TorchRL documentation can be found [here](https://pytorch.org/rl).
Expand Down
17 changes: 17 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ or via a ``git clone`` if you're willing to contribute to the library:
$ cd ../rl
$ python setup.py develop
Getting started
===============

A series of quick tutorials to get ramped up with the basic features of the
library. If you're in a hurry, you can start by
:ref:`the last item of the series <gs_first_training>`
and navigate to the previous ones whenever you want to learn more!

.. toctree::
:maxdepth: 1

tutorials/getting-started-0
tutorials/getting-started-1
tutorials/getting-started-2
tutorials/getting-started-3
tutorials/getting-started-4
tutorials/getting-started-5

Tutorials
=========
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
torchrl.collectors package
==========================

.. _ref_collectors:

Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they
collect data over non-static data sources and (2) the data is collected using a model
(likely a version of the model that is being trained).
Expand Down
6 changes: 6 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
torchrl.data package
====================

.. _ref_data:

Replay Buffers
--------------

Expand Down Expand Up @@ -134,6 +136,7 @@ using the following components:

Sampler
PrioritizedSampler
PrioritizedSliceSampler
RandomSampler
SamplerWithoutReplacement
SliceSampler
Expand Down Expand Up @@ -699,6 +702,9 @@ efficient sampling.
TokenizedDatasetLoader
create_infinite_iterator
get_dataloader
ConstantKLController
AdaptiveKLController


Utils
-----
Expand Down
3 changes: 3 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ single agent standards.

Transforms
----------

.. _transforms:

.. currentmodule:: torchrl.envs.transforms

In most cases, the raw output of an environment must be treated before being passed to another object (such as a
Expand Down
4 changes: 4 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
torchrl.modules package
=======================

.. _ref_modules:

TensorDict modules: Actors, exploration, value models and generative models
---------------------------------------------------------------------------

.. _tdmodules:

TorchRL offers a series of module wrappers aimed at making it easy to build
RL models from the ground up. These wrappers are exclusively based on
:class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
torchrl.objectives package
==========================

.. _ref_objectives:

TorchRL provides a series of losses to use in your training scripts.
The aim is to have losses that are easily reusable/swappable and that have
a simple signature.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ Utils
Loggers
-------

.. _ref_loggers:

.. currentmodule:: torchrl.record.loggers

.. autosummary::
Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
NoopResetEnv,
ParallelEnv,
Resize,
RewardClipping,
RewardSum,
SignTransform,
StepCounter,
ToTensorImage,
TransformedEnv,
Expand Down Expand Up @@ -73,7 +73,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False):
env.append_transform(RewardSum())
env.append_transform(StepCounter(max_steps=4500))
if not is_test:
env.append_transform(RewardClipping(-1, 1))
env.append_transform(SignTransform(in_keys=["reward"]))
env.append_transform(DoubleToFloat())
env.append_transform(VecNorm(in_keys=["pixels"]))
return env
Expand Down
4 changes: 2 additions & 2 deletions examples/dqn/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
GymEnv,
NoopResetEnv,
Resize,
RewardClipping,
RewardSum,
SignTransform,
StepCounter,
ToTensorImage,
TransformedEnv,
Expand All @@ -42,7 +42,7 @@ def make_env(env_name, frame_skip, device, is_test=False):
env.append_transform(NoopResetEnv(noops=30, random=True))
if not is_test:
env.append_transform(EndOfLifeTransform())
env.append_transform(RewardClipping(-1, 1))
env.append_transform(SignTransform(in_keys=["reward"]))
env.append_transform(ToTensorImage())
env.append_transform(GrayScale())
env.append_transform(Resize(84, 84))
Expand Down
4 changes: 2 additions & 2 deletions examples/impala/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
GymEnv,
NoopResetEnv,
Resize,
RewardClipping,
RewardSum,
SignTransform,
StepCounter,
ToTensorImage,
TransformedEnv,
Expand Down Expand Up @@ -46,7 +46,7 @@ def make_env(env_name, device, is_test=False):
env.append_transform(NoopResetEnv(noops=30, random=True))
if not is_test:
env.append_transform(EndOfLifeTransform())
env.append_transform(RewardClipping(-1, 1))
env.append_transform(SignTransform(in_keys=["reward"]))
env.append_transform(ToTensorImage(from_int=False))
env.append_transform(GrayScale())
env.append_transform(Resize(84, 84))
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
NoopResetEnv,
ParallelEnv,
Resize,
RewardClipping,
RewardSum,
SignTransform,
StepCounter,
ToTensorImage,
TransformedEnv,
Expand Down Expand Up @@ -71,7 +71,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False):
env.append_transform(RewardSum())
env.append_transform(StepCounter(max_steps=4500))
if not is_test:
env.append_transform(RewardClipping(-1, 1))
env.append_transform(SignTransform(in_keys=["reward"]))
env.append_transform(DoubleToFloat())
env.append_transform(VecNorm(in_keys=["pixels"]))
return env
Expand Down
4 changes: 1 addition & 3 deletions examples/rlhf/train_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def main(cfg):
# using a Gym-like API (querying steps etc) introduces some
# extra code that we can spare.
#
kl_scheduler = AdaptiveKLController(
model, init_kl_coef=0.1, target=6, horizon=10000
)
kl_scheduler = AdaptiveKLController(init_kl_coef=0.1, target=6, horizon=10000)
rollout_from_model = RolloutFromModel(
model,
ref_model,
Expand Down
7 changes: 6 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def fin():
request.addfinalizer(fin)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def set_warnings() -> None:
warnings.filterwarnings(
"ignore",
Expand All @@ -65,6 +65,11 @@ def set_warnings() -> None:
category=UserWarning,
message=r"Couldn't cast the policy onto the desired device on remote process",
)
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"Skipping device Apple Paravirtual device",
)
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
Expand Down
Loading

0 comments on commit eb731a4

Please sign in to comment.