From 2f3b4cd4da69dc80c3ff1d4ca9dfad7cff90df68 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 13 Nov 2024 18:31:38 +0000 Subject: [PATCH] [Doc] Fix tutorials ghstack-source-id: 6c9114384015e76e96b3bbd0c8893cc42344537a Pull Request resolved: https://github.com/pytorch/rl/pull/2560 --- tutorials/sphinx-tutorials/torchrl_demo.py | 6 ++---- tutorials/sphinx-tutorials/torchrl_envs.py | 8 +++++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 84f7be715ad..a9bc74aad3c 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -415,11 +415,9 @@ TransformedEnv, ) -base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) +base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False) env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage())) -env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) - -############################################################################### +env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) env.reset() diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 34189396ee9..3ea3f6af13d 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -694,9 +694,11 @@ def env_make(env_name): make_env = EnvCreator(lambda: TransformedEnv(GymEnv("CartPole-v1"), VecNorm(decay=1.0))) env = ParallelEnv(3, make_env) -make_env.state_dict()["_extra_state"]["td"]["observation_count"].fill_(0.0) -make_env.state_dict()["_extra_state"]["td"]["observation_ssq"].fill_(0.0) -make_env.state_dict()["_extra_state"]["td"]["observation_sum"].fill_(0.0) +print("env state dict:") +sd = TensorDict(make_env.state_dict()) +print(sd) +# Zeroes all tensors +sd *= 0 tensordict = env.rollout(max_steps=5)