diff --git a/rl4co/models/nn/utils.py b/rl4co/models/nn/utils.py index ff072556..09d7b47a 100644 --- a/rl4co/models/nn/utils.py +++ b/rl4co/models/nn/utils.py @@ -58,15 +58,19 @@ def random_policy(td): return td -def rollout(env, td, policy): +def rollout(env, td, policy, max_steps: int = 10_000): """Helper function to rollout a policy. Currently, TorchRL does not allow to step over envs when done with `env.rollout()`. We need this because for environements that complete at different steps. """ actions = [] + steps = 0 while not td["done"].all(): td = policy(td) actions.append(td["action"]) td = env.step(td)["next"] + steps += 1 + if steps > max_steps: + break return ( env.get_reward(td, torch.stack(actions, dim=1)), td, @@ -92,6 +96,7 @@ def forward( self, td: TensorDict, env: Union[str, RL4COEnvBase] = None, + max_steps: int = 10_000, ): # Instantiate environment if needed if isinstance(env, str) or env is None: @@ -99,4 +104,4 @@ def forward( log.info(f"Instantiated environment not provided; instantiating {env_name}") env = get_env(env_name) - return rollout(env, td, random_policy) + return rollout(env, td, random_policy, max_steps=max_steps)