Skip to content

Commit

Permalink
[Minor] log , , and by default
Browse files Browse the repository at this point in the history
  • Loading branch information
Junyoungpark committed Sep 11, 2023
1 parent 90a2246 commit c27c429
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions rl4co/models/rl/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader

from rl4co.envs.common.base import RL4COEnvBase
Expand Down Expand Up @@ -69,15 +68,16 @@ def __init__(
entropy_lambda: float = 0.0, # lambda of entropy bonus
normalize_adv: bool = False, # whether to normalize advantage
max_grad_norm: float = 0.5, # max gradient norm
metrics: dict = {
"train": ["loss", "surrogate_loss", "value_loss", "entropy"],
},
**kwargs,
):
super().__init__(env, policy, **kwargs)
super().__init__(env, policy, metrics=metrics, **kwargs)
self.automatic_optimization = False # PPO uses custom optimization routine
self.critic = critic

if isinstance(mini_batch_size, float) and (
mini_batch_size <= 0 or mini_batch_size > 1
):
if isinstance(mini_batch_size, float) and (mini_batch_size <= 0 or mini_batch_size > 1):
default_mini_batch_fraction = 0.25
log.warning(
f"mini_batch_size must be an integer or a float in the range (0, 1], got {mini_batch_size}. Setting mini_batch_size to {default_mini_batch_fraction}."
Expand Down Expand Up @@ -147,14 +147,10 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str):

for _ in range(self.ppo_cfg["ppo_epochs"]): # PPO inner epoch, K
for sub_td in dataloader:
ll, entropy = self.policy.evaluate_action(
sub_td, action=sub_td["action"]
)
ll, entropy = self.policy.evaluate_action(sub_td, action=sub_td["action"])

# Compute the ratio of probabilities of new and old actions
ratio = torch.exp(ll.sum(dim=-1) - sub_td["log_prob"]).view(
-1, 1
) # [batch, 1]
ratio = torch.exp(ll.sum(dim=-1) - sub_td["log_prob"]).view(-1, 1) # [batch, 1]

# Compute the advantage
value_pred = self.critic(sub_td) # [batch, 1]
Expand Down

0 comments on commit c27c429

Please sign in to comment.