diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index eb063fcd139..456e31f04ad 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -103,7 +103,9 @@ def main(cfg: "DictConfig"): # noqa: F821 compile_mode = "reduce-overhead" # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode) + collector = make_collector( + cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode + ) # Create loss loss_module, target_net_updater = make_loss(cfg.loss, model) diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index ec6de455ae1..abc6eb4b537 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -12,10 +12,8 @@ import warnings import hydra -from tensordict.nn import CudaGraphModule -from torchrl._utils import logger as torchrl_logger, timeit -from torchrl.record import VideoRecorder +from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_atari", version_base="1.1") @@ -25,12 +23,16 @@ def main(cfg: "DictConfig"): # noqa: F821 import tqdm from tensordict import TensorDict + from tensordict.nn import CudaGraphModule + + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models @@ -79,9 +81,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyTensorStorage(frames_per_batch, compilable=cfg.compile.compile), sampler=sampler, batch_size=mini_batch_size, + compilable=cfg.compile.compile, ) # Create loss and adv modules @@ -141,7 +144,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop collected_frames = 0 - num_network_updates = 0 + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = ( @@ -152,7 +155,7 @@ def update(batch, num_network_updates): optim.zero_grad(set_to_none=True) # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 + alpha = torch.ones((), device=device) if cfg_optim_anneal_lr: alpha = 1 - (num_network_updates / total_network_updates) for group in optim.param_groups: @@ -165,9 +168,7 @@ def update(batch, num_network_updates): # Forward pass PPO loss loss = loss_module(batch) - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] # Backward pass loss_sum.backward() torch.nn.utils.clip_grad_norm_( @@ -176,12 +177,11 @@ def update(batch, num_network_updates): # Update the networks optim.step() - return loss.detach().set("alpha", alpha) - + return loss.detach().set("alpha", alpha), num_network_updates.clone() if cfg.compile.compile: - update = torch.compile(update, mode=compile_mode) - adv_module = torch.compile(adv_module, mode=compile_mode) + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) if cfg.compile.cudagraphs: warnings.warn( @@ -238,7 +238,9 @@ def update(batch, num_network_updates): for k, batch in enumerate(data_buffer): - loss = update(batch, num_network_updates=num_network_updates) + loss, num_network_updates = update( + batch, num_network_updates=num_network_updates + ) losses[j, k] = loss.select( "loss_critic", "loss_entropy", "loss_objective" ) @@ -255,7 +257,9 @@ def update(batch, num_network_updates): ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC), timeit("eval"): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 402099c987d..b460bb2866d 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -12,11 +12,8 @@ import warnings import hydra -from tensordict.nn import CudaGraphModule -from torchrl._utils import logger as torchrl_logger, timeit -from torchrl.objectives import group_optimizers -from torchrl.record import VideoRecorder +from torchrl._utils import compile_with_warmup @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") @@ -26,12 +23,16 @@ def main(cfg: "DictConfig"): # noqa: F821 import tqdm from tensordict import TensorDict + from tensordict.nn import CudaGraphModule + + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import ClipPPOLoss + from torchrl.objectives import ClipPPOLoss, group_optimizers from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models @@ -80,9 +81,12 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch), + storage=LazyTensorStorage( + cfg.collector.frames_per_batch, compilable=cfg.compile.compile + ), sampler=sampler, batch_size=cfg.loss.mini_batch_size, + compilable=cfg.compile.compile, ) # Create loss and adv modules @@ -138,12 +142,10 @@ def main(cfg: "DictConfig"): # noqa: F821 def update(batch, num_network_updates): optim.zero_grad(set_to_none=True) # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 + alpha = torch.ones((), device=device) if cfg_optim_anneal_lr: alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - for group in critic_optim.param_groups: + for group in optim.param_groups: group["lr"] = cfg_optim_lr * alpha if cfg_loss_anneal_clip_eps: loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) @@ -160,12 +162,11 @@ def update(batch, num_network_updates): # Update the networks optim.step() - return loss.detach().set("alpha", alpha) - + return loss.detach().set("alpha", alpha), num_network_updates.clone() if cfg.compile.compile: - update = torch.compile(update, mode=compile_mode) - adv_module = torch.compile(adv_module, mode=compile_mode) + update = compile_with_warmup(update, mode=compile_mode, warmup=1) + adv_module = compile_with_warmup(adv_module, mode=compile_mode, warmup=1) if cfg.compile.cudagraphs: warnings.warn( @@ -177,10 +178,9 @@ def update(batch, num_network_updates): # Main loop collected_frames = 0 - num_network_updates = 0 + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) pbar = tqdm.tqdm(total=cfg.collector.total_frames) - # extract cfg variables cfg_loss_ppo_epochs = cfg.loss.ppo_epochs cfg_optim_anneal_lr = cfg.optim.anneal_lr @@ -226,7 +226,9 @@ def update(batch, num_network_updates): data_buffer.extend(data_reshape) for k, batch in enumerate(data_buffer): - loss = update(batch, num_network_updates=num_network_updates) + loss, num_network_updates = update( + batch, num_network_updates=num_network_updates + ) losses[j, k] = loss.select( "loss_critic", "loss_entropy", "loss_objective" ) @@ -245,7 +247,9 @@ def update(batch, num_network_updates): ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC), timeit("eval"): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 040259377ad..ab39c102106 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -31,7 +31,6 @@ ActorValueOperator, ConvNet, MLP, - OneHotCategorical, ProbabilisticActor, TanhNormal, ValueOperator, @@ -51,6 +50,7 @@ def make_base_env(env_name="BreakoutNoFrameskip-v4", frame_skip=4, is_test=False from_pixels=True, pixels_only=False, device="cpu", + categorical_action_encoding=True, ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) @@ -86,7 +86,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): # -------------------------------------------------------------------- -def make_ppo_modules_pixels(proof_environment): +def make_ppo_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -94,14 +94,14 @@ def make_ppo_modules_pixels(proof_environment): # Define distribution class and kwargs if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox): num_outputs = proof_environment.action_spec_unbatched.space.n - distribution_class = OneHotCategorical + distribution_class = torch.distributions.Categorical distribution_kwargs = {} else: # is ContinuousBox num_outputs = proof_environment.action_spec_unbatched.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), } # Define input keys @@ -113,6 +113,7 @@ def make_ppo_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) common_cnn_output = common_cnn(torch.ones(input_shape)) common_mlp = MLP( @@ -121,6 +122,7 @@ def make_ppo_modules_pixels(proof_environment): activate_last_layer=True, out_features=512, num_cells=[], + device=device, ) common_mlp_output = common_mlp(common_cnn_output) @@ -137,6 +139,7 @@ def make_ppo_modules_pixels(proof_environment): out_features=num_outputs, activation_class=torch.nn.ReLU, num_cells=[], + device=device, ) policy_module = TensorDictModule( module=policy_net, @@ -148,7 +151,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.full_action_spec_unbatched, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -161,6 +164,7 @@ def make_ppo_modules_pixels(proof_environment): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[], + device=device, ) value_module = ValueOperator( value_net, @@ -170,11 +174,12 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): - proof_environment = make_parallel_env(env_name, 1, device="cpu") + proof_environment = make_parallel_env(env_name, 1, device=device) common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment + proof_environment, + device=device, ) # Wrap modules in a single ActorCritic operator @@ -185,8 +190,8 @@ def make_ppo_models(env_name): ) with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) + td = proof_environment.fake_tensordict().expand(10) + actor_critic(td) del td actor = actor_critic.get_policy_operator() diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index f2e08ffb129..584945013dc 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False) # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, } @@ -63,6 +63,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -87,7 +88,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.full_action_spec_unbatched, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -100,6 +101,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -117,9 +119,9 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): - proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) +def make_ppo_models(env_name, device): + proof_environment = make_env(env_name, device=device) + actor, critic = make_ppo_models_state(proof_environment, device=device) return actor, critic diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 45f8c433725..cc1621d8723 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -854,7 +854,7 @@ def set_mode(self, type: Any | None) -> None: @wraps(torch.compile) -def compile_with_warmup(*args, warmup: int, **kwargs): +def compile_with_warmup(*args, warmup: int = 1, **kwargs): """Compile a model with warm-up. This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase, @@ -863,7 +863,7 @@ def compile_with_warmup(*args, warmup: int, **kwargs): Args: *args: Arguments to be passed to `torch.compile`. - warmup (int): Number of calls to the model before compiling it. + warmup (int): Number of calls to the model before compiling it. Defaults to 1. **kwargs: Keyword arguments to be passed to `torch.compile`. Returns: @@ -888,7 +888,7 @@ def compile_with_warmup(*args, warmup: int, **kwargs): if model is None: return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs) else: - count = 0 + count = -1 compiled_model = model @wraps(model)