Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 8145d6e commit 401d58b
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 57 deletions.
4 changes: 3 additions & 1 deletion sota-implementations/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def main(cfg: "DictConfig"): # noqa: F821
compile_mode = "reduce-overhead"

# Create collector
collector = make_collector(cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode)
collector = make_collector(
cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode
)

# Create loss
loss_module, target_net_updater = make_loss(cfg.loss, model)
Expand Down
36 changes: 20 additions & 16 deletions sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import warnings

import hydra
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.record import VideoRecorder
from torchrl._utils import compile_with_warmup


@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
Expand All @@ -25,12 +23,16 @@ def main(cfg: "DictConfig"): # noqa: F821
import tqdm

from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_atari import eval_model, make_parallel_env, make_ppo_models

Expand Down Expand Up @@ -79,9 +81,10 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create data buffer
sampler = SamplerWithoutReplacement()
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(frames_per_batch),
storage=LazyTensorStorage(frames_per_batch, compilable=cfg.compile.compile),
sampler=sampler,
batch_size=mini_batch_size,
compilable=cfg.compile.compile,
)

# Create loss and adv modules
Expand Down Expand Up @@ -141,7 +144,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Main loop
collected_frames = 0
num_network_updates = 0
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)
pbar = tqdm.tqdm(total=total_frames)
num_mini_batches = frames_per_batch // mini_batch_size
total_network_updates = (
Expand All @@ -152,7 +155,7 @@ def update(batch, num_network_updates):
optim.zero_grad(set_to_none=True)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
alpha = torch.ones((), device=device)
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in optim.param_groups:
Expand All @@ -165,9 +168,7 @@ def update(batch, num_network_updates):

# Forward pass PPO loss
loss = loss_module(batch)
loss_sum = (
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
)
loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
# Backward pass
loss_sum.backward()
torch.nn.utils.clip_grad_norm_(
Expand All @@ -176,12 +177,11 @@ def update(batch, num_network_updates):

# Update the networks
optim.step()
return loss.detach().set("alpha", alpha)

return loss.detach().set("alpha", alpha), num_network_updates.clone()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
adv_module = torch.compile(adv_module, mode=compile_mode)
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)

if cfg.compile.cudagraphs:
warnings.warn(
Expand Down Expand Up @@ -238,7 +238,9 @@ def update(batch, num_network_updates):

for k, batch in enumerate(data_buffer):

loss = update(batch, num_network_updates=num_network_updates)
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
losses[j, k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
)
Expand All @@ -255,7 +257,9 @@ def update(batch, num_network_updates):
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC), timeit("eval"):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
Expand Down
42 changes: 23 additions & 19 deletions sota-implementations/ppo/ppo_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
import warnings

import hydra
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.objectives import group_optimizers
from torchrl.record import VideoRecorder
from torchrl._utils import compile_with_warmup


@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
Expand All @@ -26,12 +23,16 @@ def main(cfg: "DictConfig"): # noqa: F821
import tqdm

from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives import ClipPPOLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
from utils_mujoco import eval_model, make_env, make_ppo_models

Expand Down Expand Up @@ -80,9 +81,12 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create data buffer
sampler = SamplerWithoutReplacement()
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.collector.frames_per_batch),
storage=LazyTensorStorage(
cfg.collector.frames_per_batch, compilable=cfg.compile.compile
),
sampler=sampler,
batch_size=cfg.loss.mini_batch_size,
compilable=cfg.compile.compile,
)

# Create loss and adv modules
Expand Down Expand Up @@ -138,12 +142,10 @@ def main(cfg: "DictConfig"): # noqa: F821
def update(batch, num_network_updates):
optim.zero_grad(set_to_none=True)
# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
alpha = torch.ones((), device=device)
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
for group in optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
Expand All @@ -160,12 +162,11 @@ def update(batch, num_network_updates):

# Update the networks
optim.step()
return loss.detach().set("alpha", alpha)

return loss.detach().set("alpha", alpha), num_network_updates.clone()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
adv_module = torch.compile(adv_module, mode=compile_mode)
update = compile_with_warmup(update, mode=compile_mode, warmup=1)
adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1)

if cfg.compile.cudagraphs:
warnings.warn(
Expand All @@ -177,10 +178,9 @@ def update(batch, num_network_updates):

# Main loop
collected_frames = 0
num_network_updates = 0
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)
pbar = tqdm.tqdm(total=cfg.collector.total_frames)


# extract cfg variables
cfg_loss_ppo_epochs = cfg.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.optim.anneal_lr
Expand Down Expand Up @@ -226,7 +226,9 @@ def update(batch, num_network_updates):
data_buffer.extend(data_reshape)

for k, batch in enumerate(data_buffer):
loss = update(batch, num_network_updates=num_network_updates)
loss, num_network_updates = update(
batch, num_network_updates=num_network_updates
)
losses[j, k] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
)
Expand All @@ -245,7 +247,9 @@ def update(batch, num_network_updates):
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC), timeit("eval"):
with torch.no_grad(), set_exploration_type(
ExplorationType.DETERMINISTIC
), timeit("eval"):
if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < (
i * frames_in_batch
) // cfg_logger_test_interval:
Expand Down
27 changes: 16 additions & 11 deletions sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ActorValueOperator,
ConvNet,
MLP,
OneHotCategorical,
ProbabilisticActor,
TanhNormal,
ValueOperator,
Expand All @@ -51,6 +50,7 @@ def make_base_env(env_name="BreakoutNoFrameskip-v4", frame_skip=4, is_test=False
from_pixels=True,
pixels_only=False,
device="cpu",
categorical_action_encoding=True,
)
env = TransformedEnv(env)
env.append_transform(NoopResetEnv(noops=30, random=True))
Expand Down Expand Up @@ -86,22 +86,22 @@ def make_parallel_env(env_name, num_envs, device, is_test=False):
# --------------------------------------------------------------------


def make_ppo_modules_pixels(proof_environment):
def make_ppo_modules_pixels(proof_environment, device):

# Define input shape
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox):
num_outputs = proof_environment.action_spec_unbatched.space.n
distribution_class = OneHotCategorical
distribution_class = torch.distributions.Categorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = proof_environment.action_spec_unbatched.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
}

# Define input keys
Expand All @@ -113,6 +113,7 @@ def make_ppo_modules_pixels(proof_environment):
num_cells=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
device=device,
)
common_cnn_output = common_cnn(torch.ones(input_shape))
common_mlp = MLP(
Expand All @@ -121,6 +122,7 @@ def make_ppo_modules_pixels(proof_environment):
activate_last_layer=True,
out_features=512,
num_cells=[],
device=device,
)
common_mlp_output = common_mlp(common_cnn_output)

Expand All @@ -137,6 +139,7 @@ def make_ppo_modules_pixels(proof_environment):
out_features=num_outputs,
activation_class=torch.nn.ReLU,
num_cells=[],
device=device,
)
policy_module = TensorDictModule(
module=policy_net,
Expand All @@ -148,7 +151,7 @@ def make_ppo_modules_pixels(proof_environment):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=proof_environment.full_action_spec_unbatched,
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand All @@ -161,6 +164,7 @@ def make_ppo_modules_pixels(proof_environment):
in_features=common_mlp_output.shape[-1],
out_features=1,
num_cells=[],
device=device,
)
value_module = ValueOperator(
value_net,
Expand All @@ -170,11 +174,12 @@ def make_ppo_modules_pixels(proof_environment):
return common_module, policy_module, value_module


def make_ppo_models(env_name):
def make_ppo_models(env_name, device):

proof_environment = make_parallel_env(env_name, 1, device="cpu")
proof_environment = make_parallel_env(env_name, 1, device=device)
common_module, policy_module, value_module = make_ppo_modules_pixels(
proof_environment
proof_environment,
device=device,
)

# Wrap modules in a single ActorCritic operator
Expand All @@ -185,8 +190,8 @@ def make_ppo_models(env_name):
)

with torch.no_grad():
td = proof_environment.rollout(max_steps=100, break_when_any_done=False)
td = actor_critic(td)
td = proof_environment.fake_tensordict().expand(10)
actor_critic(td)
del td

actor = actor_critic.get_policy_operator()
Expand Down
16 changes: 9 additions & 7 deletions sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False)
# --------------------------------------------------------------------


def make_ppo_models_state(proof_environment):
def make_ppo_models_state(proof_environment, device):

# Define input shape
input_shape = proof_environment.observation_spec["observation"].shape
Expand All @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec_unbatched.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
}

Expand All @@ -63,6 +63,7 @@ def make_ppo_models_state(proof_environment):
activation_class=torch.nn.Tanh,
out_features=num_outputs, # predict only loc
num_cells=[64, 64],
device=device,
)

# Initialize policy weights
Expand All @@ -87,7 +88,7 @@ def make_ppo_models_state(proof_environment):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=proof_environment.full_action_spec_unbatched,
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand All @@ -100,6 +101,7 @@ def make_ppo_models_state(proof_environment):
activation_class=torch.nn.Tanh,
out_features=1,
num_cells=[64, 64],
device=device,
)

# Initialize value weights
Expand All @@ -117,9 +119,9 @@ def make_ppo_models_state(proof_environment):
return policy_module, value_module


def make_ppo_models(env_name):
proof_environment = make_env(env_name, device="cpu")
actor, critic = make_ppo_models_state(proof_environment)
def make_ppo_models(env_name, device):
proof_environment = make_env(env_name, device=device)
actor, critic = make_ppo_models_state(proof_environment, device=device)
return actor, critic


Expand Down
Loading

0 comments on commit 401d58b

Please sign in to comment.