Skip to content

Commit

Permalink
Merge pull request #89 from kaist-silab/random_rollout
Browse files Browse the repository at this point in the history
Implement RandomPolicy
  • Loading branch information
Junyoungpark authored Aug 30, 2023
2 parents c00c5c1 + 1cef082 commit 0455cbe
Showing 1 changed file with 44 additions and 5 deletions.
49 changes: 44 additions & 5 deletions rl4co/models/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Union

import torch
import torch.nn as nn
from tensordict import TensorDict

from rl4co.envs import RL4COEnvBase, get_env
from rl4co.utils import get_pylogger

log = get_pylogger(__name__)
Expand All @@ -14,9 +19,7 @@ def get_log_likelihood(log_p, actions, mask, return_sum: bool = True):
if mask is not None:
log_p[~mask] = 0

assert (
log_p > -1000
).data.all(), "Logprobs should not be -inf, check sampling procedure!"
assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"

# Calculate log_likelihood
if return_sum:
Expand Down Expand Up @@ -55,17 +58,53 @@ def random_policy(td):
return td


def rollout(env, td, policy):
def rollout(env, td, policy, max_steps: int = None):
"""Helper function to rollout a policy. Currently, TorchRL does not allow to step
over envs when done with `env.rollout()`. We need this because for environements that complete at different steps.
over envs when done with `env.rollout()`. We need this because for environments that complete at different steps.
"""

max_steps = float("inf") if max_steps is None else max_steps
actions = []
steps = 0

while not td["done"].all():
td = policy(td)
actions.append(td["action"])
td = env.step(td)["next"]
steps += 1
if steps > max_steps:
break
return (
env.get_reward(td, torch.stack(actions, dim=1)),
td,
torch.stack(actions, dim=1),
)


class RandomPolicy(nn.Module):

"""
Random Policy Class that randomly select actions from the action space
This policy can be useful to check the sanity of the environment during development
We match the function signature of forward to the one of the AutoregressivePolicy class
"""

def __init__(self, env_name=None):
super().__init__()
self.env_name = env_name

def forward(
self,
td: TensorDict,
env: Union[str, RL4COEnvBase] = None,
max_steps: int = None,
):
# Instantiate environment if needed
if isinstance(env, str) or env is None:
env_name = self.env_name if env is None else env
log.info(f"Instantiated environment not provided; instantiating {env_name}")
env = get_env(env_name)

return rollout(env, td, random_policy, max_steps=max_steps)

0 comments on commit 0455cbe

Please sign in to comment.