diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index d07b40595bc..7ff3f23b7a5 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -50,7 +50,7 @@ ) # Anything from 2.5, incl. nightlies, allows for fullgraph -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) def set_default_device(): cur_device = torch.get_default_device() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") diff --git a/sota-implementations/a2c/README.md b/sota-implementations/a2c/README.md index 513e6d70811..91c9099c8c9 100644 --- a/sota-implementations/a2c/README.md +++ b/sota-implementations/a2c/README.md @@ -19,11 +19,21 @@ Please note that each example is independent of each other for the sake of simpl You can execute the A2C algorithm on Atari environments by running the following command: ```bash -python a2c_atari.py +python a2c_atari.py compile.compile=1 compile.cudagraphs=1 ``` + You can execute the A2C algorithm on MuJoCo environments by running the following command: ```bash -python a2c_mujoco.py +python a2c_mujoco.py compile.compile=1 compile.cudagraphs=1 ``` + +## Runtimes + +Runtimes when executed on H100: + +| Environment | Eager | Compile | Compile+cudagraphs | +|-------------|-----------|-----------|--------------------| +| MUJOCO | < 25 mins | < 23 mins | < 20 mins | +| ATARI | < 85 mins | < 60 mins | < 45 mins | diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 39328570955..f6401b9946c 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -3,29 +3,37 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder +import torch + +torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time + from copy import deepcopy import torch.optim import tqdm + from tensordict import from_module + from tensordict.nn import CudaGraphModule - from tensordict import TensorDict + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -35,28 +43,16 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) - actor, critic, critic_head = ( - actor.to(device), - critic.to(device), - critic_head.to(device), - ) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), - policy=actor, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device) + with from_module(actor).data.to("meta").to_module(actor): + actor_eval = deepcopy(actor) + actor_eval.eval() + from_module(actor).data.to_module(actor_eval) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyTensorStorage(frames_per_batch, device=device), sampler=sampler, batch_size=mini_batch_size, ) @@ -67,6 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=True, + vectorized=not cfg.compile.compile, + device=device, ) loss_module = A2CLoss( actor_network=actor, @@ -83,9 +81,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizer optim = torch.optim.Adam( loss_module.parameters(), - lr=cfg.optim.lr, + lr=torch.tensor(cfg.optim.lr, device=device), weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, + capturable=device.type == "cuda", ) # Create logger @@ -115,19 +114,71 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + # update function + def update(batch, max_grad_norm=cfg.optim.max_grad_norm): + # Forward pass A2C loss + loss = loss_module(batch) + + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + loss_sum.backward() + gn = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad(set_to_none=True) + + return ( + loss.select("loss_critic", "loss_entropy", "loss_objective") + .detach() + .set("grad_norm", gn) + ) + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.compile.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + adv_module = CudaGraphModule(adv_module) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + policy_device=device, + compile_policy={"mode": compile_mode} if cfg.compile.compile else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) + # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + lr = cfg.optim.lr - sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) @@ -144,94 +195,76 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict(batch_size=[num_mini_batches]) - training_start = time.time() + losses = [] # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): + torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) + with timeit("rb - emptying"): + data_buffer.empty() + with timeit("rb - extending"): + data_buffer.extend(data_reshape) - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) + with timeit("optim"): + for batch in data_buffer: - # Update the networks - optim.step() - optim.zero_grad() + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"].copy_(lr * alpha) + + num_network_updates += 1 + + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch).clone() + losses.append(loss) # Get training losses - training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() + for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg.optim.lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, + "train/lr": lr * alpha, } ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: - actor.eval() - eval_start = time.time() test_rewards = eval_model( - actor, test_env, num_episodes=cfg.logger.num_test_episodes + actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "test/reward": test_rewards.mean(), - "test/eval_time": eval_time, } ) - actor.train() + if i % 200 == 0: + log_info.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() - sampling_start = time.time() - collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index e30541c2691..b75a5224bc5 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -3,54 +3,59 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder +import torch + +torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time + from copy import deepcopy import torch.optim import tqdm - from tensordict import TensorDict + from tensordict import from_module + from tensordict.nn import CudaGraphModule + + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss - from torchrl.objectives.value.advantages import GAE + from torchrl.objectives import A2CLoss, group_optimizers + from torchrl.objectives.value import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models # Define paper hyperparameters - device = "cpu" if not torch.cuda.device_count() else "cuda" + + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( cfg.collector.total_frames // cfg.collector.frames_per_batch ) * num_mini_batches # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, + actor, critic = make_ppo_models( + cfg.env.env_name, device=device, compile=cfg.compile.compile ) + with from_module(actor).data.to("meta").to_module(actor): + actor_eval = deepcopy(actor) + actor_eval.eval() + from_module(actor).data.to_module(actor_eval) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch), + storage=LazyTensorStorage(cfg.collector.frames_per_batch, device=device), sampler=sampler, batch_size=cfg.loss.mini_batch_size, ) @@ -61,6 +66,8 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=False, + vectorized=not cfg.compile.compile, + device=device, ) loss_module = A2CLoss( actor_network=actor, @@ -71,8 +78,18 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + actor_optim = torch.optim.Adam( + actor.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) + critic_optim = torch.optim.Adam( + critic.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim # Create logger logger = None @@ -99,19 +116,66 @@ def main(cfg: "DictConfig"): # noqa: F821 logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] ), ) + + def update(batch): + # Forward pass A2C loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + + # Backward pass + (actor_loss + critic_loss).backward() + + # Update the networks + optim.step() + + optim.zero_grad(set_to_none=True) + return loss.select("loss_critic", "loss_objective").detach() # , "loss_entropy" + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + update = torch.compile(update, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.compile.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=20) + adv_module = CudaGraphModule(adv_module, warmup=20) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + trust_policy=True, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) + test_env.eval() + lr = cfg.optim.lr # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=cfg.collector.total_frames) - sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) @@ -128,60 +192,43 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict(batch_size=[num_mini_batches]) - training_start = time.time() + losses = [] # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): + torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - for group in critic_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_objective" # , "loss_entropy" - ).detach() - critic_loss = loss["loss_critic"] - actor_loss = loss["loss_objective"] # + loss["loss_entropy"] - - # Backward pass - actor_loss.backward() - critic_loss.backward() - - # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + with timeit("emptying"): + data_buffer.empty() + with timeit("extending"): + data_buffer.extend(data_reshape) + + with timeit("optim"): + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"].copy_(lr * alpha) + num_network_updates += 1 + with timeit("optim - update"): + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch).clone() + losses.append(loss) # Get training losses - training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { "train/lr": alpha * cfg.optim.lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, } ) @@ -192,32 +239,30 @@ def main(cfg: "DictConfig"): # noqa: F821 final = collected_frames >= collector.total_frames if prev_test_frame < cur_test_frame or final: actor.eval() - eval_start = time.time() test_rewards = eval_model( actor, test_env, num_episodes=cfg.logger.num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "test/reward": test_rewards.mean(), - "test/eval_time": eval_time, } ) actor.train() + if i % 200 == 0: + log_info.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() - sampling_start = time.time() + torch.compiler.cudagraph_mark_step_begin() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index dd0f43b52cb..59a0a621756 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -1,11 +1,11 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 1 + num_envs: 16 # collector collector: - frames_per_batch: 80 + frames_per_batch: 800 total_frames: 40_000_000 # logger @@ -34,3 +34,9 @@ loss: critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 + device: + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index 03a0bde32c5..9e8f36a5995 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -4,7 +4,7 @@ env: # collector collector: - frames_per_batch: 64 + frames_per_batch: 640 total_frames: 1_000_000 # logger @@ -31,3 +31,9 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + device: + +compile: + compile: False + compile_mode: default + cudagraphs: False diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 6a09ff715e4..99c3ce2338c 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -64,10 +64,12 @@ def make_base_env( def make_parallel_env(env_name, num_envs, device, is_test=False): env = ParallelEnv( num_envs, - EnvCreator(lambda: make_base_env(env_name, device=device)), + EnvCreator(lambda: make_base_env(env_name)), serial_for_single=True, + device=device, ) env = TransformedEnv(env) + env.append_transform(DoubleToFloat()) env.append_transform(ToTensorImage()) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) @@ -76,7 +78,6 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): env.append_transform(StepCounter(max_steps=4500)) if not is_test: env.append_transform(SignTransform(in_keys=["reward"])) - env.append_transform(DoubleToFloat()) env.append_transform(VecNorm(in_keys=["pixels"])) return env @@ -86,7 +87,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): # -------------------------------------------------------------------- -def make_ppo_modules_pixels(proof_environment): +def make_ppo_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -100,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), } # Define input keys @@ -113,14 +114,16 @@ def make_ppo_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - common_cnn_output = common_cnn(torch.ones(input_shape)) + common_cnn_output = common_cnn(torch.ones(input_shape, device=device)) common_mlp = MLP( in_features=common_cnn_output.shape[-1], activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=512, num_cells=[], + device=device, ) common_mlp_output = common_mlp(common_cnn_output) @@ -137,6 +140,7 @@ def make_ppo_modules_pixels(proof_environment): out_features=num_outputs, activation_class=torch.nn.ReLU, num_cells=[], + device=device, ) policy_module = TensorDictModule( module=policy_net, @@ -148,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -161,6 +165,7 @@ def make_ppo_modules_pixels(proof_environment): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[], + device=device, ) value_module = ValueOperator( value_net, @@ -170,11 +175,11 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): proof_environment = make_parallel_env(env_name, 1, device="cpu") common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment + proof_environment, device=device ) # Wrap modules in a single ActorCritic operator @@ -185,8 +190,8 @@ def make_ppo_models(env_name): ) with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) + td = proof_environment.fake_tensordict().expand(1) + td = actor_critic(td.to(device)) del td actor = actor_critic.get_policy_operator() diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 996706ce4f9..87587d092f0 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -48,7 +48,7 @@ def make_env( # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, device, *, compile: bool = False): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -57,9 +57,10 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), "tanh_loc": False, + "safe_tanh": True, } # Define policy architecture @@ -68,6 +69,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -79,7 +81,9 @@ def make_ppo_models_state(proof_environment): # Add state-independent normal scale policy_mlp = torch.nn.Sequential( policy_mlp, - AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], device=device + ), ) # Add probabilistic sampling of the actions @@ -90,7 +94,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -103,6 +107,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -120,9 +125,11 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device, *, compile: bool = False): proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) + actor, critic = make_ppo_models_state( + proof_environment, device=device, compile=compile + ) return actor, critic diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 48492459315..683a0d60182 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -11,6 +11,7 @@ import os.path import time import unittest +import warnings from functools import wraps # Get relative file path @@ -20,9 +21,15 @@ import torch import torch.cuda -from tensordict import tensorclass, TensorDict -from torch import nn -from torchrl._utils import implement_for, logger as torchrl_logger, seed_generator +from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torch import nn, vmap +from torchrl._utils import ( + implement_for, + logger as torchrl_logger, + RL_WARNINGS, + seed_generator, +) from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm @@ -35,6 +42,7 @@ ToTensorImage, TransformedEnv, ) +from torchrl.objectives.value.advantages import _vmap_func # Specified for test_utils.py __version__ = "0.3" @@ -713,3 +721,89 @@ def forward( ): input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in) + + +def _call_value_nets( + value_net: TensorDictModuleBase, + data: TensorDictBase, + params: TensorDictBase, + next_params: TensorDictBase, + single_call: bool, + value_key: NestedKey, + detach_next: bool, + vmap_randomness: str = "error", +): + in_keys = value_net.in_keys + if single_call: + for i, name in enumerate(data.names): + if name == "time": + ndim = i + 1 + break + else: + ndim = None + if ndim is not None: + # get data at t and last of t+1 + idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) + idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) + idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False)[idx0], + ], + ndim - 1, + ) + else: + if RL_WARNINGS: + warnings.warn( + "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " + "This warning can be turned off by setting the environment variable RL_WARNINGS to False." + ) + ndim = data.ndim + idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) + idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + ndim - 1, + ) + + # next_params should be None or be identical to params + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." + ) + if params is not None: + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) + else: + value_est = value_net(data_in).get(value_key) + value, value_ = value_est[idx], value_est[idx_] + else: + data_in = torch.stack( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + 0, + ) + if (params is not None) ^ (next_params is not None): + raise ValueError( + "params and next_params must be either both provided or not." + ) + elif params is not None: + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) + else: + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) + value_est = data_out.get(value_key) + value, value_ = value_est[0], value_est[1] + data.set(value_key, value) + data.set(("next", value_key), value_) + if detach_next: + value_ = value_.detach() + return value, value_ diff --git a/test/test_cost.py b/test/test_cost.py index 1b54b8bf111..1e157fd7a2f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -106,7 +106,6 @@ ValueEstimators, ) from torchrl.objectives.value.advantages import ( - _call_value_nets, GAE, TD1Estimator, TDLambdaEstimator, @@ -135,6 +134,7 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import ( # noqa + _call_value_nets, dtype_fixture, get_available_devices, get_default_devices, @@ -142,6 +142,7 @@ from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv else: from _utils_internal import ( # noqa + _call_value_nets, dtype_fixture, get_available_devices, get_default_devices, diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b8c2709b583..fe1a796ea2d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -67,6 +67,15 @@ set_exploration_type, ) +try: + from torch.compiler import cudagraph_mark_step_begin +except ImportError: + + def cudagraph_mark_step_begin(): + """Placeholder for missing cudagraph_mark_step_begin method.""" + raise NotImplementedError("cudagraph_mark_step_begin not implemented.") + + _TIMEOUT = 1.0 INSTANTIATE_TIMEOUT = 20 _MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory @@ -833,6 +842,8 @@ def _make_final_rollout(self): policy_input_clone = ( policy_input.clone() ) # to test if values have changed in-place + if self.compiled_policy: + cudagraph_mark_step_begin() policy_output = self.policy(policy_input) # check that we don't have exclusive keys, because they don't appear in keys @@ -1146,7 +1157,11 @@ def rollout(self) -> TensorDictBase: else: policy_input = self._shuttle # we still do the assignment for security + if self.compiled_policy: + cudagraph_mark_step_begin() policy_output = self.policy(policy_input) + if self.compiled_policy: + policy_output = policy_output.clone() if self._shuttle is not policy_output: # ad-hoc update shuttle self._shuttle.update( diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 995f245a8ac..07d339761b0 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -309,7 +309,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if _reward is not None: reward = reward + _reward - terminated, truncated, done, do_break = self.read_done( terminated=terminated, truncated=truncated, done=done ) @@ -323,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # if truncated/terminated is not in the keys, we just don't pass it even if it # is defined. if terminated is None: - terminated = done + terminated = done.clone() if truncated is not None: obs_dict["truncated"] = truncated obs_dict["done"] = done diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 66acc8639d4..c83591acb63 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -1423,7 +1423,7 @@ def _make_compatible_policy( env_maker=None, env_maker_kwargs=None, ): - if trust_policy: + if trust_policy or isinstance(policy, torch._dynamo.eval_frame.OptimizedModule): return policy if policy is None: input_spec = None diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 52f8f302a35..8f1b7da49a5 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from tensordict.nn import NormalParamExtractor +from torch import distributions as torch_dist from .continuous import ( Delta, @@ -37,3 +38,16 @@ OneHotOrdinal, ) } + +HAS_ENTROPY = { + Delta: False, + IndependentNormal: True, + TanhDelta: False, + TanhNormal: False, + TruncatedNormal: False, + MaskedCategorical: False, + MaskedOneHotCategorical: False, + OneHotCategorical: True, + torch_dist.Categorical: True, + torch_dist.Normal: True, +} diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index cde2c95d30f..62b5df5d14b 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -42,9 +42,8 @@ except ImportError: from torch._dynamo import is_compiling as is_dynamo_compiling -TORCH_VERSION_PRE_2_6 = version.parse(torch.__version__).base_version < version.parse( - "2.6.0" -) +TORCH_VERSION = version.parse(torch.__version__).base_version +TORCH_VERSION_PRE_2_6 = version.parse(TORCH_VERSION) < version.parse("2.6.0") class IndependentNormal(D.Independent): @@ -408,15 +407,16 @@ def __init__( event_dims = min(1, loc.ndim) err_msg = "TanhNormal high values must be strictly greater than low values" - if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): - if not (high > low).all(): - raise RuntimeError(err_msg) - elif isinstance(high, Number) and isinstance(low, Number): - if not high > low: - raise RuntimeError(err_msg) - else: - if not all(high > low): - raise RuntimeError(err_msg) + if not is_dynamo_compiling(): + if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): + if not (high > low).all(): + raise RuntimeError(err_msg) + elif isinstance(high, Number) and isinstance(low, Number): + if not high > low: + raise RuntimeError(err_msg) + else: + if not all(high > low): + raise RuntimeError(err_msg) high = torch.as_tensor(high, device=loc.device) low = torch.as_tensor(low, device=loc.device) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index 3dac06cd270..bfb7bc48f3c 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - from enum import Enum from functools import wraps from typing import Any, Optional, Sequence, Union @@ -11,6 +10,9 @@ import torch.distributions as D import torch.nn.functional as F +from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits + + __all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"] @@ -79,6 +81,17 @@ class OneHotCategorical(D.Categorical): """ + num_params: int = 1 + + # This is to make the compiler happy, see https://github.com/pytorch/pytorch/issues/140266 + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, @@ -106,6 +119,12 @@ def mode(self) -> torch.Tensor: def deterministic_sample(self): return self.mode + def entropy(self): + min_real = torch.finfo(self.logits.dtype).min + logits = torch.clamp(self.logits, min=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + @_one_hot_wrapper(D.Categorical) def sample( self, sample_shape: Optional[Union[torch.Size, Sequence]] = None @@ -188,6 +207,14 @@ class MaskedCategorical(D.Categorical): -2.1972, -2.1972]) """ + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, @@ -360,6 +387,14 @@ class MaskedOneHotCategorical(MaskedCategorical): -2.1972, -2.1972]) """ + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 96c1bc63eae..f324e491298 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -20,12 +20,14 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl.modules.distributions import HAS_ENTROPY from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _get_default_device, _reduce, default_value_kwargs, distance_loss, @@ -316,10 +318,7 @@ def __init__( self.entropy_bonus = entropy_bonus and entropy_coef self.reduction = reduction - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + device = _get_default_device(self) self.register_buffer( "entropy_coef", torch.as_tensor(entropy_coef, device=device) @@ -347,7 +346,11 @@ def __init__( raise ValueError( f"clip_value must be a float or a scalar tensor, got {clip_value}." ) - self.register_buffer("clip_value", clip_value) + self.register_buffer( + "clip_value", torch.as_tensor(clip_value, device=device) + ) + else: + self.clip_value = None @property def functional(self): @@ -398,9 +401,9 @@ def reset(self) -> None: pass def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: - try: + if HAS_ENTROPY.get(type(dist), False): entropy = dist.entropy() - except NotImplementedError: + else: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) if is_tensor_collection(log_prob): @@ -456,7 +459,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: old_state_value = old_state_value.clone() # TODO: if the advantage is gathered by forward, this introduces an - # overhead that we could easily reduce. + # overhead that we could easily reduce. target_return = tensordict.get( self.tensor_keys.value_target, None ) # TODO: None soon to be removed @@ -487,7 +490,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: loss_value, clip_fraction = _clip_value_loss( old_state_value, state_value, - self.clip_value.to(state_value.device), + self.clip_value, target_return, loss_value, self.loss_critic_type, @@ -541,6 +544,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) + device = _get_default_device(self) + hp["device"] = device + if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 17ab16cfefa..b0ba254d2b3 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -592,6 +592,13 @@ def _clip_value_loss( return loss_value, clip_fraction +def _get_default_device(net): + for p in net.parameters(): + return p.device + else: + return torch.get_default_device() + + def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer: """Groups multiple optimizers into a single one. diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index e396b7e1fcc..739fb9a018e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -37,6 +37,10 @@ vtrace_advantage_estimate, ) +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling try: from torch import vmap @@ -69,92 +73,6 @@ def new_func(self, *args, **kwargs): return new_func -def _call_value_nets( - value_net: TensorDictModuleBase, - data: TensorDictBase, - params: TensorDictBase, - next_params: TensorDictBase, - single_call: bool, - value_key: NestedKey, - detach_next: bool, - vmap_randomness: str = "error", -): - in_keys = value_net.in_keys - if single_call: - for i, name in enumerate(data.names): - if name == "time": - ndim = i + 1 - break - else: - ndim = None - if ndim is not None: - # get data at t and last of t+1 - idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) - idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) - idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False)[idx0], - ], - ndim - 1, - ) - else: - if RL_WARNINGS: - warnings.warn( - "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " - "This warning can be turned off by setting the environment variable RL_WARNINGS to False." - ) - ndim = data.ndim - idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) - idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - ndim - 1, - ) - - # next_params should be None or be identical to params - if next_params is not None and next_params is not params: - raise ValueError( - "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." - ) - if params is not None: - with params.to_module(value_net): - value_est = value_net(data_in).get(value_key) - else: - value_est = value_net(data_in).get(value_key) - value, value_ = value_est[idx], value_est[idx_] - else: - data_in = torch.stack( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - 0, - ) - if (params is not None) ^ (next_params is not None): - raise ValueError( - "params and next_params must be either both provided or not." - ) - elif params is not None: - params_stack = torch.stack([params, next_params], 0).contiguous() - data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( - data_in, params_stack - ) - else: - data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) - value_est = data_out.get(value_key) - value, value_ = value_est[0], value_est[1] - data.set(value_key, value) - data.set(("next", value_key), value_) - if detach_next: - value_ = value_.detach() - return value, value_ - - def _call_actor_net( actor_net: TensorDictModuleBase, data: TensorDictBase, @@ -432,6 +350,9 @@ def _next_value(self, tensordict, target_params, kwargs): @property def vmap_randomness(self): if self._vmap_randomness is None: + if is_dynamo_compiling(): + self._vmap_randomness = "different" + return "different" do_break = False for val in self.__dict__.values(): if isinstance(val, torch.nn.Module): @@ -467,6 +388,99 @@ def _get_time_dim(self, time_dim: int | None, data: TensorDictBase): return i return data.ndim - 1 + def _call_value_nets( + self, + data: TensorDictBase, + params: TensorDictBase, + next_params: TensorDictBase, + single_call: bool, + value_key: NestedKey, + detach_next: bool, + vmap_randomness: str = "error", + *, + value_net: TensorDictModuleBase | None = None, + ): + if value_net is None: + value_net = self.value_network + in_keys = value_net.in_keys + if single_call: + for i, name in enumerate(data.names): + if name == "time": + ndim = i + 1 + break + else: + ndim = None + if ndim is not None: + # get data at t and last of t+1 + idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) + idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) + idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False)[ + idx0 + ], + ], + ndim - 1, + ) + else: + if RL_WARNINGS: + warnings.warn( + "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " + "This warning can be turned off by setting the environment variable RL_WARNINGS to False." + ) + ndim = data.ndim + idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) + idx_ = (slice(None),) * (ndim - 1) + ( + slice(data.shape[ndim - 1], None), + ) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + ndim - 1, + ) + + # next_params should be None or be identical to params + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." + ) + if params is not None: + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) + else: + value_est = value_net(data_in).get(value_key) + value, value_ = value_est[idx], value_est[idx_] + else: + data_in = torch.stack( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + 0, + ) + if (params is not None) ^ (next_params is not None): + raise ValueError( + "params and next_params must be either both provided or not." + ) + elif params is not None: + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) + else: + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) + value_est = data_out.get(value_key) + value, value_ = value_est[0], value_est[1] + data.set(value_key, value) + data.set(("next", value_key), value_) + if detach_next: + value_ = value_.detach() + return value, value_ + class TD0Estimator(ValueEstimatorBase): """Temporal Difference (TD(0)) estimate of advantage function. @@ -623,8 +637,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -651,7 +664,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -837,8 +854,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -867,7 +883,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1063,8 +1083,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1092,7 +1111,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1155,7 +1178,7 @@ class GAE(ValueEstimatorBase): pass detached parameters for functional modules. vectorized (bool, optional): whether to use the vectorized version of the - lambda return. Default is `True`. + lambda return. Default is `True` if not compiling. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` @@ -1205,7 +1228,7 @@ def __init__( value_network: TensorDictModule, average_gae: bool = False, differentiable: bool = False, - vectorized: bool = True, + vectorized: bool | None = None, skip_existing: bool | None = None, advantage_key: NestedKey = None, value_target_key: NestedKey = None, @@ -1229,6 +1252,16 @@ def __init__( self.vectorized = vectorized self.time_dim = time_dim + @property + def vectorized(self): + if is_dynamo_compiling(): + return False + return self._vectorized + + @vectorized.setter + def vectorized(self, value): + self._vectorized = value + @_self_set_skip_existing @_self_set_grad_enabled @dispatch @@ -1315,7 +1348,13 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + gamma, lmbda = self.gamma, self.lmbda + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1328,10 +1367,10 @@ def forward( with hold_out_net(self.value_network) if ( params is None and target_params is None ) else nullcontext(): + # with torch.no_grad(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1396,7 +1435,13 @@ def value_estimate( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + gamma, lmbda = self.gamma, self.lmbda + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1417,8 +1462,7 @@ def value_estimate( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1668,7 +1712,11 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1682,8 +1730,7 @@ def forward( with hold_out_net(self.value_network): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index e20e1cbb2c9..15e5d56d6bf 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -12,6 +12,10 @@ import torch +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling __all__ = [ "generalized_advantage_estimate", @@ -147,7 +151,7 @@ def generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -181,19 +185,25 @@ def generalized_advantage_estimate( def _geom_series_like(t, r, thr): """Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`. - Drops all elements which are smaller than `thr`. + Drops all elements which are smaller than `thr` (unless in compile mode). """ - if isinstance(r, torch.Tensor): - r = r.item() - - if r == 0.0: - return torch.zeros_like(t) - elif r >= 1.0: - lim = t.numel() + if is_dynamo_compiling(): + if isinstance(r, torch.Tensor): + rs = r.expand_as(t) + else: + rs = torch.full_like(t, r) else: - lim = int(math.log(thr) / math.log(r)) + if isinstance(r, torch.Tensor): + r = r.item() + + if r == 0.0: + return torch.zeros_like(t) + elif r >= 1.0: + lim = t.numel() + else: + lim = int(math.log(thr) / math.log(r)) - rs = torch.full_like(t[:lim], r) + rs = torch.full_like(t[:lim], r) rs[0] = 1.0 rs = rs.cumprod(0) rs = rs.unsqueeze(-1) @@ -292,7 +302,7 @@ def vec_generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -391,7 +401,7 @@ def td0_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -435,7 +445,7 @@ def td0_return_estimate( """ if done is not None and terminated is None: - terminated = done + terminated = done.clone() warnings.warn( "done for td0_return_estimate is deprecated. Pass ``terminated`` instead." ) @@ -503,7 +513,7 @@ def td1_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) not_done = (~done).int() @@ -604,7 +614,7 @@ def td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -742,7 +752,7 @@ def vec_td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -824,7 +834,7 @@ def td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -934,7 +944,7 @@ def td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -1074,7 +1084,7 @@ def vec_td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -1228,7 +1238,7 @@ def vec_td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index ec1d33069a5..5eb678e01e1 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -301,10 +301,13 @@ def _fill_tensor(tensor): device=tensor.device, ) mask_expand = expand_right(mask, (*mask.shape, *tensor.shape[1:])) + # We need to use masked-scatter to accommodate vmap return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1)) + # empty_tensor[mask_expand] = tensor.reshape(-1) + # return empty_tensor if isinstance(tensor, TensorDictBase): - tensor = tensor.apply(_fill_tensor, batch_size=[*shape]) + tensor = tensor.apply(_fill_tensor, batch_size=list(shape)) else: tensor = _fill_tensor(tensor) if return_mask: