Skip to content

Commit

Permalink
Fix all policy learners to return action rather than action index
Browse files Browse the repository at this point in the history
Summary:
Policy learners have been incorrectly returning action indices rather than actions themselves.

This has currently surfaced after greedy exploration was fixed to return actions rather than action indices, causing an inconsistency that's reflected in the recommendation system tutorial.

Reviewed By: yiwan-rl

Differential Revision: D55763685

fbshipit-source-id: ed3d096398cbcb317765464092890573bd49ff29
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Apr 16, 2024
1 parent 0c2d4c7 commit 399785e
Show file tree
Hide file tree
Showing 11 changed files with 401 additions and 27 deletions.
5 changes: 1 addition & 4 deletions pearl/pearl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,7 @@ def act(self, exploit: bool = False) -> Action:
subjective_state_to_be_used, safe_action_space, exploit=exploit # pyre-fixme[6]
)

if isinstance(safe_action_space, DiscreteActionSpace):
self._latest_action = safe_action_space.actions_batch[int(action.item())]
else:
self._latest_action = action
self._latest_action = action

return action

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class PropensityExploration(ExplorationModule):
def __init__(self) -> None:
super(PropensityExploration, self).__init__()

# TODO: We should make discrete action space itself iterable
def act(
self,
subjective_state: SubjectiveState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def act(
# this does a forward pass since all available
# actions are already stacked together

return torch.argmax(q_values).view((-1))
action_index = torch.argmax(q_values)
action = action_space.actions[action_index]
return action

def reset(self) -> None: # noqa: B027
# sample a new epistemic index (i.e., a Q-network) at the beginning of a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def act(
available_actions=actions,
)
# (action_space_size)
exploit_action = torch.argmax(action_probabilities)
exploit_action_index = torch.argmax(action_probabilities)
exploit_action = available_action_space.actions[exploit_action_index]

# Step 2: return exploit action if no exploration,
# else pass through the exploration module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def act(
# this does a forward pass since all avaialble
# actions are already stacked together

exploit_action = torch.argmax(q_values).view((-1))
exploit_action_index = torch.argmax(q_values)
exploit_action = available_action_space.actions[exploit_action_index]

if exploit:
return exploit_action
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def act(
q_values = self.safety_module.get_q_values_under_risk_metric(
states_repeated, actions, self._Q
)
exploit_action = torch.argmax(q_values).view((-1))
exploit_action_index = torch.argmax(q_values)
exploit_action = available_action_space.actions[exploit_action_index]

if exploit:
return exploit_action
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

# pyre-strict

from typing import Any, Dict, Iterable, List, Tuple

import torch
from typing import Any, Dict, Iterable, Tuple

from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
Expand Down Expand Up @@ -45,6 +43,9 @@ def __init__(
) -> None:
"""
Initializes the tabular Q-learning policy learner.
Currently, tabular Q-learning assumes
a discrete action space, and assumes that for each action
int(action.item()) == action's index.
Args:
learning_rate (float, optional): the learning rate. Defaults to 0.01.
Expand All @@ -66,6 +67,13 @@ def __init__(

def reset(self, action_space: ActionSpace) -> None:
self._action_space = action_space
for i, action in enumerate(self._action_space):
if int(action.item()) != i:
raise ValueError(
f"{self.__class__.__name__} only supports "
f"action spaces that are a DiscreteSpace where for each action "
f"action.item() == action's index. "
)

def act(
self,
Expand All @@ -74,21 +82,22 @@ def act(
exploit: bool = False,
) -> Action:
assert isinstance(available_action_space, DiscreteSpace)
# FIXME: this conversion should be eliminated once Action
# is no longer constrained to be a Tensor.
actions_as_ints: List[int] = [int(a.item()) for a in available_action_space]
# TODO: if we substitute DiscreteActionSpace for DiscreteSpace
# we get Pyre errors. It would be nice to fix this.

# Choose the action with the highest Q-value for the current state.
q_values_for_state = {
action: self.q_values.get((subjective_state, action), 0)
for action in actions_as_ints
action_q_values_for_state = {
action_index: self.q_values.get((subjective_state, action_index), 0)
for action_index in range(available_action_space.n)
}
max_q_value = max(q_values_for_state.values())
exploit_action = first_item(
action
for action, q_value in q_values_for_state.items()
if q_value == max_q_value
max_q_value_for_state = max(action_q_values_for_state.values())
exploit_action_index = first_item(
action_index
for action_index, q_value in action_q_values_for_state.items()
if q_value == max_q_value_for_state
)
exploit_action = torch.tensor([exploit_action])
exploit_action = available_action_space.actions[exploit_action_index]

if exploit:
return exploit_action

Expand All @@ -102,6 +111,7 @@ def learn(
self,
replay_buffer: ReplayBuffer,
) -> Dict[str, Any]:

# We know the sampling result from SingleTransitionReplayBuffer
# is a list with a single tuple.
transitions = replay_buffer.sample(1)
Expand Down
2 changes: 0 additions & 2 deletions pearl/safety_modules/identity_safety_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

# pyre-strict

from typing import Optional

from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
SubjectiveState,
Expand Down
2 changes: 1 addition & 1 deletion pearl/utils/instantiations/environments/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, number_of_steps: int = 100) -> None:
self.number_of_steps_so_far = 0
self.number_of_steps: int = number_of_steps
self._action_space = DiscreteActionSpace(
[torch.tensor(True), torch.tensor(False)]
[torch.tensor(False), torch.tensor(True)]
)

def step(self, action: Action) -> ActionResult:
Expand Down
Loading

0 comments on commit 399785e

Please sign in to comment.