Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 6, 2024
1 parent 6f15101 commit bfbfaa7
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3013,6 +3013,52 @@ def add_truncated_keys(self) -> EnvBase:
self.__dict__["_done_keys"] = None
return self

def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase:
"""Advances the environment state by one step using the provided `next_tensordict`.
This method updates the environment's state by transitioning from the current
state to the next, as defined by the `next_tensordict`. The resulting tensordict
includes updated observations and any other relevant state information, with
keys managed according to the environment's specifications.
Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently
handle the transition of state, observation, action, reward, and done keys. The
:class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and
exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance
is created with `exclude_action=False`, meaning that action keys are retained in
the root tensordict.
Args:
next_tensordict (TensorDictBase): A tensordict containing the state of the
environment at the next time step. This tensordict should include keys
for observations, actions, rewards, and done flags, as defined by the
environment's specifications.
Returns:
TensorDictBase: A new tensordict representing the environment state after
advancing by one step.
.. note:: The method ensures that the environment's key specifications are validated
against the provided `next_tensordict`, issuing warnings if discrepancies
are found.
.. note:: This method is designed to work efficiently with environments that have
consistent key specifications, leveraging the `_StepMDP` class to minimize
overhead.
Example:
>>> from torchrl.envs import GymEnv
>>> env = GymEnv("Pendulum-1")
>>> data = env.reset()
>>> for i in range(10):
... # compute action
... env.rand_action(data)
... # Perform action
... next_data = env.step(reset_data)
... data = env.step_mdp(next_data)
"""
return self._step_mdp(next_tensordict)

@property
def _step_mdp(self):
step_func = self.__dict__.get("_step_mdp_value")
Expand Down

0 comments on commit bfbfaa7

Please sign in to comment.