Skip to content

Commit

Permalink
[Feature] GAIL compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 98c7602ec0343d7a83cb19bddeb579484c42e77e
Pull Request resolved: #2573
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent f149811 commit 6482766
Show file tree
Hide file tree
Showing 24 changed files with 249 additions and 119 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=proof_environment.single_full_action_spec.to(device),
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=proof_environment.single_full_action_spec.to(device),
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ collector:
multi_step: 0
init_random_frames: 1000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 200
annealing_frames: 10000
eps_start: 1.0
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ collector:
multi_step: 0
init_random_frames: 5_000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 1000


Expand Down
8 changes: 7 additions & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,20 @@ def make_collector(
cudagraph=False,
):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
collector = SyncDataCollector(
train_env,
actor_model_explore,
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
max_frames_per_traj=cfg.collector.max_frames_per_traj,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def _dreamer_make_actor_real(
default_interaction_type=InteractionType.DETERMINISTIC,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=proof_environment.single_full_action_spec.to("cpu"),
spec=proof_environment.full_action_spec_unbatched.to("cpu"),
),
),
SafeModule(
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/gail/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ gail:
gp_lambda: 10.0
device: null

compile:
compile: False
compile_mode: default
cudagraphs: False

replay_buffer:
dataset: halfcheetah-expert-v2
batch_size: 256
185 changes: 113 additions & 72 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,33 @@
"""
from __future__ import annotations

import warnings

import hydra
import numpy as np
import torch
import tqdm

from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
from ppo_utils import eval_model, make_env, make_ppo_models
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss, GAILLoss
from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger


torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()
Expand Down Expand Up @@ -71,25 +79,20 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
actor, critic = make_ppo_models(
cfg.env.env_name, compile=cfg.compile.compile, device=device
)

# Create data buffer
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
storage=LazyTensorStorage(
cfg.ppo.collector.frames_per_batch,
device=device,
compilable=cfg.compile.compile,
),
sampler=SamplerWithoutReplacement(),
batch_size=cfg.ppo.loss.mini_batch_size,
compilable=cfg.compile.compile,
)

# Create loss and adv modules
Expand All @@ -98,6 +101,7 @@ def main(cfg: "DictConfig"): # noqa: F821
lmbda=cfg.ppo.loss.gae_lambda,
value_network=critic,
average_gae=False,
device=device,
)

loss_module = ClipPPOLoss(
Expand All @@ -111,8 +115,35 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
actor_optim = torch.optim.Adam(
actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
)
critic_optim = torch.optim.Adam(
critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
)
optim = group_optimizers(actor_optim, critic_optim)
del actor_optim, critic_optim

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
Expand Down Expand Up @@ -140,32 +171,9 @@ def main(cfg: "DictConfig"): # noqa: F821
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
test_env.eval()
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)

# Training loop
collected_frames = 0
num_network_updates = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)
def update(data, expert_data, num_network_updates=num_network_updates):
# Add collector data to expert data
expert_data.set(
discriminator_loss.tensor_keys.collector_action,
Expand All @@ -178,9 +186,9 @@ def main(cfg: "DictConfig"): # noqa: F821
d_loss = discriminator_loss(expert_data)

# Backward pass
discriminator_optim.zero_grad()
d_loss.get("loss").backward()
discriminator_optim.step()
discriminator_optim.zero_grad(set_to_none=True)

# Compute discriminator reward
with torch.no_grad():
Expand All @@ -190,40 +198,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set discriminator rewards to tensordict
data.set(("next", "reward"), d_rewards)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
# Update PPO
for _ in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.empty()
data_buffer.extend(data_reshape)

for _, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)
for batch in data_buffer:
optim.zero_grad(set_to_none=True)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
alpha = torch.ones((), device=device)
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
for group in optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
Expand All @@ -235,20 +228,68 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()
(actor_loss + critic_loss).backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()
optim.step()
return {"dloss": d_loss, "alpha": alpha}

if cfg.compile.compile:
update = compile_with_warmup(update, warmup=2, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Training loop
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)

metadata = update(data, expert_data)
d_loss = metadata["dloss"]
alpha = metadata["alpha"]

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]

log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)

log_info.update(
{
"train/actor_loss": actor_loss.item(),
"train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"].item(),
# "train/actor_loss": actor_loss.item(),
# "train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"],
"train/lr": alpha * cfg_optim_lr,
"train/clip_epsilon": (
alpha * cfg_loss_clip_epsilon
Expand Down
Loading

1 comment on commit 6482766

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 6482766 Previous: f149811 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 463.69768903993446 iter/sec (stddev: 0.03966088341468626) 1506.5820596806993 iter/sec (stddev: 0.00005460179029479441) 3.25

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.