From ae30f25536ae8504f087e5a2db23b84aa8abe89f Mon Sep 17 00:00:00 2001 From: junyoungpark Date: Wed, 30 Aug 2023 17:53:12 +0900 Subject: [PATCH] Implement RandomPolicy --- rl4co/models/nn/utils.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/rl4co/models/nn/utils.py b/rl4co/models/nn/utils.py index 3c694556..ff072556 100644 --- a/rl4co/models/nn/utils.py +++ b/rl4co/models/nn/utils.py @@ -1,5 +1,10 @@ +from typing import Union + import torch +import torch.nn as nn +from tensordict import TensorDict +from rl4co.envs import RL4COEnvBase, get_env from rl4co.utils import get_pylogger log = get_pylogger(__name__) @@ -14,9 +19,7 @@ def get_log_likelihood(log_p, actions, mask, return_sum: bool = True): if mask is not None: log_p[~mask] = 0 - assert ( - log_p > -1000 - ).data.all(), "Logprobs should not be -inf, check sampling procedure!" + assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!" # Calculate log_likelihood if return_sum: @@ -69,3 +72,31 @@ def rollout(env, td, policy): td, torch.stack(actions, dim=1), ) + + +class RandomPolicy(nn.Module): + + """ + Random Policy Class that randomly select actions from the action space + This policy can be useful to check the sanity of the environment during development + + We match the function signature of forward to the one of the AutoregressivePolicy class + + """ + + def __init__(self, env_name=None): + super().__init__() + self.env_name = env_name + + def forward( + self, + td: TensorDict, + env: Union[str, RL4COEnvBase] = None, + ): + # Instantiate environment if needed + if isinstance(env, str) or env is None: + env_name = self.env_name if env is None else env + log.info(f"Instantiated environment not provided; instantiating {env_name}") + env = get_env(env_name) + + return rollout(env, td, random_policy)