Skip to content

Commit

Permalink
fix severl problems related to history summarization
Browse files Browse the repository at this point in the history
Summary:
This diff fixes several issues related to history summarization.

1) Currently contextual bandit algorithms do not update history summarization module parameters. This diff makes it possible to update these parameters in neural_bandit and neural_linear_bandit algorithms. Other contextual bandit algorithms are tabular or linear algorithms and are not based on nn.module currently. So I assume that they will not use nn.module-based history summarization module.

2) Actor critic methods now have a separate optimizer for the history summarization module. Previously the history summarization module shares the optimizer with the actor. But there seems to be no reason for making this choice. Also, the TD3 update rule for updating history summarization parameters (self._actor_optimizer.step()) is problematic. This is because this update rule results in adding a zero gradient when computing the momentum and other statistics in many optimizers like Adam and RMSprop, while the zero gradient should have been ignored. I also thought about having a separate history summarization optimizer in value-based methods. But I don't think this is needed. The history summarization module can share the same optimizer as the value function.

3) set_history_summarization_module is now an abstract method of policy learner. All policy learners need to implement this function. So will not miss as we did for contextual bandit algorithms.

4) benchmark.py does not address StackingHistorySummarizationModule correctly. Fix this problem.

5) I also thought about whether we should have a target network for the history summarization module. I thought before that this should be straightforward but after thinking a bit more I think including the target network is actually tricky and it is not clear to me whether it is worth including it.

Reviewed By: rodrigodesalvobraz

Differential Revision: D65760816

fbshipit-source-id: cc338d418015622f17a0ef504fe3e5d401aeda22
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Nov 13, 2024
1 parent b94c5bd commit 8458dbc
Show file tree
Hide file tree
Showing 17 changed files with 99 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pearl.api.action_space import ActionSpace
from pearl.api.reward import Value
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
SubjectiveState,
)
from pearl.policy_learners.exploration_modules.exploration_module import (
Expand Down
9 changes: 9 additions & 0 deletions pearl/policy_learners/contextual_bandits/disjoint_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pearl.api.action import Action
from pearl.api.action_space import ActionSpace
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
SubjectiveState,
)
from pearl.neural_networks.common.utils import ensemble_forward
Expand Down Expand Up @@ -228,3 +229,11 @@ def get_scores(
@property
def optimizer(self) -> torch.optim.Optimizer:
return self._optimizer

def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
# usually, this method would also add the parameters of the history summarization module
# to the optimizer of the bandit, but disjoint bandits do not use a pytorch optimizer.
# Instead, the optimization uses Pearl's own linear regression module.
self._history_summarization_module = value
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from pearl.api.action import Action
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
SubjectiveState,
)
from pearl.neural_networks.common.utils import ensemble_forward
Expand Down Expand Up @@ -145,3 +146,11 @@ def get_scores(
subjective_state: SubjectiveState,
) -> torch.Tensor:
raise NotImplementedError("Implement when necessary")

def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
# usually, this method would also add the parameters of the history summarization module
# to the optimizer of the bandit, but disjoint bandits do not use a pytorch optimizer.
# Instead, the optimization uses Pearl's own linear regression module.
self._history_summarization_module = value
8 changes: 8 additions & 0 deletions pearl/policy_learners/contextual_bandits/linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from pearl.api.action import Action
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
SubjectiveState,
)
from pearl.neural_networks.contextual_bandit.linear_regression import LinearRegression
Expand Down Expand Up @@ -190,3 +191,10 @@ def get_scores(
action_space=action_space,
representation=self.model,
).squeeze(-1)

def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
# currently linear bandit algorithm does not update
# parameters of the history summarization module
self._history_summarization_module = value
7 changes: 7 additions & 0 deletions pearl/policy_learners/contextual_bandits/neural_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pearl.api.action_space import ActionSpace

from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
SubjectiveState,
)
from pearl.neural_networks.common.utils import LOSS_TYPES
Expand Down Expand Up @@ -111,6 +112,12 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
"loss": loss.detach(),
}

def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._optimizer.add_param_group({"params": value.parameters()})
self._history_summarization_module = value

def act(
self,
subjective_state: SubjectiveState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pearl.api.action_space import ActionSpace

from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
SubjectiveState,
)
from pearl.neural_networks.common.utils import LOSS_TYPES
Expand Down Expand Up @@ -146,6 +147,12 @@ def _maybe_apply_discounting(self) -> None:
def optimizer(self) -> torch.optim.Optimizer:
return self._optimizer

def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._optimizer.add_param_group({"params": value.parameters()})
self._history_summarization_module = value

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:

# get scores for logging purpose
Expand Down
3 changes: 2 additions & 1 deletion pearl/policy_learners/policy_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def exploration_module(self, new_exploration_module: ExplorationModule) -> None:
def get_action_representation_module(self) -> ActionRepresentationModule:
return self._action_representation_module

@abstractmethod
def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._history_summarization_module = value
pass

def reset(self, action_space: ActionSpace) -> None:
"""Resets policy maker for a new episode. Default implementation does nothing."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(
action_space: Optional[ActionSpace] = None,
actor_learning_rate: float = 1e-3,
critic_learning_rate: float = 1e-3,
# used only for learnable history summarization module
history_summarization_learning_rate: float = 1e-3,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Union[
Type[ValueNetwork], Type[QValueNetwork]
Expand Down Expand Up @@ -187,11 +189,23 @@ def __init__(
self._critic_target: nn.Module = copy.deepcopy(self._critic)

self._discount_factor = discount_factor
self._history_summarization_learning_rate = history_summarization_learning_rate

def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._actor_optimizer.add_param_group({"params": value.parameters()})
"""
The history summarization module uses its own optimizer.
"""
self._history_summarization_optimizer: optim.Optimizer = optim.AdamW(
[
{
"params": value.parameters(),
"lr": self._history_summarization_learning_rate,
"amsgrad": True,
}
]
)
self._history_summarization_module = value

def act(
Expand Down Expand Up @@ -279,6 +293,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
Dict[str, Any]: A dictionary containing the loss reports from the critic
and actor updates. These can be useful to track for debugging purposes.
"""
self._history_summarization_optimizer.zero_grad()
actor_loss = self._actor_loss(batch)
self._actor_optimizer.zero_grad()
"""
Expand All @@ -289,25 +304,15 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
After the graph is cleared, critic_loss.backward() fails.
"""
actor_loss.backward(retain_graph=True)
self._actor_optimizer.step()
report = {"actor_loss": actor_loss.item()}
if self._use_critic:
self._critic_optimizer.zero_grad()
critic_loss = self._critic_loss(batch)
"""
This backward operation needs to happen before the actor_optimizer.step().
This is because actor_optimizer.step() updates the history summarization neural network
and critic_loss.backward() fails
once parameters involved in critic_loss's computational graph change.
"""
critic_loss.backward()
self._actor_optimizer.step()
self._critic_optimizer.step()
report = {
"actor_loss": actor_loss.item(),
"critic_loss": critic_loss.item(),
}
else:
self._actor_optimizer.step()
report = {"actor_loss": actor_loss.item()}
report["critic_loss"] = critic_loss.item()
self._history_summarization_optimizer.step()

if self._use_critic_target:
update_critic_target_network(
Expand Down
2 changes: 2 additions & 0 deletions pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
exploration_module: Optional[ExplorationModule] = None,
actor_learning_rate: float = 1e-3,
critic_learning_rate: float = 1e-3,
history_summarization_learning_rate: float = 1e-3,
actor_network_type: Type[ActorNetwork] = VanillaContinuousActorNetwork,
critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
actor_soft_update_tau: float = 0.005,
Expand All @@ -73,6 +74,7 @@ def __init__(
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
history_summarization_learning_rate=history_summarization_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
use_actor_target=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
value_critic_learning_rate: float = 1e-3,
actor_learning_rate: float = 1e-3,
critic_learning_rate: float = 1e-3,
history_summarization_learning_rate: float = 1e-3,
critic_soft_update_tau: float = 0.05,
discount_factor: float = 0.99,
training_rounds: int = 5,
Expand All @@ -107,6 +108,7 @@ def __init__(
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
history_summarization_learning_rate=history_summarization_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
use_actor_target=False,
Expand Down Expand Up @@ -147,16 +149,11 @@ def __init__(
amsgrad=True,
)

def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
self._actor_optimizer.add_param_group({"params": value.parameters()})
self._history_summarization_module = value

def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
value_loss = self._value_loss(batch)
critic_loss = self._critic_loss(batch)
actor_loss = self._actor_loss(batch)
self._history_summarization_optimizer.zero_grad()
self._value_network_optimizer.zero_grad()
self._actor_optimizer.zero_grad()
self._critic_optimizer.zero_grad()
Expand All @@ -165,7 +162,7 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
self._value_network_optimizer.step()
self._actor_optimizer.step()
self._critic_optimizer.step()

self._history_summarization_optimizer.step()
# update critic and target Twin networks;
update_target_networks(
self._critic_target._critic_networks_combined,
Expand Down
2 changes: 2 additions & 0 deletions pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
critic_hidden_dims: Optional[List[int]] = None,
actor_learning_rate: float = 1e-4,
critic_learning_rate: float = 1e-4,
history_summarization_learning_rate: float = 1e-4,
exploration_module: Optional[ExplorationModule] = None,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Type[ValueNetwork] = VanillaValueNetwork,
Expand All @@ -119,6 +120,7 @@ def __init__(
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
history_summarization_learning_rate=history_summarization_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
use_actor_target=False,
Expand Down
2 changes: 2 additions & 0 deletions pearl/policy_learners/sequential_decision_making/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
action_space: Optional[ActionSpace] = None,
actor_learning_rate: float = 1e-4,
critic_learning_rate: float = 1e-4,
history_summarization_learning_rate: float = 1e-4,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Type[ValueNetwork] = VanillaValueNetwork,
exploration_module: Optional[ExplorationModule] = None,
Expand All @@ -115,6 +116,7 @@ def __init__(
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
history_summarization_learning_rate=history_summarization_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
use_actor_target=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
critic_hidden_dims: Optional[List[int]] = None,
actor_learning_rate: float = 1e-4,
critic_learning_rate: float = 1e-4,
history_summarization_learning_rate: float = 1e-4,
actor_network_type: Type[ActorNetwork] = VanillaActorNetwork,
critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
critic_soft_update_tau: float = 0.005,
Expand All @@ -78,6 +79,7 @@ def __init__(
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
history_summarization_learning_rate=history_summarization_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
use_actor_target=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
critic_hidden_dims: Optional[List[int]] = None,
actor_learning_rate: float = 1e-3,
critic_learning_rate: float = 1e-3,
history_summarization_learning_rate: float = 1e-3,
actor_network_type: Type[ActorNetwork] = GaussianActorNetwork,
critic_network_type: Type[QValueNetwork] = VanillaQValueNetwork,
critic_soft_update_tau: float = 0.005,
Expand All @@ -72,6 +73,7 @@ def __init__(
critic_hidden_dims=critic_hidden_dims,
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
history_summarization_learning_rate=history_summarization_learning_rate,
actor_network_type=actor_network_type,
critic_network_type=critic_network_type,
use_actor_target=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pearl.api.action_space import ActionSpace
from pearl.api.reward import Reward, Value
from pearl.history_summarization_modules.history_summarization_module import (
HistorySummarizationModule,
SubjectiveState,
)
from pearl.policy_learners.exploration_modules.common.epsilon_greedy_exploration import (
Expand Down Expand Up @@ -75,6 +76,12 @@ def reset(self, action_space: ActionSpace) -> None:
f"action.item() == action's index. "
)

def set_history_summarization_module(
self, value: HistorySummarizationModule
) -> None:
# tabular q learning is assumed to not update parameters of the history summarization module
self._history_summarization_module = value

def act(
self,
subjective_state: SubjectiveState,
Expand Down
7 changes: 4 additions & 3 deletions pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,21 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
self._critic_update_count += 1
report = {}
# delayed actor update
self._actor_optimizer.zero_grad()
self._history_summarization_optimizer.zero_grad()
if self._critic_update_count % self._actor_update_freq == 0:
# see ddpg base class for actor update details
self._actor_optimizer.zero_grad()
actor_loss = self._actor_loss(batch)
actor_loss.backward(retain_graph=True)
self._actor_optimizer.step()
self._last_actor_loss = actor_loss.item()
report["actor_loss"] = self._last_actor_loss

self._critic_optimizer.zero_grad()
critic_loss = self._critic_loss(batch) # critic update
critic_loss.backward()
self._actor_optimizer.step()
self._critic_optimizer.step()
report["critic_loss"] = critic_loss.item()
self._history_summarization_optimizer.step()

if self._critic_update_count % self._actor_update_freq == 0:
# update targets of critics using soft updates
Expand Down
13 changes: 11 additions & 2 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,19 @@ def evaluate_single(
):
if (
method["history_summarization_module"].__name__
== "StackHistorySummarizationModule"
== "StackingHistorySummarizationModule"
):
method["history_summarization_module_args"]["observation_dim"] = (
env.observation_space.shape[0]
)
method["history_summarization_module_args"]["action_dim"] = (
policy_learner_args["action_representation_module"].representation_dim
if "action_representation_module" in policy_learner_args
else env.action_space.action_dim
)
policy_learner_args["state_dim"] = (
env.observation_space.shape[0] + env.action_space.n
method["history_summarization_module_args"]["observation_dim"]
+ method["history_summarization_module_args"]["action_dim"]
) * method["history_summarization_module_args"]["history_length"]
elif (
method["history_summarization_module"].__name__
Expand Down

0 comments on commit 8458dbc

Please sign in to comment.