From 7bb29d1827d1955675ff99a6723f7b5d9dfcd06e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 9 Feb 2024 13:54:10 +0000 Subject: [PATCH] amend --- docs/source/index.rst | 5 +++ torchrl/record/loggers/csv.py | 4 +- .../sphinx-tutorials/getting-started-5.py | 37 ++++++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 8343577547d..ab1cee681db 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -65,6 +65,11 @@ or via a ``git clone`` if you're willing to contribute to the library: Getting started =============== +A series of quick tutorials to get ramped up with the basic features of the +library. If you're in a hurry, you can start by +:ref:`the last item of the series ` +and navigate to the previous ones whenever you want to learn more! + .. toctree:: :maxdepth: 1 diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 256d0a2e840..6bcd3f50c86 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os from collections import defaultdict from pathlib import Path @@ -126,7 +128,7 @@ class CSVLogger(Logger): def __init__( self, exp_name: str, - log_dir: Optional[str] = None, + log_dir: str | None = None, video_format: str = "pt", video_fps: int = 30, ) -> None: diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py index 27f4fddc8f8..7d00ef7d0d3 100644 --- a/tutorials/sphinx-tutorials/getting-started-5.py +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -27,6 +27,8 @@ # transform. These features are presented in # :ref:`the environment tutorial `. # +import pathlib + import torch torch.manual_seed(0) @@ -76,7 +78,13 @@ init_rand_steps = 5000 frames_per_batch = 100 optim_steps = 10 -collector = SyncDataCollector(env, policy, frames_per_batch=frames_per_batch, total_frames=-1, init_random_frames=init_rand_steps) +collector = SyncDataCollector( + env, + policy, + frames_per_batch=frames_per_batch, + total_frames=-1, + init_random_frames=init_rand_steps, +) rb = ReplayBuffer(storage=LazyTensorStorage(100_000)) from torch.optim import Adam @@ -94,9 +102,19 @@ ################################# # Logger # ------ -# TODO +# +# We'll be using a CSV logger to log our results, and save rendered videos. +# from torchrl._utils import logger as torchrl_logger +from torchrl.record import CSVLogger, VideoRecorder + +path = pathlib.Path(__file__).parent / "training_loop" +logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4") +video_recorder = VideoRecorder(logger, tag="video") +record_env = TransformedEnv( + GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder +) ################################# # Training loop @@ -134,3 +152,18 @@ torchrl_logger.info( f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s." ) + +################################# +# Rendering +# --------- +# +# Finally, we run the environment for as many steps as we can and save the +# video locally (notice that we are not exploring). + +record_env.rollout(max_steps=1000, policy=policy) +video_recorder.dump() + +################################# +# +# .. figure:: ./training_loop/dqn/videos/video_0.mp4 +#