-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add wrappers to exploration module to allow following external policies
Summary: This diff adds wrappers to exploration modules to make it is easy for users to specify the behavior of the agent (currently the agent can only follow a policy determined by the exploration module). Reviewed By: rodrigodesalvobraz Differential Revision: D65786549 fbshipit-source-id: 9bba246d97be6e1c304851738224f646e4b36c65
- Loading branch information
1 parent
8458dbc
commit ff2b13c
Showing
7 changed files
with
164 additions
and
5 deletions.
There are no files selected for viewing
53 changes: 53 additions & 0 deletions
53
pearl/policy_learners/exploration_modules/exploration_module_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
# pyre-strict | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from pearl.api.action import Action | ||
from pearl.api.action_space import ActionSpace | ||
from pearl.history_summarization_modules.history_summarization_module import ( | ||
SubjectiveState, | ||
) | ||
from pearl.policy_learners.exploration_modules import ExplorationModule | ||
from pearl.replay_buffers.replay_buffer import ReplayBuffer | ||
|
||
|
||
class ExplorationModuleWrapper(ExplorationModule): | ||
""" | ||
This class is the base class for all exploration module wrappers. | ||
""" | ||
|
||
def __init__(self, exploration_module: ExplorationModule) -> None: | ||
self.exploration_module: ExplorationModule = exploration_module | ||
|
||
def reset(self) -> None: # noqa: B027 | ||
self.exploration_module.reset() | ||
|
||
def act( | ||
self, | ||
subjective_state: SubjectiveState, | ||
action_space: ActionSpace, | ||
values: Optional[torch.Tensor] = None, | ||
exploit_action: Optional[Action] = None, | ||
action_availability_mask: Optional[torch.Tensor] = None, | ||
representation: Optional[torch.nn.Module] = None, | ||
) -> Action: | ||
return self.exploration_module.act( | ||
subjective_state, | ||
action_space, | ||
values, | ||
exploit_action, | ||
action_availability_mask, | ||
representation, | ||
) | ||
|
||
def learn(self, replay_buffer: ReplayBuffer) -> None: # noqa: B027 | ||
"""Learns from the replay buffer. Default implementation does nothing.""" | ||
self.exploration_module.learn(replay_buffer) |
14 changes: 14 additions & 0 deletions
14
pearl/policy_learners/exploration_modules/wrappers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from .warmup import Warmup | ||
|
||
|
||
__all__ = [ | ||
"Warmup", | ||
] |
60 changes: 60 additions & 0 deletions
60
pearl/policy_learners/exploration_modules/wrappers/warmup.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
# pyre-strict | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from pearl.api.action import Action | ||
from pearl.api.action_space import ActionSpace | ||
from pearl.history_summarization_modules.history_summarization_module import ( | ||
SubjectiveState, | ||
) | ||
from pearl.policy_learners.exploration_modules.exploration_module import ( | ||
ExplorationModule, | ||
) | ||
from pearl.policy_learners.exploration_modules.exploration_module_wrapper import ( | ||
ExplorationModuleWrapper, | ||
) | ||
|
||
|
||
class Warmup(ExplorationModuleWrapper): | ||
""" | ||
Follow the random policy for the first `warmup_steps` steps, | ||
then switch to the actions from the base exploration module. | ||
""" | ||
|
||
def __init__( | ||
self, exploration_module: ExplorationModule, warmup_steps: int | ||
) -> None: | ||
self.warmup_steps = warmup_steps | ||
self.time_step = 0 | ||
super().__init__(exploration_module) | ||
|
||
def act( | ||
self, | ||
subjective_state: SubjectiveState, | ||
action_space: ActionSpace, | ||
values: Optional[torch.Tensor] = None, | ||
exploit_action: Optional[Action] = None, | ||
action_availability_mask: Optional[torch.Tensor] = None, | ||
representation: Optional[torch.nn.Module] = None, | ||
) -> Action: | ||
if self.time_step < self.warmup_steps: | ||
action = action_space.sample() | ||
else: | ||
action = self.exploration_module.act( | ||
subjective_state=subjective_state, | ||
action_space=action_space, | ||
values=values, | ||
exploit_action=exploit_action, | ||
action_availability_mask=action_availability_mask, | ||
representation=representation, | ||
) | ||
self.time_step += 1 | ||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters