Skip to content

Commit

Permalink
[Refactor] move A2C in separate folder for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Sep 13, 2023
1 parent d0144a2 commit 3e9a249
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 61 deletions.
32 changes: 32 additions & 0 deletions rl4co/models/rl/a2c/baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch.nn.functional as F
import torch.nn as nn

from rl4co.models.rl.common.critic import CriticNetwork
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline
from rl4co import utils


log = utils.get_pylogger(__name__)


class CriticBaseline(REINFORCEBaseline):
"""Critic baseline: use critic network as baseline for REINFORCE (Policy Gradients).
We separate A2C from REINFORCE for clarity, although they are essentially the same algorithm with different baselines.
Args:
critic: Critic network to use as baseline. If None, create a new critic network based on the environment
"""

def __init__(self, critic: nn.Module = None, **unused_kw):
super(CriticBaseline, self).__init__()
self.critic = critic

def setup(self, model, env, **kwargs):
if self.critic is None:
log.info("Creating critic network for {}".format(env.name))
self.critic = CriticNetwork(env.name, **kwargs)

def eval(self, x, c, env=None):
v = self.critic(x).squeeze(-1)
# detach v since actor should not backprop through baseline, only for neg_loss
return v.detach(), -F.mse_loss(v, c.detach())
38 changes: 38 additions & 0 deletions rl4co/models/rl/reinforce/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, NoBaseline, SharedBaseline, ExponentialBaseline, RolloutBaseline, WarmupBaseline
from rl4co.models.rl.a2c.baseline import CriticBaseline


REINFORCE_BASELINES_REGISTRY = {
"no": NoBaseline,
"shared": SharedBaseline,
"exponential": ExponentialBaseline,
"critic": CriticBaseline,
"rollout_only": RolloutBaseline,
"warmup": WarmupBaseline,
}


def get_reinforce_baseline(name, **kw):
"""Get a REINFORCE baseline by name
The rollout baseline default to warmup baseline with one epoch of
exponential baseline and the greedy rollout
"""
if name == "warmup":
inner_baseline = kw.get("baseline", "rollout")
if not isinstance(inner_baseline, REINFORCEBaseline):
inner_baseline = get_reinforce_baseline(inner_baseline, **kw)
return WarmupBaseline(inner_baseline, **kw)

Check warning on line 24 in rl4co/models/rl/reinforce/__init__.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/rl/reinforce/__init__.py#L21-L24

Added lines #L21 - L24 were not covered by tests
elif name == "rollout":
warmup_epochs = kw.get("n_epochs", 1)
warmup_exp_beta = kw.get("exp_beta", 0.8)
bl_alpha = kw.get("bl_alpha", 0.05)
return WarmupBaseline(
RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta
)

baseline_cls = REINFORCE_BASELINES_REGISTRY.get(name, None)
if baseline_cls is None:
raise ValueError(

Check warning on line 35 in rl4co/models/rl/reinforce/__init__.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/rl/reinforce/__init__.py#L35

Added line #L35 was not covered by tests
f"Unknown baseline {baseline_cls}. Available baselines: {REINFORCE_BASELINES_REGISTRY.keys()}"
)
return baseline_cls(**kw)
60 changes: 0 additions & 60 deletions rl4co/models/rl/reinforce/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy.stats import ttest_rel
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from rl4co import utils
from rl4co.data.dataset import ExtraKeyDataset, tensordict_collate_fn
from rl4co.models.rl.common.critic import CriticNetwork

log = utils.get_pylogger(__name__)

Expand Down Expand Up @@ -125,28 +123,6 @@ def epoch_callback(self, *args, **kw):
log.info("Set warmup alpha = {}".format(self.alpha))


class CriticBaseline(REINFORCEBaseline):
"""Critic baseline: use critic network as baseline
Args:
critic: Critic network to use as baseline. If None, create a new critic network based on the environment
"""

def __init__(self, critic: nn.Module = None, **unused_kw):
super(CriticBaseline, self).__init__()
self.critic = critic

def setup(self, model, env, **kwargs):
if self.critic is None:
log.info("Creating critic network for {}".format(env.name))
self.critic = CriticNetwork(env.name, **kwargs)

def eval(self, x, c, env=None):
v = self.critic(x).squeeze(-1)
# detach v since actor should not backprop through baseline, only for neg_loss
return v.detach(), -F.mse_loss(v, c.detach())


class RolloutBaseline(REINFORCEBaseline):
"""Rollout baseline: use greedy rollout as baseline
Expand Down Expand Up @@ -262,39 +238,3 @@ def __setstate__(self, state):
"""Restore datasets after unpickling. Will be restored in setup"""
self.__dict__.update(state)
self.dataset = None


REINFORCE_BASELINES_REGISTRY = {
"no": NoBaseline,
"shared": SharedBaseline,
"exponential": ExponentialBaseline,
"critic": CriticBaseline,
"rollout_only": RolloutBaseline,
"warmup": WarmupBaseline,
}


def get_reinforce_baseline(name, **kw):
"""Get a REINFORCE baseline by name
The rollout baseline default to warmup baseline with one epoch of
exponential baseline and the greedy rollout
"""
if name == "warmup":
inner_baseline = kw.get("baseline", "rollout")
if not isinstance(inner_baseline, REINFORCEBaseline):
inner_baseline = get_reinforce_baseline(inner_baseline, **kw)
return WarmupBaseline(inner_baseline, **kw)
elif name == "rollout":
warmup_epochs = kw.get("n_epochs", 1)
warmup_exp_beta = kw.get("exp_beta", 0.8)
bl_alpha = kw.get("bl_alpha", 0.05)
return WarmupBaseline(
RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta
)

baseline_cls = REINFORCE_BASELINES_REGISTRY.get(name, None)
if baseline_cls is None:
raise ValueError(
f"Unknown baseline {baseline_cls}. Available baselines: {REINFORCE_BASELINES_REGISTRY.keys()}"
)
return baseline_cls(**kw)
2 changes: 1 addition & 1 deletion rl4co/models/rl/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.common.base import RL4COLitModule
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline
from rl4co.models.rl.reinforce import REINFORCEBaseline, get_reinforce_baseline
from rl4co.utils.lightning import get_lightning_device
from rl4co.utils.pylogger import get_pylogger

Expand Down

0 comments on commit 3e9a249

Please sign in to comment.