Skip to content

Commit

Permalink
[Feature] env.step_mdp
Browse files Browse the repository at this point in the history
ghstack-source-id: 145e37cd772fdd74e35e5ffe6accc5c81ad689f3
Pull Request resolved: #2636
  • Loading branch information
vmoens committed Dec 12, 2024
1 parent 30d21e5 commit 4bc40a8
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3015,6 +3015,52 @@ def add_truncated_keys(self) -> EnvBase:
self.__dict__["_done_keys"] = None
return self

def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase:
"""Advances the environment state by one step using the provided `next_tensordict`.
This method updates the environment's state by transitioning from the current
state to the next, as defined by the `next_tensordict`. The resulting tensordict
includes updated observations and any other relevant state information, with
keys managed according to the environment's specifications.
Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently
handle the transition of state, observation, action, reward, and done keys. The
:class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and
exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance
is created with `exclude_action=False`, meaning that action keys are retained in
the root tensordict.
Args:
next_tensordict (TensorDictBase): A tensordict containing the state of the
environment at the next time step. This tensordict should include keys
for observations, actions, rewards, and done flags, as defined by the
environment's specifications.
Returns:
TensorDictBase: A new tensordict representing the environment state after
advancing by one step.
.. note:: The method ensures that the environment's key specifications are validated
against the provided `next_tensordict`, issuing warnings if discrepancies
are found.
.. note:: This method is designed to work efficiently with environments that have
consistent key specifications, leveraging the `_StepMDP` class to minimize
overhead.
Example:
>>> from torchrl.envs import GymEnv
>>> env = GymEnv("Pendulum-1")
>>> data = env.reset()
>>> for i in range(10):
... # compute action
... env.rand_action(data)
... # Perform action
... next_data = env.step(reset_data)
... data = env.step_mdp(next_data)
"""
return self._step_mdp(next_tensordict)

@property
def _step_mdp(self):
step_func = self.__dict__.get("_step_mdp_value")
Expand Down

1 comment on commit 4bc40a8

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 4bc40a8 Previous: 57dc25a Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 623.253577832539 iter/sec (stddev: 0.031147975287124898) 1506.9847088208473 iter/sec (stddev: 0.000041641550313143115) 2.42

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.