From 507766a8846da6c2f1907dab89f0480a28a06b57 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:26:30 +0000 Subject: [PATCH 1/3] [Feature] A2C compatibility with compile ghstack-source-id: 66a7f0d1dd82d6463d61c1671e8e0a14ac9a55e7 Pull Request resolved: https://github.com/pytorch/rl/pull/2464 --- benchmarks/test_objectives_benchmarks.py | 2 +- sota-implementations/a2c/README.md | 14 +- sota-implementations/a2c/a2c_atari.py | 195 +++++++++------ sota-implementations/a2c/a2c_mujoco.py | 197 +++++++++------ sota-implementations/a2c/config_atari.yaml | 10 +- sota-implementations/a2c/config_mujoco.yaml | 8 +- sota-implementations/a2c/utils_atari.py | 27 +- sota-implementations/a2c/utils_mujoco.py | 21 +- test/_utils_internal.py | 100 +++++++- test/test_cost.py | 3 +- torchrl/collectors/collectors.py | 15 ++ torchrl/envs/gym_like.py | 3 +- torchrl/envs/utils.py | 2 +- torchrl/modules/distributions/__init__.py | 14 ++ torchrl/modules/distributions/continuous.py | 24 +- torchrl/modules/distributions/discrete.py | 37 ++- torchrl/objectives/a2c.py | 24 +- torchrl/objectives/utils.py | 7 + torchrl/objectives/value/advantages.py | 259 ++++++++++++-------- torchrl/objectives/value/functional.py | 52 ++-- torchrl/objectives/value/utils.py | 5 +- 21 files changed, 681 insertions(+), 338 deletions(-) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index d07b40595bc..7ff3f23b7a5 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -50,7 +50,7 @@ ) # Anything from 2.5, incl. nightlies, allows for fullgraph -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) def set_default_device(): cur_device = torch.get_default_device() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") diff --git a/sota-implementations/a2c/README.md b/sota-implementations/a2c/README.md index 513e6d70811..91c9099c8c9 100644 --- a/sota-implementations/a2c/README.md +++ b/sota-implementations/a2c/README.md @@ -19,11 +19,21 @@ Please note that each example is independent of each other for the sake of simpl You can execute the A2C algorithm on Atari environments by running the following command: ```bash -python a2c_atari.py +python a2c_atari.py compile.compile=1 compile.cudagraphs=1 ``` + You can execute the A2C algorithm on MuJoCo environments by running the following command: ```bash -python a2c_mujoco.py +python a2c_mujoco.py compile.compile=1 compile.cudagraphs=1 ``` + +## Runtimes + +Runtimes when executed on H100: + +| Environment | Eager | Compile | Compile+cudagraphs | +|-------------|-----------|-----------|--------------------| +| MUJOCO | < 25 mins | < 23 mins | < 20 mins | +| ATARI | < 85 mins | < 60 mins | < 45 mins | diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 39328570955..f6401b9946c 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -3,29 +3,37 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder +import torch + +torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time + from copy import deepcopy import torch.optim import tqdm + from tensordict import from_module + from tensordict.nn import CudaGraphModule - from tensordict import TensorDict + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -35,28 +43,16 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) - actor, critic, critic_head = ( - actor.to(device), - critic.to(device), - critic_head.to(device), - ) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), - policy=actor, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device) + with from_module(actor).data.to("meta").to_module(actor): + actor_eval = deepcopy(actor) + actor_eval.eval() + from_module(actor).data.to_module(actor_eval) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyTensorStorage(frames_per_batch, device=device), sampler=sampler, batch_size=mini_batch_size, ) @@ -67,6 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=True, + vectorized=not cfg.compile.compile, + device=device, ) loss_module = A2CLoss( actor_network=actor, @@ -83,9 +81,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizer optim = torch.optim.Adam( loss_module.parameters(), - lr=cfg.optim.lr, + lr=torch.tensor(cfg.optim.lr, device=device), weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, + capturable=device.type == "cuda", ) # Create logger @@ -115,19 +114,71 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + # update function + def update(batch, max_grad_norm=cfg.optim.max_grad_norm): + # Forward pass A2C loss + loss = loss_module(batch) + + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + loss_sum.backward() + gn = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad(set_to_none=True) + + return ( + loss.select("loss_critic", "loss_entropy", "loss_objective") + .detach() + .set("grad_norm", gn) + ) + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.compile.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + adv_module = CudaGraphModule(adv_module) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + policy_device=device, + compile_policy={"mode": compile_mode} if cfg.compile.compile else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) + # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + lr = cfg.optim.lr - sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) @@ -144,94 +195,76 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict(batch_size=[num_mini_batches]) - training_start = time.time() + losses = [] # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): + torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) + with timeit("rb - emptying"): + data_buffer.empty() + with timeit("rb - extending"): + data_buffer.extend(data_reshape) - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) + with timeit("optim"): + for batch in data_buffer: - # Update the networks - optim.step() - optim.zero_grad() + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"].copy_(lr * alpha) + + num_network_updates += 1 + + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch).clone() + losses.append(loss) # Get training losses - training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() + for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg.optim.lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, + "train/lr": lr * alpha, } ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: - actor.eval() - eval_start = time.time() test_rewards = eval_model( - actor, test_env, num_episodes=cfg.logger.num_test_episodes + actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "test/reward": test_rewards.mean(), - "test/eval_time": eval_time, } ) - actor.train() + if i % 200 == 0: + log_info.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() - sampling_start = time.time() - collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index e30541c2691..b75a5224bc5 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -3,54 +3,59 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder +import torch + +torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time + from copy import deepcopy import torch.optim import tqdm - from tensordict import TensorDict + from tensordict import from_module + from tensordict.nn import CudaGraphModule + + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss - from torchrl.objectives.value.advantages import GAE + from torchrl.objectives import A2CLoss, group_optimizers + from torchrl.objectives.value import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models # Define paper hyperparameters - device = "cpu" if not torch.cuda.device_count() else "cuda" + + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( cfg.collector.total_frames // cfg.collector.frames_per_batch ) * num_mini_batches # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, + actor, critic = make_ppo_models( + cfg.env.env_name, device=device, compile=cfg.compile.compile ) + with from_module(actor).data.to("meta").to_module(actor): + actor_eval = deepcopy(actor) + actor_eval.eval() + from_module(actor).data.to_module(actor_eval) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch), + storage=LazyTensorStorage(cfg.collector.frames_per_batch, device=device), sampler=sampler, batch_size=cfg.loss.mini_batch_size, ) @@ -61,6 +66,8 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=False, + vectorized=not cfg.compile.compile, + device=device, ) loss_module = A2CLoss( actor_network=actor, @@ -71,8 +78,18 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + actor_optim = torch.optim.Adam( + actor.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) + critic_optim = torch.optim.Adam( + critic.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim # Create logger logger = None @@ -99,19 +116,66 @@ def main(cfg: "DictConfig"): # noqa: F821 logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] ), ) + + def update(batch): + # Forward pass A2C loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + + # Backward pass + (actor_loss + critic_loss).backward() + + # Update the networks + optim.step() + + optim.zero_grad(set_to_none=True) + return loss.select("loss_critic", "loss_objective").detach() # , "loss_entropy" + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + update = torch.compile(update, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.compile.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=20) + adv_module = CudaGraphModule(adv_module, warmup=20) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + trust_policy=True, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) + test_env.eval() + lr = cfg.optim.lr # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=cfg.collector.total_frames) - sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) @@ -128,60 +192,43 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict(batch_size=[num_mini_batches]) - training_start = time.time() + losses = [] # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): + torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - for group in critic_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_objective" # , "loss_entropy" - ).detach() - critic_loss = loss["loss_critic"] - actor_loss = loss["loss_objective"] # + loss["loss_entropy"] - - # Backward pass - actor_loss.backward() - critic_loss.backward() - - # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + with timeit("emptying"): + data_buffer.empty() + with timeit("extending"): + data_buffer.extend(data_reshape) + + with timeit("optim"): + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"].copy_(lr * alpha) + num_network_updates += 1 + with timeit("optim - update"): + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch).clone() + losses.append(loss) # Get training losses - training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { "train/lr": alpha * cfg.optim.lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, } ) @@ -192,32 +239,30 @@ def main(cfg: "DictConfig"): # noqa: F821 final = collected_frames >= collector.total_frames if prev_test_frame < cur_test_frame or final: actor.eval() - eval_start = time.time() test_rewards = eval_model( actor, test_env, num_episodes=cfg.logger.num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "test/reward": test_rewards.mean(), - "test/eval_time": eval_time, } ) actor.train() + if i % 200 == 0: + log_info.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() - sampling_start = time.time() + torch.compiler.cudagraph_mark_step_begin() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index dd0f43b52cb..59a0a621756 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -1,11 +1,11 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 1 + num_envs: 16 # collector collector: - frames_per_batch: 80 + frames_per_batch: 800 total_frames: 40_000_000 # logger @@ -34,3 +34,9 @@ loss: critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 + device: + +compile: + compile: False + compile_mode: + cudagraphs: False diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index 03a0bde32c5..9e8f36a5995 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -4,7 +4,7 @@ env: # collector collector: - frames_per_batch: 64 + frames_per_batch: 640 total_frames: 1_000_000 # logger @@ -31,3 +31,9 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + device: + +compile: + compile: False + compile_mode: default + cudagraphs: False diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 6a09ff715e4..99c3ce2338c 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -64,10 +64,12 @@ def make_base_env( def make_parallel_env(env_name, num_envs, device, is_test=False): env = ParallelEnv( num_envs, - EnvCreator(lambda: make_base_env(env_name, device=device)), + EnvCreator(lambda: make_base_env(env_name)), serial_for_single=True, + device=device, ) env = TransformedEnv(env) + env.append_transform(DoubleToFloat()) env.append_transform(ToTensorImage()) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) @@ -76,7 +78,6 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): env.append_transform(StepCounter(max_steps=4500)) if not is_test: env.append_transform(SignTransform(in_keys=["reward"])) - env.append_transform(DoubleToFloat()) env.append_transform(VecNorm(in_keys=["pixels"])) return env @@ -86,7 +87,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): # -------------------------------------------------------------------- -def make_ppo_modules_pixels(proof_environment): +def make_ppo_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -100,8 +101,8 @@ def make_ppo_modules_pixels(proof_environment): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), } # Define input keys @@ -113,14 +114,16 @@ def make_ppo_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - common_cnn_output = common_cnn(torch.ones(input_shape)) + common_cnn_output = common_cnn(torch.ones(input_shape, device=device)) common_mlp = MLP( in_features=common_cnn_output.shape[-1], activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=512, num_cells=[], + device=device, ) common_mlp_output = common_mlp(common_cnn_output) @@ -137,6 +140,7 @@ def make_ppo_modules_pixels(proof_environment): out_features=num_outputs, activation_class=torch.nn.ReLU, num_cells=[], + device=device, ) policy_module = TensorDictModule( module=policy_net, @@ -148,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -161,6 +165,7 @@ def make_ppo_modules_pixels(proof_environment): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[], + device=device, ) value_module = ValueOperator( value_net, @@ -170,11 +175,11 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): proof_environment = make_parallel_env(env_name, 1, device="cpu") common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment + proof_environment, device=device ) # Wrap modules in a single ActorCritic operator @@ -185,8 +190,8 @@ def make_ppo_models(env_name): ) with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) + td = proof_environment.fake_tensordict().expand(1) + td = actor_critic(td.to(device)) del td actor = actor_critic.get_policy_operator() diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 996706ce4f9..87587d092f0 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -48,7 +48,7 @@ def make_env( # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, device, *, compile: bool = False): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -57,9 +57,10 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), "tanh_loc": False, + "safe_tanh": True, } # Define policy architecture @@ -68,6 +69,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -79,7 +81,9 @@ def make_ppo_models_state(proof_environment): # Add state-independent normal scale policy_mlp = torch.nn.Sequential( policy_mlp, - AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], device=device + ), ) # Add probabilistic sampling of the actions @@ -90,7 +94,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -103,6 +107,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -120,9 +125,11 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device, *, compile: bool = False): proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) + actor, critic = make_ppo_models_state( + proof_environment, device=device, compile=compile + ) return actor, critic diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 48492459315..683a0d60182 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -11,6 +11,7 @@ import os.path import time import unittest +import warnings from functools import wraps # Get relative file path @@ -20,9 +21,15 @@ import torch import torch.cuda -from tensordict import tensorclass, TensorDict -from torch import nn -from torchrl._utils import implement_for, logger as torchrl_logger, seed_generator +from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torch import nn, vmap +from torchrl._utils import ( + implement_for, + logger as torchrl_logger, + RL_WARNINGS, + seed_generator, +) from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm @@ -35,6 +42,7 @@ ToTensorImage, TransformedEnv, ) +from torchrl.objectives.value.advantages import _vmap_func # Specified for test_utils.py __version__ = "0.3" @@ -713,3 +721,89 @@ def forward( ): input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in) + + +def _call_value_nets( + value_net: TensorDictModuleBase, + data: TensorDictBase, + params: TensorDictBase, + next_params: TensorDictBase, + single_call: bool, + value_key: NestedKey, + detach_next: bool, + vmap_randomness: str = "error", +): + in_keys = value_net.in_keys + if single_call: + for i, name in enumerate(data.names): + if name == "time": + ndim = i + 1 + break + else: + ndim = None + if ndim is not None: + # get data at t and last of t+1 + idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) + idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) + idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False)[idx0], + ], + ndim - 1, + ) + else: + if RL_WARNINGS: + warnings.warn( + "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " + "This warning can be turned off by setting the environment variable RL_WARNINGS to False." + ) + ndim = data.ndim + idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) + idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + ndim - 1, + ) + + # next_params should be None or be identical to params + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." + ) + if params is not None: + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) + else: + value_est = value_net(data_in).get(value_key) + value, value_ = value_est[idx], value_est[idx_] + else: + data_in = torch.stack( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + 0, + ) + if (params is not None) ^ (next_params is not None): + raise ValueError( + "params and next_params must be either both provided or not." + ) + elif params is not None: + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) + else: + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) + value_est = data_out.get(value_key) + value, value_ = value_est[0], value_est[1] + data.set(value_key, value) + data.set(("next", value_key), value_) + if detach_next: + value_ = value_.detach() + return value, value_ diff --git a/test/test_cost.py b/test/test_cost.py index 1b54b8bf111..1e157fd7a2f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -106,7 +106,6 @@ ValueEstimators, ) from torchrl.objectives.value.advantages import ( - _call_value_nets, GAE, TD1Estimator, TDLambdaEstimator, @@ -135,6 +134,7 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import ( # noqa + _call_value_nets, dtype_fixture, get_available_devices, get_default_devices, @@ -142,6 +142,7 @@ from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv else: from _utils_internal import ( # noqa + _call_value_nets, dtype_fixture, get_available_devices, get_default_devices, diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b8c2709b583..fe1a796ea2d 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -67,6 +67,15 @@ set_exploration_type, ) +try: + from torch.compiler import cudagraph_mark_step_begin +except ImportError: + + def cudagraph_mark_step_begin(): + """Placeholder for missing cudagraph_mark_step_begin method.""" + raise NotImplementedError("cudagraph_mark_step_begin not implemented.") + + _TIMEOUT = 1.0 INSTANTIATE_TIMEOUT = 20 _MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory @@ -833,6 +842,8 @@ def _make_final_rollout(self): policy_input_clone = ( policy_input.clone() ) # to test if values have changed in-place + if self.compiled_policy: + cudagraph_mark_step_begin() policy_output = self.policy(policy_input) # check that we don't have exclusive keys, because they don't appear in keys @@ -1146,7 +1157,11 @@ def rollout(self) -> TensorDictBase: else: policy_input = self._shuttle # we still do the assignment for security + if self.compiled_policy: + cudagraph_mark_step_begin() policy_output = self.policy(policy_input) + if self.compiled_policy: + policy_output = policy_output.clone() if self._shuttle is not policy_output: # ad-hoc update shuttle self._shuttle.update( diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 995f245a8ac..07d339761b0 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -309,7 +309,6 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if _reward is not None: reward = reward + _reward - terminated, truncated, done, do_break = self.read_done( terminated=terminated, truncated=truncated, done=done ) @@ -323,7 +322,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # if truncated/terminated is not in the keys, we just don't pass it even if it # is defined. if terminated is None: - terminated = done + terminated = done.clone() if truncated is not None: obs_dict["truncated"] = truncated obs_dict["done"] = done diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 66acc8639d4..c83591acb63 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -1423,7 +1423,7 @@ def _make_compatible_policy( env_maker=None, env_maker_kwargs=None, ): - if trust_policy: + if trust_policy or isinstance(policy, torch._dynamo.eval_frame.OptimizedModule): return policy if policy is None: input_spec = None diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index 52f8f302a35..8f1b7da49a5 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from tensordict.nn import NormalParamExtractor +from torch import distributions as torch_dist from .continuous import ( Delta, @@ -37,3 +38,16 @@ OneHotOrdinal, ) } + +HAS_ENTROPY = { + Delta: False, + IndependentNormal: True, + TanhDelta: False, + TanhNormal: False, + TruncatedNormal: False, + MaskedCategorical: False, + MaskedOneHotCategorical: False, + OneHotCategorical: True, + torch_dist.Categorical: True, + torch_dist.Normal: True, +} diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index cde2c95d30f..62b5df5d14b 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -42,9 +42,8 @@ except ImportError: from torch._dynamo import is_compiling as is_dynamo_compiling -TORCH_VERSION_PRE_2_6 = version.parse(torch.__version__).base_version < version.parse( - "2.6.0" -) +TORCH_VERSION = version.parse(torch.__version__).base_version +TORCH_VERSION_PRE_2_6 = version.parse(TORCH_VERSION) < version.parse("2.6.0") class IndependentNormal(D.Independent): @@ -408,15 +407,16 @@ def __init__( event_dims = min(1, loc.ndim) err_msg = "TanhNormal high values must be strictly greater than low values" - if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): - if not (high > low).all(): - raise RuntimeError(err_msg) - elif isinstance(high, Number) and isinstance(low, Number): - if not high > low: - raise RuntimeError(err_msg) - else: - if not all(high > low): - raise RuntimeError(err_msg) + if not is_dynamo_compiling(): + if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): + if not (high > low).all(): + raise RuntimeError(err_msg) + elif isinstance(high, Number) and isinstance(low, Number): + if not high > low: + raise RuntimeError(err_msg) + else: + if not all(high > low): + raise RuntimeError(err_msg) high = torch.as_tensor(high, device=loc.device) low = torch.as_tensor(low, device=loc.device) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index 3dac06cd270..bfb7bc48f3c 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - from enum import Enum from functools import wraps from typing import Any, Optional, Sequence, Union @@ -11,6 +10,9 @@ import torch.distributions as D import torch.nn.functional as F +from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits + + __all__ = ["OneHotCategorical", "MaskedCategorical", "Ordinal", "OneHotOrdinal"] @@ -79,6 +81,17 @@ class OneHotCategorical(D.Categorical): """ + num_params: int = 1 + + # This is to make the compiler happy, see https://github.com/pytorch/pytorch/issues/140266 + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, @@ -106,6 +119,12 @@ def mode(self) -> torch.Tensor: def deterministic_sample(self): return self.mode + def entropy(self): + min_real = torch.finfo(self.logits.dtype).min + logits = torch.clamp(self.logits, min=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + @_one_hot_wrapper(D.Categorical) def sample( self, sample_shape: Optional[Union[torch.Size, Sequence]] = None @@ -188,6 +207,14 @@ class MaskedCategorical(D.Categorical): -2.1972, -2.1972]) """ + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, @@ -360,6 +387,14 @@ class MaskedOneHotCategorical(MaskedCategorical): -2.1972, -2.1972]) """ + @lazy_property + def logits(self): + return probs_to_logits(self.probs) + + @lazy_property + def probs(self): + return logits_to_probs(self.logits) + def __init__( self, logits: Optional[torch.Tensor] = None, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 96c1bc63eae..f324e491298 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -20,12 +20,14 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl.modules.distributions import HAS_ENTROPY from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _get_default_device, _reduce, default_value_kwargs, distance_loss, @@ -316,10 +318,7 @@ def __init__( self.entropy_bonus = entropy_bonus and entropy_coef self.reduction = reduction - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + device = _get_default_device(self) self.register_buffer( "entropy_coef", torch.as_tensor(entropy_coef, device=device) @@ -347,7 +346,11 @@ def __init__( raise ValueError( f"clip_value must be a float or a scalar tensor, got {clip_value}." ) - self.register_buffer("clip_value", clip_value) + self.register_buffer( + "clip_value", torch.as_tensor(clip_value, device=device) + ) + else: + self.clip_value = None @property def functional(self): @@ -398,9 +401,9 @@ def reset(self) -> None: pass def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: - try: + if HAS_ENTROPY.get(type(dist), False): entropy = dist.entropy() - except NotImplementedError: + else: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) if is_tensor_collection(log_prob): @@ -456,7 +459,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: old_state_value = old_state_value.clone() # TODO: if the advantage is gathered by forward, this introduces an - # overhead that we could easily reduce. + # overhead that we could easily reduce. target_return = tensordict.get( self.tensor_keys.value_target, None ) # TODO: None soon to be removed @@ -487,7 +490,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: loss_value, clip_fraction = _clip_value_loss( old_state_value, state_value, - self.clip_value.to(state_value.device), + self.clip_value, target_return, loss_value, self.loss_critic_type, @@ -541,6 +544,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) + device = _get_default_device(self) + hp["device"] = device + if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 17ab16cfefa..b0ba254d2b3 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -592,6 +592,13 @@ def _clip_value_loss( return loss_value, clip_fraction +def _get_default_device(net): + for p in net.parameters(): + return p.device + else: + return torch.get_default_device() + + def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimizer: """Groups multiple optimizers into a single one. diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index e396b7e1fcc..739fb9a018e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -37,6 +37,10 @@ vtrace_advantage_estimate, ) +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling try: from torch import vmap @@ -69,92 +73,6 @@ def new_func(self, *args, **kwargs): return new_func -def _call_value_nets( - value_net: TensorDictModuleBase, - data: TensorDictBase, - params: TensorDictBase, - next_params: TensorDictBase, - single_call: bool, - value_key: NestedKey, - detach_next: bool, - vmap_randomness: str = "error", -): - in_keys = value_net.in_keys - if single_call: - for i, name in enumerate(data.names): - if name == "time": - ndim = i + 1 - break - else: - ndim = None - if ndim is not None: - # get data at t and last of t+1 - idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) - idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) - idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False)[idx0], - ], - ndim - 1, - ) - else: - if RL_WARNINGS: - warnings.warn( - "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " - "This warning can be turned off by setting the environment variable RL_WARNINGS to False." - ) - ndim = data.ndim - idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) - idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - ndim - 1, - ) - - # next_params should be None or be identical to params - if next_params is not None and next_params is not params: - raise ValueError( - "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." - ) - if params is not None: - with params.to_module(value_net): - value_est = value_net(data_in).get(value_key) - else: - value_est = value_net(data_in).get(value_key) - value, value_ = value_est[idx], value_est[idx_] - else: - data_in = torch.stack( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - 0, - ) - if (params is not None) ^ (next_params is not None): - raise ValueError( - "params and next_params must be either both provided or not." - ) - elif params is not None: - params_stack = torch.stack([params, next_params], 0).contiguous() - data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( - data_in, params_stack - ) - else: - data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) - value_est = data_out.get(value_key) - value, value_ = value_est[0], value_est[1] - data.set(value_key, value) - data.set(("next", value_key), value_) - if detach_next: - value_ = value_.detach() - return value, value_ - - def _call_actor_net( actor_net: TensorDictModuleBase, data: TensorDictBase, @@ -432,6 +350,9 @@ def _next_value(self, tensordict, target_params, kwargs): @property def vmap_randomness(self): if self._vmap_randomness is None: + if is_dynamo_compiling(): + self._vmap_randomness = "different" + return "different" do_break = False for val in self.__dict__.values(): if isinstance(val, torch.nn.Module): @@ -467,6 +388,99 @@ def _get_time_dim(self, time_dim: int | None, data: TensorDictBase): return i return data.ndim - 1 + def _call_value_nets( + self, + data: TensorDictBase, + params: TensorDictBase, + next_params: TensorDictBase, + single_call: bool, + value_key: NestedKey, + detach_next: bool, + vmap_randomness: str = "error", + *, + value_net: TensorDictModuleBase | None = None, + ): + if value_net is None: + value_net = self.value_network + in_keys = value_net.in_keys + if single_call: + for i, name in enumerate(data.names): + if name == "time": + ndim = i + 1 + break + else: + ndim = None + if ndim is not None: + # get data at t and last of t+1 + idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) + idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) + idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False)[ + idx0 + ], + ], + ndim - 1, + ) + else: + if RL_WARNINGS: + warnings.warn( + "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " + "This warning can be turned off by setting the environment variable RL_WARNINGS to False." + ) + ndim = data.ndim + idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) + idx_ = (slice(None),) * (ndim - 1) + ( + slice(data.shape[ndim - 1], None), + ) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + ndim - 1, + ) + + # next_params should be None or be identical to params + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." + ) + if params is not None: + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) + else: + value_est = value_net(data_in).get(value_key) + value, value_ = value_est[idx], value_est[idx_] + else: + data_in = torch.stack( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + 0, + ) + if (params is not None) ^ (next_params is not None): + raise ValueError( + "params and next_params must be either both provided or not." + ) + elif params is not None: + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) + else: + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) + value_est = data_out.get(value_key) + value, value_ = value_est[0], value_est[1] + data.set(value_key, value) + data.set(("next", value_key), value_) + if detach_next: + value_ = value_.detach() + return value, value_ + class TD0Estimator(ValueEstimatorBase): """Temporal Difference (TD(0)) estimate of advantage function. @@ -623,8 +637,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -651,7 +664,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -837,8 +854,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -867,7 +883,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1063,8 +1083,7 @@ def forward( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1092,7 +1111,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1155,7 +1178,7 @@ class GAE(ValueEstimatorBase): pass detached parameters for functional modules. vectorized (bool, optional): whether to use the vectorized version of the - lambda return. Default is `True`. + lambda return. Default is `True` if not compiling. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` @@ -1205,7 +1228,7 @@ def __init__( value_network: TensorDictModule, average_gae: bool = False, differentiable: bool = False, - vectorized: bool = True, + vectorized: bool | None = None, skip_existing: bool | None = None, advantage_key: NestedKey = None, value_target_key: NestedKey = None, @@ -1229,6 +1252,16 @@ def __init__( self.vectorized = vectorized self.time_dim = time_dim + @property + def vectorized(self): + if is_dynamo_compiling(): + return False + return self._vectorized + + @vectorized.setter + def vectorized(self, value): + self._vectorized = value + @_self_set_skip_existing @_self_set_grad_enabled @dispatch @@ -1315,7 +1348,13 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + gamma, lmbda = self.gamma, self.lmbda + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1328,10 +1367,10 @@ def forward( with hold_out_net(self.value_network) if ( params is None and target_params is None ) else nullcontext(): + # with torch.no_grad(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1396,7 +1435,13 @@ def value_estimate( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + gamma, lmbda = self.gamma, self.lmbda + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1417,8 +1462,7 @@ def value_estimate( ) else nullcontext(): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, @@ -1668,7 +1712,11 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1682,8 +1730,7 @@ def forward( with hold_out_net(self.value_network): # we may still need to pass gradient, but we don't want to assign grads to # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, + value, next_value = self._call_value_nets( data=tensordict, params=params, next_params=target_params, diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index e20e1cbb2c9..15e5d56d6bf 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -12,6 +12,10 @@ import torch +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling __all__ = [ "generalized_advantage_estimate", @@ -147,7 +151,7 @@ def generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -181,19 +185,25 @@ def generalized_advantage_estimate( def _geom_series_like(t, r, thr): """Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`. - Drops all elements which are smaller than `thr`. + Drops all elements which are smaller than `thr` (unless in compile mode). """ - if isinstance(r, torch.Tensor): - r = r.item() - - if r == 0.0: - return torch.zeros_like(t) - elif r >= 1.0: - lim = t.numel() + if is_dynamo_compiling(): + if isinstance(r, torch.Tensor): + rs = r.expand_as(t) + else: + rs = torch.full_like(t, r) else: - lim = int(math.log(thr) / math.log(r)) + if isinstance(r, torch.Tensor): + r = r.item() + + if r == 0.0: + return torch.zeros_like(t) + elif r >= 1.0: + lim = t.numel() + else: + lim = int(math.log(thr) / math.log(r)) - rs = torch.full_like(t[:lim], r) + rs = torch.full_like(t[:lim], r) rs[0] = 1.0 rs = rs.cumprod(0) rs = rs.unsqueeze(-1) @@ -292,7 +302,7 @@ def vec_generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -391,7 +401,7 @@ def td0_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -435,7 +445,7 @@ def td0_return_estimate( """ if done is not None and terminated is None: - terminated = done + terminated = done.clone() warnings.warn( "done for td0_return_estimate is deprecated. Pass ``terminated`` instead." ) @@ -503,7 +513,7 @@ def td1_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) not_done = (~done).int() @@ -604,7 +614,7 @@ def td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -742,7 +752,7 @@ def vec_td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -824,7 +834,7 @@ def td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -934,7 +944,7 @@ def td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -1074,7 +1084,7 @@ def vec_td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -1228,7 +1238,7 @@ def vec_td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index ec1d33069a5..5eb678e01e1 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -301,10 +301,13 @@ def _fill_tensor(tensor): device=tensor.device, ) mask_expand = expand_right(mask, (*mask.shape, *tensor.shape[1:])) + # We need to use masked-scatter to accommodate vmap return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1)) + # empty_tensor[mask_expand] = tensor.reshape(-1) + # return empty_tensor if isinstance(tensor, TensorDictBase): - tensor = tensor.apply(_fill_tensor, batch_size=[*shape]) + tensor = tensor.apply(_fill_tensor, batch_size=list(shape)) else: tensor = _fill_tensor(tensor) if return_mask: From db7f08d76c0b1a99cd9fe5f3c586ecd879d379ad Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:40:57 +0000 Subject: [PATCH 2/3] [Refactor] compile compatibility improvements ghstack-source-id: 95f8241b56e42b80e828485cb5f377288bff6f5e Pull Request resolved: https://github.com/pytorch/rl/pull/2578 --- test/test_collector.py | 22 ---- torchrl/collectors/collectors.py | 5 +- torchrl/data/tensor_specs.py | 6 +- torchrl/envs/batched_envs.py | 11 +- torchrl/modules/distributions/continuous.py | 21 ++- torchrl/modules/distributions/utils.py | 13 +- .../modules/models/decision_transformer.py | 5 + torchrl/modules/tensordict_module/actors.py | 4 + torchrl/modules/tensordict_module/common.py | 3 +- .../modules/tensordict_module/exploration.py | 122 ++++++++++-------- torchrl/objectives/common.py | 18 ++- torchrl/objectives/cql.py | 1 + torchrl/objectives/crossq.py | 14 +- torchrl/objectives/decision_transformer.py | 8 +- torchrl/objectives/value/advantages.py | 50 ++++--- 15 files changed, 176 insertions(+), 127 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 7c185830a92..1309254ce2d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3172,28 +3172,6 @@ def make_and_test_policy( ) -@pytest.mark.parametrize( - "ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] -) -def test_no_stopiteration(ctype): - # Tests that there is no StopIteration raised and that the length of the collector is properly set - if ctype is SyncDataCollector: - envs = SerialEnv(16, CountingEnv) - else: - envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)] - - collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300) - try: - c_iter = iter(collector) - for i in range(len(collector)): # noqa: B007 - c = next(c_iter) - assert c is not None - assert i == 1 - finally: - collector.shutdown() - del collector - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index fe1a796ea2d..319722a552e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -147,7 +147,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): _iterator = None total_frames: int frames_per_batch: int - requested_frames_per_batch: int trust_policy: bool compiled_policy: bool cudagraphed_policy: bool @@ -306,7 +305,7 @@ def __class_getitem__(self, index): def __len__(self) -> int: if self.total_frames > 0: - return -(self.total_frames // -self.requested_frames_per_batch) + return -(self.total_frames // -self.frames_per_batch) raise RuntimeError("Non-terminating collectors do not have a length") @@ -701,7 +700,7 @@ def __init__( remainder = total_frames % frames_per_batch if remainder != 0 and RL_WARNINGS: warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})." f"This means {frames_per_batch - remainder} additional frames will be collected." "To silence this message, set the environment variable RL_WARNINGS to False." ) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 7fbfaab3280..2ef74bb4521 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2312,10 +2312,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self - self.space.device = dest_device + space = self.space.to(dest_device) return Bounded( - low=self.space.low, - high=self.space.high, + low=space.low, + high=space.high, shape=self.shape, device=dest_device, dtype=dest_dtype, diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 9e59e0f69d6..17bd28c8390 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1356,12 +1356,15 @@ def _start_workers(self) -> None: from torchrl.envs.env_creator import EnvCreator + num_threads = max( + 1, torch.get_num_threads() - self.num_workers + ) # 1 more thread for this proc + if self.num_threads is None: - self.num_threads = max( - 1, torch.get_num_threads() - self.num_workers - ) # 1 more thread for this proc + self.num_threads = num_threads - torch.set_num_threads(self.num_threads) + if self.num_threads != torch.get_num_threads(): + torch.set_num_threads(self.num_threads) if self._mp_start_method is not None: ctx = mp.get_context(self._mp_start_method) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 62b5df5d14b..6c200c15ee4 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -397,7 +397,6 @@ def __init__( event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True, - **kwargs, ): if not isinstance(loc, torch.Tensor): loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) @@ -683,6 +682,7 @@ def __init__( event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, + safe: bool = True, ): minmax_msg = "high value has been found to be equal or less than low value" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): @@ -695,12 +695,19 @@ def __init__( if not all(high > low): raise ValueError(minmax_msg) - t = SafeTanhTransform() - non_trivial_min = (isinstance(low, torch.Tensor) and (low != -1.0).any()) or ( - not isinstance(low, torch.Tensor) and low != -1.0 + if safe: + if is_dynamo_compiling(): + _err_compile_safetanh() + t = SafeTanhTransform() + else: + t = torch.distributions.TanhTransform() + non_trivial_min = is_dynamo_compiling or ( + (isinstance(low, torch.Tensor) and (low != -1.0).any()) + or (not isinstance(low, torch.Tensor) and low != -1.0) ) - non_trivial_max = (isinstance(high, torch.Tensor) and (high != 1.0).any()) or ( - not isinstance(high, torch.Tensor) and high != 1.0 + non_trivial_max = is_dynamo_compiling or ( + (isinstance(high, torch.Tensor) and (high != 1.0).any()) + or (not isinstance(high, torch.Tensor) and high != 1.0) ) self.non_trivial = non_trivial_min or non_trivial_max @@ -778,7 +785,7 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: def _err_compile_safetanh(): raise RuntimeError( "safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. " - "To deactivate it, pass safe_tanh=False. " + " To deactivate it, pass safe_tanh=False. " "If you are using a ProbabilisticTensorDictModule, this can be done via " "`distribution_kwargs={'safe_tanh': False}`. " "See https://github.com/pytorch/pytorch/issues/133529 for more details." diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 546d93cb228..8c332c4efed 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -9,6 +9,11 @@ from torch import autograd, distributions as d from torch.distributions import Independent, Transform, TransformedDistribution +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]: if isinstance(elt, torch.Tensor): @@ -40,10 +45,12 @@ class FasterTransformedDistribution(TransformedDistribution): __doc__ = __doc__ + TransformedDistribution.__doc__ def __init__(self, base_distribution, transforms, validate_args=None): + if is_dynamo_compiling(): + return super().__init__( + base_distribution, transforms, validate_args=validate_args + ) if isinstance(transforms, Transform): - self.transforms = [ - transforms, - ] + self.transforms = [transforms] elif isinstance(transforms, list): raise ValueError("Make a ComposeTransform first.") else: diff --git a/torchrl/modules/models/decision_transformer.py b/torchrl/modules/models/decision_transformer.py index 8eb72f1f9ea..8a20ad2eba8 100644 --- a/torchrl/modules/models/decision_transformer.py +++ b/torchrl/modules/models/decision_transformer.py @@ -90,7 +90,12 @@ def __init__( state_dim, action_dim, config: dict | DTConfig = None, + device: torch.device | None = None, ): + if device is not None: + with torch.device(device): + return self.__init__(state_dim, action_dim, config) + if not _has_transformers: raise ImportError( "transformers is not installed. Please install it with `pip install transformers`." diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 2ad5918861d..888729835b5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -1783,6 +1783,7 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper): For example for an observation input of shape [batch_size, context, obs_dim] with context=20 and inference_context=5, the first 15 entries of the context will be masked. Defaults to 5. spec (Optional[TensorSpec]): The spec of the input TensorDict. If None, it will be inferred from the policy module. + device (torch.device, optional): if provided, the device where the buffers / specs will be placed. Examples: >>> import torch @@ -1836,6 +1837,7 @@ def __init__( *, inference_context: int = 5, spec: Optional[TensorSpec] = None, + device: torch.device | None = None, ): super().__init__(policy) self.observation_key = "observation" @@ -1857,6 +1859,8 @@ def __init__( self._spec[self.action_key] = None else: self._spec = Composite({key: None for key in policy.out_keys}) + if device is not None: + self._spec = self._spec.to(device) self.checked = False @property diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 4018589bfa1..f722bc2bd7d 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -69,7 +69,8 @@ def _forward_hook_safe_action(module, tensordict_in, tensordict_out): keys = [out_key] values = [spec] else: - keys = list(spec.keys(True, True)) + # Make dynamo happy with the list creation + keys = [key for key in spec.keys(True, True)] # noqa: C416 values = [spec[key] for key in keys] for _spec, _key in zip(values, keys): if _spec is None: diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 2ccdf599f2d..df947236970 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -133,11 +133,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -150,7 +153,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: action_key = self.action_key out = action_tensordict.get(action_key) - eps = self.eps.item() + eps = self.eps cond = torch.rand(action_tensordict.shape, device=out.device) < eps # cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps) cond = expand_as_right(cond, out) @@ -307,19 +310,20 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.sigma.data[0] = max( - self.sigma_end.item(), - ( - self.sigma - - (self.sigma_init - self.sigma_end) / self.annealing_num_steps - ).item(), + self.sigma.data.copy_( + torch.maximum( + self.sigma_end( + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps + ), + ) ) def _add_noise(self, action: torch.Tensor) -> torch.Tensor: - sigma = self.sigma.item() + sigma = self.sigma noise = torch.normal( - mean=torch.ones(action.shape) * self.mean.item(), - std=torch.ones(action.shape) * self.std.item(), + mean=torch.ones(action.shape) * self.mean, + std=torch.ones(action.shape) * self.std, ).to(action.device) action = action + noise * sigma spec = self.spec @@ -365,6 +369,9 @@ class AdditiveGaussianModule(TensorDictModuleBase): its output spec will be of type Composite. One needs to know where to find the action spec. default: "action" + safe (bool): if ``True``, actions that are out of bounds given the action specs will be projected in the space + given the :obj:`TensorSpec.project` heuristic. + default: True .. note:: It is @@ -386,6 +393,7 @@ def __init__( std: float = 1.0, *, action_key: Optional[NestedKey] = "action", + safe: bool = True, ): if not isinstance(sigma_init, float): warnings.warn("eps_init should be a float.") @@ -410,7 +418,9 @@ def __init__( else: raise RuntimeError("spec cannot be None.") self._spec = spec - self.register_forward_hook(_forward_hook_safe_action) + self.safe = safe + if self.safe: + self.register_forward_hook(_forward_hook_safe_action) @property def spec(self): @@ -426,19 +436,21 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): - self.sigma.data[0] = max( - self.sigma_end.item(), - ( - self.sigma - - (self.sigma_init - self.sigma_end) / self.annealing_num_steps - ).item(), + self.sigma.data.copy_( + torch.maximum( + self.sigma_end, + ( + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps + ), + ) ) def _add_noise(self, action: torch.Tensor) -> torch.Tensor: - sigma = self.sigma.item() + sigma = self.sigma noise = torch.normal( - mean=torch.ones(action.shape) * self.mean.item(), - std=torch.ones(action.shape) * self.std.item(), + mean=torch.ones(action.shape) * self.mean, + std=torch.ones(action.shape) * self.std, ).to(action.device) action = action + noise * sigma spec = self.spec[self.action_key] @@ -636,12 +648,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): if self.annealing_num_steps > 0: - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) else: raise ValueError( @@ -664,9 +678,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." ) - tensordict = self.ou.add_sample( - tensordict, self.eps.item(), is_init=is_init - ) + tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init) return tensordict @@ -730,6 +742,10 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): default: "action" is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps. default: "is_init" + safe (boolean, optional): if False, the TensorSpec can be None. If it + is set to False but the spec is passed, the projection will still + happen. + Default is True. Examples: >>> import torch @@ -772,6 +788,7 @@ def __init__( *, action_key: Optional[NestedKey] = "action", is_init_key: Optional[NestedKey] = "is_init", + safe: bool = True, ): super().__init__() @@ -815,7 +832,9 @@ def __init__( self._spec.update(ou_specs) if len(set(self.out_keys)) != len(self.out_keys): raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}") - self.register_forward_hook(_forward_hook_safe_action) + self.safe = safe + if self.safe: + self.register_forward_hook(_forward_hook_safe_action) @property def spec(self): @@ -830,12 +849,14 @@ def step(self, frames: int = 1) -> None: """ for _ in range(frames): if self.annealing_num_steps > 0: - self.eps.data[0] = max( - self.eps_end.item(), - ( - self.eps - - (self.eps_init - self.eps_end) / self.annealing_num_steps - ).item(), + self.eps.data.copy_( + torch.maximum( + self.eps_end, + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ), + ) ) else: raise ValueError( @@ -857,9 +878,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." ) - tensordict = self.ou.add_sample( - tensordict, self.eps.item(), is_init=is_init - ) + tensordict = self.ou.add_sample(tensordict, self.eps, is_init=is_init) return tensordict @@ -923,11 +942,12 @@ def _make_noise_pair( tensordict.set(self.noise_key, noise) tensordict.set(self.steps_key, steps) else: - noise = tensordict.get(self.noise_key) - steps = tensordict.get(self.steps_key) + # We must clone for cudagraph, otherwise the same tensor may re-enter the compiled region + noise = tensordict.get(self.noise_key).clone() + steps = tensordict.get(self.steps_key).clone() if is_init is not None: - noise[is_init] = 0 - steps[is_init] = 0 + noise = torch.masked_fill(noise, is_init, 0) + steps = torch.masked_fill(steps, is_init, 0) return noise, steps def add_sample( @@ -977,9 +997,9 @@ def add_sample( * np.sqrt(self.dt) * torch.randn_like(prev_noise) ) - tensordict.set_(self.noise_key, noise - self.x0) - tensordict.set_(self.key, tensordict.get(self.key) + eps * noise) - tensordict.set_(self.steps_key, n_steps + 1) + tensordict.set(self.noise_key, noise - self.x0) + tensordict.set(self.key, tensordict.get(self.key) + eps * noise) + tensordict.set(self.steps_key, n_steps + 1) return tensordict def current_sigma(self, n_steps: torch.Tensor) -> torch.Tensor: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index be05e2fa66b..57310a5fc3d 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple +import torch from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -515,7 +516,22 @@ def _default_value_estimator(self): from :obj:`torchrl.objectives.utils.DEFAULT_VALUE_FUN_PARAMS`. """ - self.make_value_estimator(self.default_value_estimator) + self.make_value_estimator( + self.default_value_estimator, device=self._default_device + ) + + @property + def _default_device(self) -> torch.device | None: + """A util to find the default device. + + Returns ``None`` if parameters are spread across multiple devices. + """ + devices = set() + for p in self.parameters(): + devices.add(p.device) + if len(devices) == 1: + return list(devices)[0] + return None def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): """Value-function constructor. diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 14c5d54a61d..191096e7492 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -375,6 +375,7 @@ def __init__( ) self._make_vmap() self.reduction = reduction + _ = self.target_entropy def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index c555b7a609c..ca6559ac5b8 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -340,6 +340,8 @@ def __init__( self._action_spec = action_spec self._make_vmap() self.reduction = reduction + # init target entropy + _ = self.target_entropy def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -513,15 +515,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: **metadata_actor, **value_metadata, } - td_out = TensorDict(out, []) - # td_out = td_out.named_apply( - # lambda name, value: ( - # _reduce(value, reduction=self.reduction) - # if name.startswith("loss_") - # else value - # ), - # batch_size=[], - # ) + td_out = TensorDict(out) return td_out @property @@ -543,6 +537,7 @@ def actor_loss( Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action. """ + tensordict = tensordict.copy() with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -584,6 +579,7 @@ def qvalue_loss( Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling. """ + tensordict = tensordict.copy() # # compute next action with torch.no_grad(): with set_exploration_type( diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 632d3e615b6..1b1f0aa4e0b 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -292,6 +292,7 @@ def __init__( *, loss_function: str = "l2", reduction: str = None, + device: torch.device | None = None, ) -> None: self._in_keys = None self._out_keys = None @@ -343,7 +344,7 @@ def out_keys(self, values): def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets - tensordict = tensordict.clone(False) + tensordict = tensordict.copy() target_actions = tensordict.get(self.tensor_keys.action_target).detach() with self.actor_network_params.to_module(self.actor_network): @@ -356,8 +357,5 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: loss_function=self.loss_function, ) loss = _reduce(loss, reduction=self.reduction) - out = { - "loss": loss, - } - td_out = TensorDict(out, []) + td_out = TensorDict(loss=loss) return td_out diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 739fb9a018e..8ac64bf3d21 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -197,6 +197,8 @@ def forward( to be passed to the functional value network module. target_params (TensorDictBase, optional): A nested TensorDict containing the target params to be passed to the functional value network module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. Returns: An updated TensorDict with an advantage and a value_error keys as defined in the constructor. @@ -213,8 +215,14 @@ def __init__( advantage_key: NestedKey = None, value_target_key: NestedKey = None, value_key: NestedKey = None, + device: torch.device | None = None, ): super().__init__() + if device is None: + device = torch.get_default_device() + # this is saved for tracking only and should not be used to cast anything else than buffers during + # init. + self._device = device self._tensor_keys = None self.differentiable = differentiable self.skip_existing = skip_existing @@ -518,7 +526,8 @@ class TD0Estimator(ValueEstimatorBase): of the advantage entry. Defaults to ``"value_target"``. value_key (str or tuple of str, optional): [Deprecated] the value key to read from the input tensordict. Defaults to ``"state_value"``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. """ @@ -544,8 +553,9 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards @_self_set_skip_existing @@ -668,7 +678,6 @@ def value_estimate( if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -727,7 +736,8 @@ class TD1Estimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -761,8 +771,9 @@ def __init__( value_key=value_key, shifted=shifted, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) self.average_rewards = average_rewards self.time_dim = time_dim @@ -887,7 +898,6 @@ def value_estimate( if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -951,7 +961,8 @@ class TDLambdaEstimator(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -987,9 +998,10 @@ def __init__( value_key=value_key, skip_existing=skip_existing, shifted=shifted, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) self.average_rewards = average_rewards self.vectorized = vectorized self.time_dim = time_dim @@ -1115,7 +1127,6 @@ def value_estimate( if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1197,7 +1208,8 @@ class GAE(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension marked with the ``"time"`` name if any, and to the last dimension @@ -1245,9 +1257,10 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("gamma", torch.tensor(gamma, device=self._device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=self._device)) self.average_gae = average_gae self.vectorized = vectorized self.time_dim = time_dim @@ -1530,7 +1543,8 @@ class VTrace(ValueEstimatorBase): estimation, for instance) and (2) when the parameters used at time ``t`` and ``t+1`` are identical (which is not the case when target parameters are to be used). Defaults to ``False``. - device (torch.device, optional): device of the module. + device (torch.device, optional): the device where the buffers will be instantiated. + Defaults to ``torch.get_default_device()``. time_dim (int, optional): the dimension corresponding to the time in the input tensordict. If not provided, defaults to the dimension markes with the ``"time"`` name if any, and to the last dimension @@ -1575,13 +1589,14 @@ def __init__( value_target_key=value_target_key, value_key=value_key, skip_existing=skip_existing, + device=device, ) if not isinstance(gamma, torch.Tensor): - gamma = torch.tensor(gamma, device=device) + gamma = torch.tensor(gamma, device=self._device) if not isinstance(rho_thresh, torch.Tensor): - rho_thresh = torch.tensor(rho_thresh, device=device) + rho_thresh = torch.tensor(rho_thresh, device=self._device) if not isinstance(c_thresh, torch.Tensor): - c_thresh = torch.tensor(c_thresh, device=device) + c_thresh = torch.tensor(c_thresh, device=self._device) self.register_buffer("gamma", gamma) self.register_buffer("rho_thresh", rho_thresh) @@ -1716,7 +1731,6 @@ def forward( if self.gamma.device != device: self.gamma = self.gamma.to(device) gamma = self.gamma - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) From 408cf7d04705f18e6a1d58f4b2b7255d67a443d9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:40:58 +0000 Subject: [PATCH 3/3] [BugFix] requested_frames_per_batch in distributed collectors ghstack-source-id: 49289de6956460d9aed13d982eb8003eafc35118 Pull Request resolved: https://github.com/pytorch/rl/pull/2579 --- torchrl/collectors/distributed/generic.py | 1 + torchrl/collectors/distributed/rpc.py | 1 + torchrl/collectors/distributed/sync.py | 1 + 3 files changed, 3 insertions(+) diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 729b8a48171..5e49ad95f49 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -448,6 +448,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 73247df4b0c..98220727a45 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -304,6 +304,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 481fb70cc31..b90111763d7 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -315,6 +315,7 @@ def __init__( self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch + self.requested_frames_per_batch = frames_per_batch self.device = device self.storing_device = storing_device