From 3e9a249c56611fa8bb330df3f93840cff5654e93 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Wed, 13 Sep 2023 12:08:36 +0900 Subject: [PATCH] [Refactor] move A2C in separate folder for clarity --- rl4co/models/rl/a2c/baseline.py | 32 ++++++++++++++ rl4co/models/rl/reinforce/__init__.py | 38 ++++++++++++++++ rl4co/models/rl/reinforce/baselines.py | 60 -------------------------- rl4co/models/rl/reinforce/reinforce.py | 2 +- 4 files changed, 71 insertions(+), 61 deletions(-) create mode 100644 rl4co/models/rl/a2c/baseline.py create mode 100644 rl4co/models/rl/reinforce/__init__.py diff --git a/rl4co/models/rl/a2c/baseline.py b/rl4co/models/rl/a2c/baseline.py new file mode 100644 index 00000000..8b8a0379 --- /dev/null +++ b/rl4co/models/rl/a2c/baseline.py @@ -0,0 +1,32 @@ +import torch.nn.functional as F +import torch.nn as nn + +from rl4co.models.rl.common.critic import CriticNetwork +from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline +from rl4co import utils + + +log = utils.get_pylogger(__name__) + + +class CriticBaseline(REINFORCEBaseline): + """Critic baseline: use critic network as baseline for REINFORCE (Policy Gradients). + We separate A2C from REINFORCE for clarity, although they are essentially the same algorithm with different baselines. + + Args: + critic: Critic network to use as baseline. If None, create a new critic network based on the environment + """ + + def __init__(self, critic: nn.Module = None, **unused_kw): + super(CriticBaseline, self).__init__() + self.critic = critic + + def setup(self, model, env, **kwargs): + if self.critic is None: + log.info("Creating critic network for {}".format(env.name)) + self.critic = CriticNetwork(env.name, **kwargs) + + def eval(self, x, c, env=None): + v = self.critic(x).squeeze(-1) + # detach v since actor should not backprop through baseline, only for neg_loss + return v.detach(), -F.mse_loss(v, c.detach()) diff --git a/rl4co/models/rl/reinforce/__init__.py b/rl4co/models/rl/reinforce/__init__.py new file mode 100644 index 00000000..a76630af --- /dev/null +++ b/rl4co/models/rl/reinforce/__init__.py @@ -0,0 +1,38 @@ +from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, NoBaseline, SharedBaseline, ExponentialBaseline, RolloutBaseline, WarmupBaseline +from rl4co.models.rl.a2c.baseline import CriticBaseline + + +REINFORCE_BASELINES_REGISTRY = { + "no": NoBaseline, + "shared": SharedBaseline, + "exponential": ExponentialBaseline, + "critic": CriticBaseline, + "rollout_only": RolloutBaseline, + "warmup": WarmupBaseline, +} + + +def get_reinforce_baseline(name, **kw): + """Get a REINFORCE baseline by name + The rollout baseline default to warmup baseline with one epoch of + exponential baseline and the greedy rollout + """ + if name == "warmup": + inner_baseline = kw.get("baseline", "rollout") + if not isinstance(inner_baseline, REINFORCEBaseline): + inner_baseline = get_reinforce_baseline(inner_baseline, **kw) + return WarmupBaseline(inner_baseline, **kw) + elif name == "rollout": + warmup_epochs = kw.get("n_epochs", 1) + warmup_exp_beta = kw.get("exp_beta", 0.8) + bl_alpha = kw.get("bl_alpha", 0.05) + return WarmupBaseline( + RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta + ) + + baseline_cls = REINFORCE_BASELINES_REGISTRY.get(name, None) + if baseline_cls is None: + raise ValueError( + f"Unknown baseline {baseline_cls}. Available baselines: {REINFORCE_BASELINES_REGISTRY.keys()}" + ) + return baseline_cls(**kw) \ No newline at end of file diff --git a/rl4co/models/rl/reinforce/baselines.py b/rl4co/models/rl/reinforce/baselines.py index 81d60fe0..c0041ca1 100644 --- a/rl4co/models/rl/reinforce/baselines.py +++ b/rl4co/models/rl/reinforce/baselines.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from scipy.stats import ttest_rel from torch.utils.data import DataLoader @@ -10,7 +9,6 @@ from rl4co import utils from rl4co.data.dataset import ExtraKeyDataset, tensordict_collate_fn -from rl4co.models.rl.common.critic import CriticNetwork log = utils.get_pylogger(__name__) @@ -125,28 +123,6 @@ def epoch_callback(self, *args, **kw): log.info("Set warmup alpha = {}".format(self.alpha)) -class CriticBaseline(REINFORCEBaseline): - """Critic baseline: use critic network as baseline - - Args: - critic: Critic network to use as baseline. If None, create a new critic network based on the environment - """ - - def __init__(self, critic: nn.Module = None, **unused_kw): - super(CriticBaseline, self).__init__() - self.critic = critic - - def setup(self, model, env, **kwargs): - if self.critic is None: - log.info("Creating critic network for {}".format(env.name)) - self.critic = CriticNetwork(env.name, **kwargs) - - def eval(self, x, c, env=None): - v = self.critic(x).squeeze(-1) - # detach v since actor should not backprop through baseline, only for neg_loss - return v.detach(), -F.mse_loss(v, c.detach()) - - class RolloutBaseline(REINFORCEBaseline): """Rollout baseline: use greedy rollout as baseline @@ -262,39 +238,3 @@ def __setstate__(self, state): """Restore datasets after unpickling. Will be restored in setup""" self.__dict__.update(state) self.dataset = None - - -REINFORCE_BASELINES_REGISTRY = { - "no": NoBaseline, - "shared": SharedBaseline, - "exponential": ExponentialBaseline, - "critic": CriticBaseline, - "rollout_only": RolloutBaseline, - "warmup": WarmupBaseline, -} - - -def get_reinforce_baseline(name, **kw): - """Get a REINFORCE baseline by name - The rollout baseline default to warmup baseline with one epoch of - exponential baseline and the greedy rollout - """ - if name == "warmup": - inner_baseline = kw.get("baseline", "rollout") - if not isinstance(inner_baseline, REINFORCEBaseline): - inner_baseline = get_reinforce_baseline(inner_baseline, **kw) - return WarmupBaseline(inner_baseline, **kw) - elif name == "rollout": - warmup_epochs = kw.get("n_epochs", 1) - warmup_exp_beta = kw.get("exp_beta", 0.8) - bl_alpha = kw.get("bl_alpha", 0.05) - return WarmupBaseline( - RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta - ) - - baseline_cls = REINFORCE_BASELINES_REGISTRY.get(name, None) - if baseline_cls is None: - raise ValueError( - f"Unknown baseline {baseline_cls}. Available baselines: {REINFORCE_BASELINES_REGISTRY.keys()}" - ) - return baseline_cls(**kw) diff --git a/rl4co/models/rl/reinforce/reinforce.py b/rl4co/models/rl/reinforce/reinforce.py index 1003198c..16a2c7ba 100644 --- a/rl4co/models/rl/reinforce/reinforce.py +++ b/rl4co/models/rl/reinforce/reinforce.py @@ -10,7 +10,7 @@ from rl4co.envs.common.base import RL4COEnvBase from rl4co.models.rl.common.base import RL4COLitModule -from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline +from rl4co.models.rl.reinforce import REINFORCEBaseline, get_reinforce_baseline from rl4co.utils.lightning import get_lightning_device from rl4co.utils.pylogger import get_pylogger