Skip to content

Commit

Permalink
add wrappers to exploration module to allow following external policies
Browse files Browse the repository at this point in the history
Summary: This diff adds wrappers to exploration modules to make it is easy for users to specify the behavior of the agent (currently the agent can only follow a policy determined by the exploration module).

Reviewed By: rodrigodesalvobraz

Differential Revision: D65786549

fbshipit-source-id: 9bba246d97be6e1c304851738224f646e4b36c65
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Nov 13, 2024
1 parent 8458dbc commit ff2b13c
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

# pyre-strict

from typing import Optional

import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
SubjectiveState,
)
from pearl.policy_learners.exploration_modules import ExplorationModule
from pearl.replay_buffers.replay_buffer import ReplayBuffer


class ExplorationModuleWrapper(ExplorationModule):
"""
This class is the base class for all exploration module wrappers.
"""

def __init__(self, exploration_module: ExplorationModule) -> None:
self.exploration_module: ExplorationModule = exploration_module

def reset(self) -> None: # noqa: B027
self.exploration_module.reset()

def act(
self,
subjective_state: SubjectiveState,
action_space: ActionSpace,
values: Optional[torch.Tensor] = None,
exploit_action: Optional[Action] = None,
action_availability_mask: Optional[torch.Tensor] = None,
representation: Optional[torch.nn.Module] = None,
) -> Action:
return self.exploration_module.act(
subjective_state,
action_space,
values,
exploit_action,
action_availability_mask,
representation,
)

def learn(self, replay_buffer: ReplayBuffer) -> None: # noqa: B027
"""Learns from the replay buffer. Default implementation does nothing."""
self.exploration_module.learn(replay_buffer)
14 changes: 14 additions & 0 deletions pearl/policy_learners/exploration_modules/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from .warmup import Warmup


__all__ = [
"Warmup",
]
60 changes: 60 additions & 0 deletions pearl/policy_learners/exploration_modules/wrappers/warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

# pyre-strict

from typing import Optional

import torch
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
SubjectiveState,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
ExplorationModule,
)
from pearl.policy_learners.exploration_modules.exploration_module_wrapper import (
ExplorationModuleWrapper,
)


class Warmup(ExplorationModuleWrapper):
"""
Follow the random policy for the first `warmup_steps` steps,
then switch to the actions from the base exploration module.
"""

def __init__(
self, exploration_module: ExplorationModule, warmup_steps: int
) -> None:
self.warmup_steps = warmup_steps
self.time_step = 0
super().__init__(exploration_module)

def act(
self,
subjective_state: SubjectiveState,
action_space: ActionSpace,
values: Optional[torch.Tensor] = None,
exploit_action: Optional[Action] = None,
action_availability_mask: Optional[torch.Tensor] = None,
representation: Optional[torch.nn.Module] = None,
) -> Action:
if self.time_step < self.warmup_steps:
action = action_space.sample()
else:
action = self.exploration_module.act(
subjective_state=subjective_state,
action_space=action_space,
values=values,
exploit_action=exploit_action,
action_availability_mask=action_availability_mask,
representation=representation,
)
self.time_step += 1
return action
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def online_learning(
print_every_x_steps: Optional[int] = None,
seed: Optional[int] = None,
record_period: int = 1,
learning_start_step: int = 0,
# TODO: use LearningLogger similarly to offline_learning
) -> Dict[str, Any]:
"""
Expand All @@ -97,6 +98,8 @@ def online_learning(
If number_of_episodes is used, report every record_period episodes.
If number_of_steps is used, report every record_period steps
Episodic statistics collected within this period are averaged and then recorded.
learning_start_step (int, optional): the agent starts to learn at learning_start_step.
Defaults to 0.
"""
assert (number_of_episodes is None and number_of_steps is not None) or (
number_of_episodes is not None and number_of_steps is None
Expand All @@ -120,6 +123,7 @@ def online_learning(
learn_every_k_steps=learn_every_k_steps,
total_steps=old_total_steps,
seed=seed,
learning_start_step=learning_start_step,
)
if number_of_steps is not None and episode_total_steps > record_period:
print(
Expand Down Expand Up @@ -234,6 +238,7 @@ def run_episode(
learn_every_k_steps: int = 1,
total_steps: int = 0,
seed: Optional[int] = None,
learning_start_step: int = 0,
) -> Tuple[Dict[str, Any], int]:
"""
Runs one episode and returns an info dict and number of steps taken.
Expand All @@ -248,6 +253,8 @@ def run_episode(
learn_every_k_steps (int, optional): asks the agent to learn every k steps.
total_steps (int, optional): the total number of steps taken so far. Defaults to 0.
seed (int, optional): the seed for the environment. Defaults to None.
learning_start_step (int, optional): the agent starts to learn at learning_start_step.
Defaults to 0.
Returns:
Tuple[Dict[str, Any], int]: the return of the episode and the number of steps taken.
"""
Expand Down Expand Up @@ -285,7 +292,7 @@ def run_episode(
agent.observe(action_result)
done = action_result.done
episode_steps += 1
if learn:
if learn and (total_steps + episode_steps >= learning_start_step):
if learn_after_episode:
# when learn_after_episode is True, we learn only at the end of the episode,
# regardless of the value of learn_every_k_steps.
Expand Down
10 changes: 6 additions & 4 deletions pearl/utils/instantiations/spaces/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class BoxSpace(Space):

def __init__(
self,
low: Union[float, Tensor],
high: Union[float, Tensor],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
low: Union[float, np.ndarray, Tensor],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
high: Union[float, np.ndarray, Tensor],
seed: Optional[Union[int, np.random.Generator]] = None,
) -> None:
"""Constructs a `BoxSpace`.
Expand All @@ -51,10 +53,10 @@ def __init__(
"""
# pyre-fixme[9]: low has type `Union[float, Tensor]`; used as `ndarray[Any,
# Any]`.
low = low.numpy(force=True) if isinstance(low, Tensor) else np.array([low])
low = low.numpy(force=True) if isinstance(low, Tensor) else low
# pyre-fixme[9]: high has type `Union[float, Tensor]`; used as `ndarray[Any,
# Any]`.
high = high.numpy(force=True) if isinstance(high, Tensor) else np.array([high])
high = high.numpy(force=True) if isinstance(high, Tensor) else high
self._gym_space = Box(low=low, high=high, seed=seed)

@property
Expand Down
8 changes: 8 additions & 0 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ def evaluate_single(
policy_learner_args["exploration_module"] = method["exploration_module"](
**method["exploration_module_args"]
)
if "exploration_module_wrapper" in method:
policy_learner_args["exploration_module"] = method[
"exploration_module_wrapper"
](
exploration_module=policy_learner_args["exploration_module"],
**method["exploration_module_wrapper_args"],
)
if "replay_buffer" in method and "replay_buffer_args" in method:
agent_args["replay_buffer"] = method["replay_buffer"](
**method["replay_buffer_args"]
Expand Down Expand Up @@ -239,6 +246,7 @@ def evaluate_single(
learn_every_k_steps=learn_every_k_steps,
seed=run_idx,
record_period=record_period,
learning_start_step=method.get("learning_start_step", 0),
)
dir = f"outputs/{env_name}/{method_name}"
os.makedirs(dir, exist_ok=True)
Expand Down
15 changes: 15 additions & 0 deletions pearl/utils/scripts/benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@
from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import ( # noqa E501
EGreedyExploration,
)
from pearl.policy_learners.exploration_modules.common.no_exploration import (
NoExploration,
)
from pearl.policy_learners.exploration_modules.common.normal_distribution_exploration import ( # noqa E501
NormalDistributionExploration,
)
from pearl.policy_learners.exploration_modules.common.propensity_exploration import (
PropensityExploration,
)
from pearl.policy_learners.exploration_modules.wrappers.warmup import Warmup
from pearl.policy_learners.sequential_decision_making.bootstrapped_dqn import (
BootstrappedDQN,
)
Expand Down Expand Up @@ -546,7 +550,10 @@
"mean": 0,
"std_dev": 0.1,
},
"exploration_module_wrapper": Warmup,
"exploration_module_wrapper_args": {"warmup_steps": 25000},
"replay_buffer": BasicReplayBuffer,
"learning_start_step": 25000,
"replay_buffer_args": {"capacity": 100000},
}

Expand Down Expand Up @@ -605,7 +612,10 @@
"mean": 0,
"std_dev": 0.1,
},
"exploration_module_wrapper": Warmup,
"exploration_module_wrapper_args": {"warmup_steps": 25000},
"replay_buffer": BasicReplayBuffer,
"learning_start_step": 25000,
"replay_buffer_args": {"capacity": 100000},
}

Expand Down Expand Up @@ -660,7 +670,12 @@
"critic_network_type": VanillaQValueNetwork,
"discount_factor": 0.99,
},
"exploration_module": NoExploration,
"exploration_module_args": {},
"exploration_module_wrapper": Warmup,
"exploration_module_wrapper_args": {"warmup_steps": 5000},
"replay_buffer": BasicReplayBuffer,
"learning_start_step": 5000,
"replay_buffer_args": {"capacity": 100000},
}

Expand Down

0 comments on commit ff2b13c

Please sign in to comment.