-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] move A2C in separate folder for clarity
- Loading branch information
Showing
4 changed files
with
71 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
|
||
from rl4co.models.rl.common.critic import CriticNetwork | ||
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline | ||
from rl4co import utils | ||
|
||
|
||
log = utils.get_pylogger(__name__) | ||
|
||
|
||
class CriticBaseline(REINFORCEBaseline): | ||
"""Critic baseline: use critic network as baseline for REINFORCE (Policy Gradients). | ||
We separate A2C from REINFORCE for clarity, although they are essentially the same algorithm with different baselines. | ||
Args: | ||
critic: Critic network to use as baseline. If None, create a new critic network based on the environment | ||
""" | ||
|
||
def __init__(self, critic: nn.Module = None, **unused_kw): | ||
super(CriticBaseline, self).__init__() | ||
self.critic = critic | ||
|
||
def setup(self, model, env, **kwargs): | ||
if self.critic is None: | ||
log.info("Creating critic network for {}".format(env.name)) | ||
self.critic = CriticNetwork(env.name, **kwargs) | ||
|
||
def eval(self, x, c, env=None): | ||
v = self.critic(x).squeeze(-1) | ||
# detach v since actor should not backprop through baseline, only for neg_loss | ||
return v.detach(), -F.mse_loss(v, c.detach()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, NoBaseline, SharedBaseline, ExponentialBaseline, RolloutBaseline, WarmupBaseline | ||
from rl4co.models.rl.a2c.baseline import CriticBaseline | ||
|
||
|
||
REINFORCE_BASELINES_REGISTRY = { | ||
"no": NoBaseline, | ||
"shared": SharedBaseline, | ||
"exponential": ExponentialBaseline, | ||
"critic": CriticBaseline, | ||
"rollout_only": RolloutBaseline, | ||
"warmup": WarmupBaseline, | ||
} | ||
|
||
|
||
def get_reinforce_baseline(name, **kw): | ||
"""Get a REINFORCE baseline by name | ||
The rollout baseline default to warmup baseline with one epoch of | ||
exponential baseline and the greedy rollout | ||
""" | ||
if name == "warmup": | ||
inner_baseline = kw.get("baseline", "rollout") | ||
if not isinstance(inner_baseline, REINFORCEBaseline): | ||
inner_baseline = get_reinforce_baseline(inner_baseline, **kw) | ||
return WarmupBaseline(inner_baseline, **kw) | ||
elif name == "rollout": | ||
warmup_epochs = kw.get("n_epochs", 1) | ||
warmup_exp_beta = kw.get("exp_beta", 0.8) | ||
bl_alpha = kw.get("bl_alpha", 0.05) | ||
return WarmupBaseline( | ||
RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta | ||
) | ||
|
||
baseline_cls = REINFORCE_BASELINES_REGISTRY.get(name, None) | ||
if baseline_cls is None: | ||
raise ValueError( | ||
f"Unknown baseline {baseline_cls}. Available baselines: {REINFORCE_BASELINES_REGISTRY.keys()}" | ||
) | ||
return baseline_cls(**kw) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters