Skip to content

Commit

Permalink
Implement RandomPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
Junyoungpark committed Aug 30, 2023
1 parent c00c5c1 commit ae30f25
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions rl4co/models/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Union

import torch
import torch.nn as nn
from tensordict import TensorDict

from rl4co.envs import RL4COEnvBase, get_env
from rl4co.utils import get_pylogger

log = get_pylogger(__name__)
Expand All @@ -14,9 +19,7 @@ def get_log_likelihood(log_p, actions, mask, return_sum: bool = True):
if mask is not None:
log_p[~mask] = 0

assert (
log_p > -1000
).data.all(), "Logprobs should not be -inf, check sampling procedure!"
assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"

# Calculate log_likelihood
if return_sum:
Expand Down Expand Up @@ -69,3 +72,31 @@ def rollout(env, td, policy):
td,
torch.stack(actions, dim=1),
)


class RandomPolicy(nn.Module):

"""
Random Policy Class that randomly select actions from the action space
This policy can be useful to check the sanity of the environment during development
We match the function signature of forward to the one of the AutoregressivePolicy class
"""

def __init__(self, env_name=None):
super().__init__()
self.env_name = env_name

Check warning on line 89 in rl4co/models/nn/utils.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/utils.py#L88-L89

Added lines #L88 - L89 were not covered by tests

def forward(
self,
td: TensorDict,
env: Union[str, RL4COEnvBase] = None,
):
# Instantiate environment if needed
if isinstance(env, str) or env is None:
env_name = self.env_name if env is None else env
log.info(f"Instantiated environment not provided; instantiating {env_name}")
env = get_env(env_name)

Check warning on line 100 in rl4co/models/nn/utils.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/utils.py#L97-L100

Added lines #L97 - L100 were not covered by tests

return rollout(env, td, random_policy)

Check warning on line 102 in rl4co/models/nn/utils.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/utils.py#L102

Added line #L102 was not covered by tests

0 comments on commit ae30f25

Please sign in to comment.