Skip to content

Commit

Permalink
[Doc] Fix tutorials
Browse files Browse the repository at this point in the history
ghstack-source-id: 6c9114384015e76e96b3bbd0c8893cc42344537a
Pull Request resolved: #2560
  • Loading branch information
vmoens committed Nov 13, 2024
1 parent 7051238 commit 2f3b4cd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 2 additions & 4 deletions tutorials/sphinx-tutorials/torchrl_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,9 @@
TransformedEnv,
)

base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
base_env = GymEnv("HalfCheetah-v4", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))

###############################################################################
env = env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1))

env.reset()

Expand Down
8 changes: 5 additions & 3 deletions tutorials/sphinx-tutorials/torchrl_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,11 @@ def env_make(env_name):

make_env = EnvCreator(lambda: TransformedEnv(GymEnv("CartPole-v1"), VecNorm(decay=1.0)))
env = ParallelEnv(3, make_env)
make_env.state_dict()["_extra_state"]["td"]["observation_count"].fill_(0.0)
make_env.state_dict()["_extra_state"]["td"]["observation_ssq"].fill_(0.0)
make_env.state_dict()["_extra_state"]["td"]["observation_sum"].fill_(0.0)
print("env state dict:")
sd = TensorDict(make_env.state_dict())
print(sd)
# Zeroes all tensors
sd *= 0

tensordict = env.rollout(max_steps=5)

Expand Down

0 comments on commit 2f3b4cd

Please sign in to comment.