diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 84f7be715ad..a9bc74aad3c 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -415,11 +415,9 @@ TransformedEnv, ) -base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) +base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False) env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage())) -env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) - -############################################################################### +env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) env.reset() diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 34189396ee9..3ea3f6af13d 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -694,9 +694,11 @@ def env_make(env_name): make_env = EnvCreator(lambda: TransformedEnv(GymEnv("CartPole-v1"), VecNorm(decay=1.0))) env = ParallelEnv(3, make_env) -make_env.state_dict()["_extra_state"]["td"]["observation_count"].fill_(0.0) -make_env.state_dict()["_extra_state"]["td"]["observation_ssq"].fill_(0.0) -make_env.state_dict()["_extra_state"]["td"]["observation_sum"].fill_(0.0) +print("env state dict:") +sd = TensorDict(make_env.state_dict()) +print(sd) +# Zeroes all tensors +sd *= 0 tensordict = env.rollout(max_steps=5)