Skip to content

Commit

Permalink
k step learning and gae in ppo
Browse files Browse the repository at this point in the history
Summary:
This diff adds k-step learning to our codebase and adds generalized advantage estimation (gae).

In our current codebase, learning can happen either after each episode or every environment step. However, some algorithms, like PPO, learn every k steps, where k>1. This diff enables learning every k environment steps by adding a new field in the online_learning function.

To support this new feature, this diff modifies OnPolicyEpisodicReplayBuffer. The current implementation of it is only compatible when learning happens after each episode. This diff makes it compatible with k-step learning. Some necessary changes to PPO and REINFORCE are also made.

Note that in the online_learning function, learn_every_k_steps would only be effective when learn_after_episode is False. And the default value of learn_every_k_steps is 1. In this way this diff does not change the behavior of external code relying on our online_learning function.

This diff also adds GAE to PPO. Current PPO implementation uses n-step returns to compute advantages for policy updates. The original PPO algorithm uses GAE, which is the difference between the truncated lambda return and the current value estimate.

Reviewed By: rodrigodesalvobraz

Differential Revision: D53838214

fbshipit-source-id: 216ceb6584a1a4c156fe5f29562f0ddd1c9970eb
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Feb 24, 2024
1 parent c189690 commit 611928a
Show file tree
Hide file tree
Showing 18 changed files with 419 additions and 254 deletions.
9 changes: 6 additions & 3 deletions pearl/policy_learners/policy_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,14 @@ def learn(
Returns:
A dictionary which includes useful metrics
"""
batch_size = self._batch_size if not self.on_policy else len(replay_buffer)

if len(replay_buffer) < batch_size or len(replay_buffer) == 0:
if len(replay_buffer) == 0:
return {}

if self._batch_size == -1 or len(replay_buffer) < self._batch_size:
batch_size = len(replay_buffer)
else:
batch_size = self._batch_size

report = {}
for _ in range(self._training_rounds):
self._training_steps += 1
Expand Down
117 changes: 89 additions & 28 deletions pearl/policy_learners/sequential_decision_making/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Any, Dict, List, Optional, Type

import torch
Expand Down Expand Up @@ -33,8 +32,12 @@
single_critic_state_value_loss,
)
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.sequential_decision_making.on_policy_replay_buffer import (
OnPolicyReplayBuffer,
OnPolicyTransition,
OnPolicyTransitionBatch,
)
from pearl.replay_buffers.transition import TransitionBatch
from torch import nn


class ProximalPolicyOptimization(ActorCriticBase):
Expand All @@ -57,6 +60,7 @@ def __init__(
training_rounds: int = 100,
batch_size: int = 128,
epsilon: float = 0.0,
trace_decay_param: float = 0.95,
entropy_bonus_scaling: float = 0.01,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
Expand Down Expand Up @@ -85,15 +89,16 @@ def __init__(
action_representation_module=action_representation_module,
)
self._epsilon = epsilon
self._trace_decay_param = trace_decay_param
self._entropy_bonus_scaling = entropy_bonus_scaling
self._actor_old: nn.Module = copy.deepcopy(self._actor)

def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
"""
Loss = actor loss + critic loss + entropy_bonus_scaling * entropy loss
"""
# TODO need to support continuous action
# TODO: change the output shape of value networks
vs: torch.Tensor = self._critic(batch.state).view(-1) # shape (batch_size)
assert isinstance(batch, OnPolicyTransitionBatch)
action_probs = self._actor.get_action_prob(
state_batch=batch.state,
action_batch=batch.action,
Expand All @@ -103,47 +108,103 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
# shape (batch_size)

# actor loss
with torch.no_grad():
action_probs_old = self._actor_old.get_action_prob(
state_batch=batch.state,
action_batch=batch.action,
available_actions=batch.curr_available_actions,
unavailable_actions_mask=batch.curr_unavailable_actions_mask,
) # shape (batch_size)
action_probs_old = batch.action_probs
assert action_probs_old is not None
r_thelta = torch.div(action_probs, action_probs_old) # shape (batch_size)
clip = torch.clamp(
r_thelta, min=1.0 - self._epsilon, max=1.0 + self._epsilon
) # shape (batch_size)

# advantage estimator, in paper:
# A = sum(lambda^t*gamma^t*TD_error), while TD_error = reward + gamma * V(s+1) - V(s)
# when lambda = 1 and gamma = 1
# A = sum(TD_error) = return - V(s)
# TODO support lambda and gamma
with torch.no_grad():
advantage = batch.cum_reward - vs # shape (batch_size)

loss = torch.sum(-torch.min(r_thelta * batch.gae, clip * batch.gae))
# entropy
# Categorical is good for Cartpole Env where actions are discrete
# TODO need to support continuous action
entropy: torch.Tensor = torch.distributions.Categorical(
action_probs.detach()
).entropy()
loss = torch.sum(
-torch.min(r_thelta * advantage, clip * advantage)
) - torch.sum(self._entropy_bonus_scaling * entropy)
loss -= torch.sum(self._entropy_bonus_scaling * entropy)
return loss

def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
assert batch.cum_reward is not None
assert isinstance(batch, OnPolicyTransitionBatch)
assert batch.lam_return is not None
return single_critic_state_value_loss(
state_batch=batch.state,
expected_target_batch=batch.cum_reward,
expected_target_batch=batch.lam_return,
critic=self._critic,
)

def learn(self, replay_buffer: ReplayBuffer) -> Dict[str, Any]:
self.preprocess_replay_buffer(replay_buffer)
# sample from replay buffer and learn
result = super().learn(replay_buffer)
# update old actor with latest actor for next round
self._actor_old.load_state_dict(self._actor.state_dict())
return result

def preprocess_replay_buffer(self, replay_buffer: ReplayBuffer) -> None:
"""
Preprocess the replay buffer by calculating
and adding the generalized advantage estimates (gae),
truncated lambda returns (lam_return) and action probabilities (action_probs)
under the current policy.
See https://arxiv.org/abs/1707.06347 equation (11) for the definition of gae.
See "Reinforcement Learning: An Introduction" by Sutton and Barto (2018) equation (12.10)
for the definition of truncated lambda return.
"""
assert type(replay_buffer) is OnPolicyReplayBuffer
assert len(replay_buffer.memory) > 0
(
state_list,
action_list,
available_actions_list,
unavailable_actions_mask_list,
) = ([], [], [], [])
for transition in reversed(replay_buffer.memory):
state_list.append(transition.state)
action_list.append(transition.action)
available_actions_list.append(transition.curr_available_actions)
unavailable_actions_mask_list.append(
transition.curr_unavailable_actions_mask
)
history_summary_batch = self._history_summarization_module(
torch.cat(state_list)
).detach()
action_representation_batch = self._action_representation_module(
torch.cat(action_list)
)

state_values = self._critic(history_summary_batch).detach()
action_probs = (
self._actor.get_action_prob(
state_batch=history_summary_batch,
action_batch=action_representation_batch,
)
.detach()
.unsqueeze(-1)
)
# Obtain the value of the most recent state stored in the replay buffer.
# This value is used to compute the generalized advantage estimation (gae)
# and the truncated lambda return for all states in the replay buffer.
next_value = self._critic(
self._history_summarization_module(replay_buffer.memory[-1].next_state)
).detach()[
0
] # shape (1,)
gae = torch.tensor([0.0]).to(state_values.device)
for i, transition in enumerate(reversed(replay_buffer.memory)):
td_error = (
transition.reward
+ self._discount_factor * next_value * (~transition.done)
- state_values[i]
)
gae = (
td_error
+ self._discount_factor
* self._trace_decay_param
* (~transition.done)
* gae
)
assert isinstance(transition, OnPolicyTransition)
transition.gae = gae
# truncated lambda return of the state
transition.lam_return = gae + state_values[i]
# action probabilities from the current policy
transition.action_probs = action_probs[i]
next_value = state_values[i]
30 changes: 27 additions & 3 deletions pearl/policy_learners/sequential_decision_making/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
#

from typing import List, Optional, Type
from typing import Any, Dict, List, Optional, Type

from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
Expand Down Expand Up @@ -36,6 +36,12 @@
ActorCriticBase,
single_critic_state_value_loss,
)
from pearl.replay_buffers.replay_buffer import ReplayBuffer
from pearl.replay_buffers.sequential_decision_making.on_policy_replay_buffer import (
OnPolicyReplayBuffer,
OnPolicyTransition,
OnPolicyTransitionBatch,
)
from pearl.replay_buffers.transition import TransitionBatch


Expand All @@ -58,7 +64,8 @@ def __init__(
critic_network_type: Type[ValueNetwork] = VanillaValueNetwork,
exploration_module: Optional[ExplorationModule] = None,
discount_factor: float = 0.99,
training_rounds: int = 1,
training_rounds: int = 8,
batch_size: int = 64,
action_representation_module: Optional[ActionRepresentationModule] = None,
) -> None:
super(REINFORCE, self).__init__(
Expand All @@ -80,13 +87,14 @@ def __init__(
else PropensityExploration(),
discount_factor=discount_factor,
training_rounds=training_rounds,
batch_size=0, # REINFORCE does not use batch size
batch_size=batch_size,
is_action_continuous=False,
on_policy=True,
action_representation_module=action_representation_module,
)

def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
assert isinstance(batch, OnPolicyTransitionBatch)
state_batch = (
batch.state
) # (batch_size x state_dim) note that here batch_size = episode length
Expand All @@ -108,9 +116,25 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:

def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
assert self._use_critic, "can not compute critic loss without critic"
assert isinstance(batch, OnPolicyTransitionBatch)
assert batch.cum_reward is not None
return single_critic_state_value_loss(
state_batch=batch.state,
expected_target_batch=batch.cum_reward,
critic=self._critic,
)

def learn(self, replay_buffer: ReplayBuffer) -> Dict[str, Any]:
assert type(replay_buffer) is OnPolicyReplayBuffer
assert len(replay_buffer.memory) > 0
# compute return for all states in the buffer
cum_reward = self._critic(
self._history_summarization_module(replay_buffer.memory[-1].next_state)
).detach() * (~replay_buffer.memory[-1].done)
for transition in reversed(replay_buffer.memory):
cum_reward += transition.reward
assert isinstance(transition, OnPolicyTransition)
transition.cum_reward = cum_reward
# sample from replay buffer and learn
result = super().learn(replay_buffer)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tenso
# next_q = next_q1 if random_index == 0 else next_q2

next_state_action_values = next_q.view(
self.batch_size, -1
next_state_batch.shape[0], -1
) # (batch_size x action_space_size)

# Make sure that unavailable actions' Q values are assigned to 0.0
Expand Down Expand Up @@ -210,7 +210,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:
) # (batch_size x action_space_size)

state_action_values = q.view(
(self.batch_size, self.action_representation_module.max_number_actions)
(state_batch.shape[0], self.action_representation_module.max_number_actions)
) # (batch_size x action_space_size)

if unavailable_actions_mask is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ def _get_next_state_expected_values(self, batch: TransitionBatch) -> torch.Tenso

# clipped double q-learning (reduce overestimation bias)
next_q = torch.minimum(next_q1, next_q2) # shape: (batch_size)
next_state_action_values = next_q.view(
self.batch_size, 1
) # shape: (batch_size x 1)
next_state_action_values = next_q.unsqueeze(-1) # shape: (batch_size x 1)

# add entropy regularization
next_state_action_values = next_state_action_values - (
Expand All @@ -196,7 +194,7 @@ def _actor_loss(self, batch: TransitionBatch) -> torch.Tensor:

# clipped double q learning (reduce overestimation bias)
q = torch.minimum(q1, q2) # shape: (batch_size)
state_action_values = q.view((self.batch_size, 1)) # shape: (batch_size x 1)
state_action_values = q.unsqueeze(-1) # shape: (batch_size x 1)

loss = (self._entropy_coef * action_batch_log_prob - state_action_values).mean()

Expand Down
4 changes: 2 additions & 2 deletions pearl/replay_buffers/sequential_decision_making/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from .fifo_off_policy_replay_buffer import FIFOOffPolicyReplayBuffer
from .fifo_on_policy_replay_buffer import FIFOOnPolicyReplayBuffer
from .hindsight_experience_replay_buffer import HindsightExperienceReplayBuffer
from .on_policy_episodic_replay_buffer import OnPolicyEpisodicReplayBuffer
from .on_policy_replay_buffer import OnPolicyReplayBuffer

__all__ = [
"BootstrapReplayBuffer",
"FIFOOffPolicyReplayBuffer",
"FIFOOnPolicyReplayBuffer",
"HindsightExperienceReplayBuffer",
"OnPolicyEpisodicReplayBuffer",
"OnPolicyReplayBuffer",
]
Loading

0 comments on commit 611928a

Please sign in to comment.