Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 9, 2024
1 parent 2abe615 commit 7bb29d1
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ or via a ``git clone`` if you're willing to contribute to the library:
Getting started
===============

A series of quick tutorials to get ramped up with the basic features of the
library. If you're in a hurry, you can start by
:ref:`the last item of the series <gs_first_training>`
and navigate to the previous ones whenever you want to learn more!

.. toctree::
:maxdepth: 1

Expand Down
4 changes: 3 additions & 1 deletion torchrl/record/loggers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import os
from collections import defaultdict
from pathlib import Path
Expand Down Expand Up @@ -126,7 +128,7 @@ class CSVLogger(Logger):
def __init__(
self,
exp_name: str,
log_dir: Optional[str] = None,
log_dir: str | None = None,
video_format: str = "pt",
video_fps: int = 30,
) -> None:
Expand Down
37 changes: 35 additions & 2 deletions tutorials/sphinx-tutorials/getting-started-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# transform. These features are presented in
# :ref:`the environment tutorial <gs_env_ted>`.
#
import pathlib

import torch

torch.manual_seed(0)
Expand Down Expand Up @@ -76,7 +78,13 @@
init_rand_steps = 5000
frames_per_batch = 100
optim_steps = 10
collector = SyncDataCollector(env, policy, frames_per_batch=frames_per_batch, total_frames=-1, init_random_frames=init_rand_steps)
collector = SyncDataCollector(
env,
policy,
frames_per_batch=frames_per_batch,
total_frames=-1,
init_random_frames=init_rand_steps,
)
rb = ReplayBuffer(storage=LazyTensorStorage(100_000))

from torch.optim import Adam
Expand All @@ -94,9 +102,19 @@
#################################
# Logger
# ------
# TODO
#
# We'll be using a CSV logger to log our results, and save rendered videos.
#

from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder

path = pathlib.Path(__file__).parent / "training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)

#################################
# Training loop
Expand Down Expand Up @@ -134,3 +152,18 @@
torchrl_logger.info(
f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

#################################
# Rendering
# ---------
#
# Finally, we run the environment for as many steps as we can and save the
# video locally (notice that we are not exploring).

record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()

#################################
#
# .. figure:: ./training_loop/dqn/videos/video_0.mp4
#

0 comments on commit 7bb29d1

Please sign in to comment.