Skip to content

Commit

Permalink
Add max_steps param to have explicit control over the number ofdecodi…
Browse files Browse the repository at this point in the history
…ng iterations
  • Loading branch information
Junyoungpark committed Aug 30, 2023
1 parent ae30f25 commit 02dc606
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions rl4co/models/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,19 @@ def random_policy(td):
return td


def rollout(env, td, policy):
def rollout(env, td, policy, max_steps: int = 10_000):
"""Helper function to rollout a policy. Currently, TorchRL does not allow to step
over envs when done with `env.rollout()`. We need this because for environements that complete at different steps.
"""
actions = []
steps = 0
while not td["done"].all():
td = policy(td)
actions.append(td["action"])
td = env.step(td)["next"]
steps += 1
if steps > max_steps:
break

Check warning on line 73 in rl4co/models/nn/utils.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/utils.py#L73

Added line #L73 was not covered by tests
return (
env.get_reward(td, torch.stack(actions, dim=1)),
td,
Expand All @@ -92,11 +96,12 @@ def forward(
self,
td: TensorDict,
env: Union[str, RL4COEnvBase] = None,
max_steps: int = 10_000,
):
# Instantiate environment if needed
if isinstance(env, str) or env is None:
env_name = self.env_name if env is None else env
log.info(f"Instantiated environment not provided; instantiating {env_name}")
env = get_env(env_name)

Check warning on line 105 in rl4co/models/nn/utils.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/utils.py#L102-L105

Added lines #L102 - L105 were not covered by tests

return rollout(env, td, random_policy)
return rollout(env, td, random_policy, max_steps=max_steps)

Check warning on line 107 in rl4co/models/nn/utils.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/utils.py#L107

Added line #L107 was not covered by tests

0 comments on commit 02dc606

Please sign in to comment.