Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent 89b1cb3 commit 7a4cc1e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
28 changes: 18 additions & 10 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
"""
from __future__ import annotations

import warnings

import hydra
import numpy as np
import torch
import tqdm

from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
from ppo_utils import eval_model, make_env, make_ppo_models
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
Expand Down Expand Up @@ -72,8 +75,9 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name, compile=cfg.compile.compile)
actor, critic = actor.to(device), critic.to(device)
actor, critic = make_ppo_models(
cfg.env.env_name, compile=cfg.compile.compile, device=device
)

# Create data buffer
data_buffer = TensorDictReplayBuffer(
Expand Down Expand Up @@ -101,8 +105,12 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
actor_optim = torch.optim.Adam(
actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
)
critic_optim = torch.optim.Adam(
critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
)
optim = group_optimizers(actor_optim, critic_optim)
del actor_optim, critic_optim

Expand Down Expand Up @@ -196,12 +204,10 @@ def update(data, expert_data, num_network_updates=num_network_updates):
optim.zero_grad(set_to_none=True)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
alpha = torch.ones((), device=device)
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
for group in optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
Expand All @@ -217,7 +223,7 @@ def update(data, expert_data, num_network_updates=num_network_updates):

# Update the networks
optim.step()
return d_loss.detach()
return TensorDict(dloss=d_loss, alpha=alpha).detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
Expand Down Expand Up @@ -253,7 +259,9 @@ def update(data, expert_data, num_network_updates=num_network_updates):
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)

d_loss = update(data, expert_data)
metadata = update(data, expert_data)
d_loss = metadata["d_loss"]
alpha = metadata["alpha"]

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
Expand Down
17 changes: 10 additions & 7 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False)
# --------------------------------------------------------------------


def make_ppo_models_state(proof_environment, compile):
def make_ppo_models_state(proof_environment, compile, device):

# Define input shape
input_shape = proof_environment.observation_spec["observation"].shape
Expand All @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment, compile):
num_outputs = proof_environment.action_spec_unbatched.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
"safe_tanh": not compile,
}
Expand All @@ -64,6 +64,7 @@ def make_ppo_models_state(proof_environment, compile):
activation_class=torch.nn.Tanh,
out_features=num_outputs, # predict only loc
num_cells=[64, 64],
device=device,
)

# Initialize policy weights
Expand All @@ -88,7 +89,7 @@ def make_ppo_models_state(proof_environment, compile):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=proof_environment.single_full_action_spec,
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down Expand Up @@ -118,9 +119,11 @@ def make_ppo_models_state(proof_environment, compile):
return policy_module, value_module


def make_ppo_models(env_name, compile):
proof_environment = make_env(env_name, device="cpu")
actor, critic = make_ppo_models_state(proof_environment, compile=compile)
def make_ppo_models(env_name, compile, device):
proof_environment = make_env(env_name, device=device)
actor, critic = make_ppo_models_state(
proof_environment, compile=compile, device=device
)
return actor, critic


Expand Down

0 comments on commit 7a4cc1e

Please sign in to comment.