From 90b3a603fbf6bca45e96ac83b2d8ffa776e1b1f5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 4 Oct 2024 15:53:10 +0100 Subject: [PATCH 01/24] Update [ghstack-poisoned] --- benchmarks/test_objectives_benchmarks.py | 2 +- sota-implementations/a2c/a2c_atari.py | 155 +++++++++++++------- sota-implementations/a2c/a2c_mujoco.py | 120 +++++++++------ sota-implementations/a2c/config_atari.yaml | 4 + sota-implementations/a2c/config_mujoco.yaml | 4 + sota-implementations/a2c/utils_atari.py | 22 +-- sota-implementations/a2c/utils_mujoco.py | 18 ++- torchrl/_utils.py | 11 +- torchrl/envs/utils.py | 2 +- torchrl/modules/distributions/continuous.py | 99 +++++++------ torchrl/objectives/a2c.py | 19 ++- torchrl/objectives/utils.py | 7 + torchrl/objectives/value/functional.py | 30 ++-- torchrl/objectives/value/utils.py | 5 +- 14 files changed, 312 insertions(+), 186 deletions(-) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index d07b40595bc..7ff3f23b7a5 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -50,7 +50,7 @@ ) # Anything from 2.5, incl. nightlies, allows for fullgraph -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) def set_default_device(): cur_device = torch.get_default_device() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 42ef4301c4d..46aaeacd517 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra +from tensordict.nn import CudaGraphModule from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder @@ -15,9 +16,9 @@ def main(cfg: "DictConfig"): # noqa: F821 import torch.optim import tqdm - from tensordict import TensorDict + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss @@ -25,7 +26,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -35,28 +40,17 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device) actor, critic, critic_head = ( actor.to(device), critic.to(device), critic_head.to(device), ) - # Create collector - collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), - policy=actor, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) - # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyTensorStorage(frames_per_batch, device=device), sampler=sampler, batch_size=mini_batch_size, ) @@ -83,9 +77,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizer optim = torch.optim.Adam( loss_module.parameters(), - lr=cfg.optim.lr, + lr=torch.tensor(cfg.optim.lr, device=device), weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, + capturable=device.type == "cuda", ) # Create logger @@ -115,6 +110,57 @@ def main(cfg: "DictConfig"): # noqa: F821 ) test_env.eval() + # update function + def update(batch, max_grad_norm=cfg.optim.max_grad_norm): + # Forward pass A2C loss + loss = loss_module(batch) + + loss_sum = loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + loss_sum.backward() + gn = torch.nn.utils.clip_grad_norm_( + loss_module.parameters(), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad(set_to_none=True) + + return ( + loss.select("loss_critic", "loss_entropy", "loss_objective") + .detach() + .set("grad_norm", gn) + ) + + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = None + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + actor = torch.compile(actor, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.loss.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + actor = CudaGraphModule(actor) + adv_module = CudaGraphModule(adv_module) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + policy_device=device, + ) + # Main loop collected_frames = 0 num_network_updates = 0 @@ -122,9 +168,13 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + lr = cfg.optim.lr sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + data = next(c_iter) log_info = {} sampling_time = time.time() - sampling_start @@ -144,59 +194,52 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict({}, batch_size=[num_mini_batches]) + losses = [] training_start = time.time() # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) - - # Update the networks - optim.step() - optim.zero_grad() - + with timeit("emptying"): + data_buffer.empty() + with timeit("extending"): + data_buffer.extend(data_reshape) + + with timeit("optim"): + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"].copy_(lr * alpha) + + num_network_updates += 1 + + with timeit("optim - update"): + loss = update(batch) + losses.append(loss) + + if i % 200 == 0: + timeit.print() + timeit.erase() # Get training losses training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() + for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg.optim.lr, + "train/lr": lr * alpha, "train/sampling_time": sampling_time, "train/training_time": training_time, + **timeit.todict(prefix="time"), } ) diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 2b390d39d2a..af6609a1f7b 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra +from tensordict.nn import CudaGraphModule from torchrl._utils import logger as torchrl_logger from torchrl.record import VideoRecorder @@ -15,9 +16,8 @@ def main(cfg: "DictConfig"): # noqa: F821 import torch.optim import tqdm - from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss @@ -26,31 +26,25 @@ def main(cfg: "DictConfig"): # noqa: F821 from utils_mujoco import eval_model, make_env, make_ppo_models # Define paper hyperparameters - device = "cpu" if not torch.cuda.device_count() else "cuda" + + device = cfg.loss.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) + num_mini_batches = cfg.collector.frames_per_batch // cfg.loss.mini_batch_size total_network_updates = ( cfg.collector.total_frames // cfg.collector.frames_per_batch ) * num_mini_batches # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.collector.frames_per_batch, - total_frames=cfg.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) + actor, critic = make_ppo_models(cfg.env.env_name, device=device) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch), + storage=LazyTensorStorage(cfg.collector.frames_per_batch, device=device), sampler=sampler, batch_size=cfg.loss.mini_batch_size, ) @@ -71,8 +65,16 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr) + actor_optim = torch.optim.Adam( + actor.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) + critic_optim = torch.optim.Adam( + critic.parameters(), + lr=torch.tensor(cfg.optim.lr, device=device), + capturable=device.type == "cuda", + ) # Create logger logger = None @@ -99,7 +101,55 @@ def main(cfg: "DictConfig"): # noqa: F821 logger, tag=f"rendered/{cfg.env.env_name}", in_keys=["pixels"] ), ) + + def update(batch): + # Forward pass A2C loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss.get("loss_entropy", 0.0) + + # Backward pass + (actor_loss + critic_loss).backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + + actor_optim.zero_grad(set_to_none=True) + critic_optim.zero_grad(set_to_none=True) + return loss.select("loss_critic", "loss_objective").detach() # , "loss_entropy" + + if cfg.loss.compile: + compile_mode = cfg.loss.compile_mode + if compile_mode in ("", None): + if cfg.loss.cudagraphs: + compile_mode = None + else: + compile_mode = "reduce-overhead" + + update = torch.compile(update, mode=compile_mode) + actor = torch.compile(actor, mode=compile_mode) + adv_module = torch.compile(adv_module, mode=compile_mode) + + if cfg.loss.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10) + actor = CudaGraphModule(update, warmup=10) + adv_module = CudaGraphModule(adv_module) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.collector.frames_per_batch, + total_frames=cfg.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + trust_policy=True, + ) + test_env.eval() + lr = cfg.optim.lr # Main loop collected_frames = 0 @@ -128,7 +178,7 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - losses = TensorDict({}, batch_size=[num_mini_batches]) + losses = [] training_start = time.time() # Compute GAE @@ -139,42 +189,24 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update the data buffer data_buffer.extend(data_reshape) - for k, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) + for batch in data_buffer: # Linearly decrease the learning rate and clip epsilon alpha = 1.0 if cfg.optim.anneal_lr: alpha = 1 - (num_network_updates / total_network_updates) for group in actor_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha + group["lr"].copy_(lr * alpha) for group in critic_optim.param_groups: - group["lr"] = cfg.optim.lr * alpha + group["lr"].copy_(lr * alpha) num_network_updates += 1 - # Forward pass A2C loss - loss = loss_module(batch) - losses[k] = loss.select( - "loss_critic", "loss_objective" # , "loss_entropy" - ).detach() - critic_loss = loss["loss_critic"] - actor_loss = loss["loss_objective"] # + loss["loss_entropy"] - - # Backward pass - actor_loss.backward() - critic_loss.backward() - - # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + loss = update(batch) + losses.append(loss) # Get training losses training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + losses = torch.stack(losses).float().mean() for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index dd0f43b52cb..5a7586ee95d 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -34,3 +34,7 @@ loss: critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 + compile: False + compile_mode: + cudagraphs: False + device: diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index 03a0bde32c5..9219d557bf8 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -31,3 +31,7 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + compile: False + compile_mode: + cudagraphs: False + device: diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 6a09ff715e4..bf7e23cd8f9 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -86,7 +86,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): # -------------------------------------------------------------------- -def make_ppo_modules_pixels(proof_environment): +def make_ppo_modules_pixels(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["pixels"].shape @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment): num_outputs = proof_environment.action_spec.shape distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), } # Define input keys @@ -113,14 +113,16 @@ def make_ppo_modules_pixels(proof_environment): num_cells=[32, 64, 64], kernel_sizes=[8, 4, 3], strides=[4, 2, 1], + device=device, ) - common_cnn_output = common_cnn(torch.ones(input_shape)) + common_cnn_output = common_cnn(torch.ones(input_shape, device=device)) common_mlp = MLP( in_features=common_cnn_output.shape[-1], activation_class=torch.nn.ReLU, activate_last_layer=True, out_features=512, num_cells=[], + device=device, ) common_mlp_output = common_mlp(common_cnn_output) @@ -137,6 +139,7 @@ def make_ppo_modules_pixels(proof_environment): out_features=num_outputs, activation_class=torch.nn.ReLU, num_cells=[], + device=device, ) policy_module = TensorDictModule( module=policy_net, @@ -148,7 +151,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -161,6 +164,7 @@ def make_ppo_modules_pixels(proof_environment): in_features=common_mlp_output.shape[-1], out_features=1, num_cells=[], + device=device, ) value_module = ValueOperator( value_net, @@ -170,11 +174,11 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): proof_environment = make_parallel_env(env_name, 1, device="cpu") common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment + proof_environment, device=device ) # Wrap modules in a single ActorCritic operator @@ -185,8 +189,8 @@ def make_ppo_models(env_name): ) with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) + td = proof_environment.fake_tensordict().expand(1) + td = actor_critic(td.to(device)) del td actor = actor_critic.get_policy_operator() diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index 996706ce4f9..6fc82890b7d 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -48,7 +48,7 @@ def make_env( # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment): num_outputs = proof_environment.action_spec.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec.space.low, - "high": proof_environment.action_spec.space.high, + "low": proof_environment.action_spec.space.low.to(device), + "high": proof_environment.action_spec.space.high.to(device), "tanh_loc": False, } @@ -68,6 +68,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -79,7 +80,9 @@ def make_ppo_models_state(proof_environment): # Add state-independent normal scale policy_mlp = torch.nn.Sequential( policy_mlp, - AddStateIndependentNormalScale(proof_environment.action_spec.shape[-1]), + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], device=device + ), ) # Add probabilistic sampling of the actions @@ -90,7 +93,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=Composite(action=proof_environment.action_spec), + spec=Composite(action=proof_environment.action_spec.to(device)), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -103,6 +106,7 @@ def make_ppo_models_state(proof_environment): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights @@ -120,9 +124,9 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, device): proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) + actor, critic = make_ppo_models_state(proof_environment, device=device) return actor, critic diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 0bfdc7b07ce..2833462d915 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -99,10 +99,15 @@ def print(prefix=None): # noqa: T202 logger.info(" -- ".join(strings)) @classmethod - def todict(cls, percall=True): + def todict(cls, percall=True, prefix=None): + def _make_key(key): + if prefix: + return f"{prefix}/{key}" + return key + if percall: - return {key: val[0] for key, val in cls._REG.items()} - return {key: val[1] for key, val in cls._REG.items()} + return {_make_key(key): val[0] for key, val in cls._REG.items()} + return {_make_key(key): val[1] for key, val in cls._REG.items()} @staticmethod def erase(): diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 9701e96ef62..6a080abae48 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -1436,7 +1436,7 @@ def _make_compatible_policy( env_maker=None, env_maker_kwargs=None, ): - if trust_policy: + if trust_policy or isinstance(policy, torch._dynamo.eval_frame.OptimizedModule): return policy if policy is None: input_spec = None diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 33dfe6aa1df..57ae97b76ed 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -12,11 +12,12 @@ import numpy as np import torch from torch import distributions as D, nn +from torch.cuda import is_current_stream_capturing try: - from torch.compiler import assume_constant_result + from torch.compiler import is_dynamo_compiling except ImportError: - from torch._dynamo import assume_constant_result + from torch._dynamo import is_compiling as is_dynamo_compiling from torch.distributions import constraints from torch.distributions.transforms import _InverseTransform @@ -37,6 +38,12 @@ D.Distribution.set_default_validate_args(False) +def _maybe_is_current_stream_capturing(): + if not torch.cuda.is_available(): + return False + return is_current_stream_capturing() + + class IndependentNormal(D.Independent): """Implements a Normal distribution with location scaling. @@ -401,11 +408,11 @@ def _warn_minmax(self): def __init__( self, - loc: torch.Tensor, - scale: torch.Tensor, - upscale: Union[torch.Tensor, Number] = 5.0, - low: Union[torch.Tensor, Number] = -1.0, - high: Union[torch.Tensor, Number] = 1.0, + loc: torch.Tensor | Number | np.ndarray, + scale: torch.Tensor | Number | np.ndarray, + upscale: torch.Tensor | Number = 5.0, + low: torch.Tensor | Number | np.ndarray = -1.0, + high: torch.Tensor | Number | np.ndarray = 1.0, event_dims: int | None = None, tanh_loc: bool = False, safe_tanh: bool = True, @@ -419,62 +426,75 @@ def __init__( low = kwargs.pop("min") if not isinstance(loc, torch.Tensor): - loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) + loc = torch.as_tensor( + loc, + dtype=torch.get_default_dtype(), + device=getattr(scale, "device", None), + ) + self.device = loc.device + if not isinstance(scale, torch.Tensor): - scale = torch.as_tensor(scale, dtype=torch.get_default_dtype()) + scale = torch.as_tensor( + scale, dtype=torch.get_default_dtype(), device=self.device + ) if event_dims is None: event_dims = min(1, loc.ndim) err_msg = "TanhNormal high values must be strictly greater than low values" - if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): + compiling = is_dynamo_compiling() or _maybe_is_current_stream_capturing() + if compiling: + pass + elif isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): if not (high > low).all(): raise RuntimeError(err_msg) elif isinstance(high, Number) and isinstance(low, Number): if not high > low: raise RuntimeError(err_msg) - else: - if not all(high > low): - raise RuntimeError(err_msg) + elif not all(high > low): + raise RuntimeError(err_msg) - high = torch.as_tensor(high, device=loc.device) - low = torch.as_tensor(low, device=loc.device) - self.non_trivial_max = (high != 1.0).any() + if not isinstance(high, (torch.Tensor, Number)): + high = torch.as_tensor(high, device=self.device) + if not isinstance(low, (torch.Tensor, Number)): + low = torch.as_tensor(low, device=self.device) + + if isinstance(high, torch.Tensor) and high.device != self.device: + high = high.to(loc.device) + if isinstance(low, torch.Tensor) and low.device != self.device: + low = low.to(loc.device) - self.non_trivial_min = (low != -1.0).any() + if compiling: + self.non_trivial_max = True + self.non_trivial_min = True + else: + self.non_trivial_max = high != 1.0 + if isinstance(self.non_trivial_max, torch.Tensor): + self.non_trivial_max = self.non_trivial_max.any() + self.non_trivial_min = low != -1.0 + if isinstance(self.non_trivial_min, torch.Tensor): + self.non_trivial_min = self.non_trivial_min.any() self.tanh_loc = tanh_loc self._event_dims = event_dims - self.device = loc.device self.upscale = ( upscale if not isinstance(upscale, torch.Tensor) else upscale.to(self.device) ) - if isinstance(high, torch.Tensor): - high = high.to(loc.device) - if isinstance(low, torch.Tensor): - low = low.to(loc.device) self.low = low self.high = high if safe_tanh: - if torch.compiler.is_dynamo_compiling(): - _err_compile_safetanh() t = SafeTanhTransform() else: t = D.TanhTransform() - # t = D.TanhTransform() - if torch.compiler.is_dynamo_compiling() or ( - self.non_trivial_max or self.non_trivial_min - ): + if self.non_trivial_max or self.non_trivial_min: t = _PatchedComposeTransform( [ t, - _PatchedAffineTransform( - loc=(high + low) / 2, scale=(high - low) / 2 - ), + D.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2), ] ) self._t = t @@ -495,9 +515,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: if self.tanh_loc: loc = (loc / self.upscale).tanh() * self.upscale # loc must be rescaled if tanh_loc - if torch.compiler.is_dynamo_compiling() or ( - self.non_trivial_max or self.non_trivial_min - ): + if self.non_trivial_max or self.non_trivial_min: loc = loc + (self.high - self.low) / 2 + self.low self.loc = loc self.scale = scale @@ -816,16 +834,3 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: uniform_sample_delta = _uniform_sample_delta - - -def _err_compile_safetanh(): - raise RuntimeError( - "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass" - "safe_tanh=False. " - "If you are using a ProbabilisticTensorDictModule, this can be done via " - "`distribution_kwargs={'safe_tanh': False}`. " - "See https://github.com/pytorch/pytorch/issues/133529 for more details." - ) - - -_warn_compile_safetanh = assume_constant_result(_err_compile_safetanh) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index c823788b4c2..12620812bf8 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -26,6 +26,7 @@ _cache_values, _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, + _get_default_device, _reduce, default_value_kwargs, distance_loss, @@ -316,10 +317,7 @@ def __init__( self.entropy_bonus = entropy_bonus and entropy_coef self.reduction = reduction - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + device = _get_default_device(self) self.register_buffer( "entropy_coef", torch.as_tensor(entropy_coef, device=device) @@ -347,7 +345,11 @@ def __init__( raise ValueError( f"clip_value must be a float or a scalar tensor, got {clip_value}." ) - self.register_buffer("clip_value", clip_value) + self.register_buffer( + "clip_value", torch.as_tensor(clip_value, device=device) + ) + else: + self.clip_value = None @property def functional(self): @@ -456,7 +458,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: old_state_value = old_state_value.clone() # TODO: if the advantage is gathered by forward, this introduces an - # overhead that we could easily reduce. + # overhead that we could easily reduce. target_return = tensordict.get( self.tensor_keys.value_target, None ) # TODO: None soon to be removed @@ -487,7 +489,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> Tuple[torch.Tensor, float]: loss_value, clip_fraction = _clip_value_loss( old_state_value, state_value, - self.clip_value.to(state_value.device), + self.clip_value, target_return, loss_value, self.loss_critic_type, @@ -541,6 +543,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp = dict(default_value_kwargs(value_type)) hp.update(hyperparams) + device = _get_default_device(self) + hp["device"] = device + if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 66eae215e54..3bd4be5dddb 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -549,3 +549,10 @@ def _clip_value_loss( # Chose the most pessimistic value prediction between clipped and non-clipped loss_value = torch.max(loss_value, loss_value_clipped) return loss_value, clip_fraction + + +def _get_default_device(net): + for p in net.parameters(): + return p.device + else: + return torch.get_default_device() diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index ddd688610c2..bc2e9ea39f7 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -12,6 +12,10 @@ import torch +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling __all__ = [ "generalized_advantage_estimate", @@ -181,19 +185,25 @@ def generalized_advantage_estimate( def _geom_series_like(t, r, thr): """Creates a geometric series of the form [1, gammalmbda, gammalmbda**2] with the shape of `t`. - Drops all elements which are smaller than `thr`. + Drops all elements which are smaller than `thr` (unless in compile mode). """ - if isinstance(r, torch.Tensor): - r = r.item() - - if r == 0.0: - return torch.zeros_like(t) - elif r >= 1.0: - lim = t.numel() + if is_dynamo_compiling(): + if isinstance(r, torch.Tensor): + rs = r.expand_as(t) + else: + rs = torch.full_like(t, r) else: - lim = int(math.log(thr) / math.log(r)) + if isinstance(r, torch.Tensor): + r = r.item() + + if r == 0.0: + return torch.zeros_like(t) + elif r >= 1.0: + lim = t.numel() + else: + lim = int(math.log(thr) / math.log(r)) - rs = torch.full_like(t[:lim], r) + rs = torch.full_like(t[:lim], r) rs[0] = 1.0 rs = rs.cumprod(0) rs = rs.unsqueeze(-1) diff --git a/torchrl/objectives/value/utils.py b/torchrl/objectives/value/utils.py index ec1d33069a5..7910611e36d 100644 --- a/torchrl/objectives/value/utils.py +++ b/torchrl/objectives/value/utils.py @@ -301,7 +301,10 @@ def _fill_tensor(tensor): device=tensor.device, ) mask_expand = expand_right(mask, (*mask.shape, *tensor.shape[1:])) - return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1)) + # return torch.where(mask_expand, tensor, 0.0) + # return torch.masked_scatter(empty_tensor, mask_expand, tensor.reshape(-1)) + empty_tensor[mask_expand] = tensor.reshape(-1) + return empty_tensor if isinstance(tensor, TensorDictBase): tensor = tensor.apply(_fill_tensor, batch_size=[*shape]) From c432b0c458d3c6d1d042ea0e176c3aea25b4700a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 4 Oct 2024 15:56:08 +0100 Subject: [PATCH 02/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 1 + sota-implementations/a2c/a2c_mujoco.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 46aaeacd517..1c2d5b685f8 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -222,6 +222,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): num_network_updates += 1 with timeit("optim - update"): + torch.compiler.cudagraph_mark_step_begin() loss = update(batch) losses.append(loss) diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index af6609a1f7b..f12cc950624 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -133,7 +133,7 @@ def update(batch): if cfg.loss.cudagraphs: update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10) - actor = CudaGraphModule(update, warmup=10) + actor = CudaGraphModule(actor, warmup=10) adv_module = CudaGraphModule(adv_module) # Create collector @@ -200,7 +200,7 @@ def update(batch): for group in critic_optim.param_groups: group["lr"].copy_(lr * alpha) num_network_updates += 1 - + torch.compiler.cudagraph_mark_step_begin() loss = update(batch) losses.append(loss) From 6415db8566befe6fde36c119d332f1c75813fcf3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 4 Oct 2024 15:57:33 +0100 Subject: [PATCH 03/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 1c2d5b685f8..bea1dc28f84 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -267,8 +267,8 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() sampling_start = time.time() + torch.compiler.cudagraph_mark_step_begin() collector.shutdown() if not test_env.is_closed: From 62a729546345387fcbcf65afd42c6ffdf4209d8e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 4 Oct 2024 15:59:47 +0100 Subject: [PATCH 04/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 2 +- sota-implementations/a2c/a2c_mujoco.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index bea1dc28f84..d4e6e0369d3 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -174,6 +174,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): c_iter = iter(collector) for i in range(len(collector)): with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() data = next(c_iter) log_info = {} @@ -268,7 +269,6 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): logger.log_scalar(key, value, collected_frames) sampling_start = time.time() - torch.compiler.cudagraph_mark_step_begin() collector.shutdown() if not test_env.is_closed: diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index f12cc950624..f15f6bd9a15 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -241,8 +241,8 @@ def update(batch): for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() sampling_start = time.time() + torch.compiler.cudagraph_mark_step_begin() collector.shutdown() if not test_env.is_closed: From a46dd6ab1e580dc1ca5ae6e0804991f9e60b769a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 4 Oct 2024 16:07:51 +0100 Subject: [PATCH 05/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 2 +- sota-implementations/a2c/a2c_mujoco.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index d4e6e0369d3..0b0aeeffc31 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -61,6 +61,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=True, + vectorized=not cfg.loss.compile, ) loss_module = A2CLoss( actor_network=actor, @@ -157,7 +158,6 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): total_frames=total_frames, device=device, storing_device=device, - max_frames_per_traj=-1, policy_device=device, ) diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index f15f6bd9a15..df35e8ea20e 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -55,6 +55,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=False, + vectorized=not cfg.loss.compile, ) loss_module = A2CLoss( actor_network=actor, From 22bc1e550bbf23272f7dd4f713f50e9b09e28e5e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 4 Oct 2024 16:36:54 +0100 Subject: [PATCH 06/24] Update [ghstack-poisoned] --- torchrl/envs/gym_like.py | 4 ++-- torchrl/objectives/value/advantages.py | 18 ++++++++++++++++-- torchrl/objectives/value/functional.py | 22 +++++++++++----------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 995f245a8ac..01c1ce701cf 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -224,7 +224,7 @@ def read_done( if truncated is not None and done is None: done = truncated | terminated elif truncated is None and done is None: - done = terminated + done = terminated.clone() do_break = done.any() if not isinstance(done, bool) else done if isinstance(done, bool): done = [done] @@ -323,7 +323,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # if truncated/terminated is not in the keys, we just don't pass it even if it # is defined. if terminated is None: - terminated = done + terminated = done.clone() if truncated is not None: obs_dict["truncated"] = truncated obs_dict["done"] = done diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index b7db2e8242e..a4e2fe13b0d 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -38,6 +38,10 @@ vtrace_advantage_estimate, ) +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling try: from torch import vmap @@ -1161,7 +1165,7 @@ class GAE(ValueEstimatorBase): pass detached parameters for functional modules. vectorized (bool, optional): whether to use the vectorized version of the - lambda return. Default is `True`. + lambda return. Default is `True` if not compiling. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` @@ -1211,7 +1215,7 @@ def __init__( value_network: TensorDictModule, average_gae: bool = False, differentiable: bool = False, - vectorized: bool = True, + vectorized: bool | None = None, skip_existing: bool | None = None, advantage_key: NestedKey = None, value_target_key: NestedKey = None, @@ -1235,6 +1239,16 @@ def __init__( self.vectorized = vectorized self.time_dim = time_dim + @property + def vectorized(self): + if is_dynamo_compiling(): + return False + return self._vectorized + + @vectorized.setter + def vectorize(self, value): + self._vectorized = value + @_self_set_skip_existing @_self_set_grad_enabled @dispatch diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index bc2e9ea39f7..bb737d7c20d 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -151,7 +151,7 @@ def generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -302,7 +302,7 @@ def vec_generalized_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -401,7 +401,7 @@ def td0_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -445,7 +445,7 @@ def td0_return_estimate( """ if done is not None and terminated is None: - terminated = done + terminated = done.clone() warnings.warn( "done for td0_return_estimate is deprecated. Pass ``terminated`` instead." ) @@ -509,7 +509,7 @@ def td1_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) not_done = (~done).int() @@ -606,7 +606,7 @@ def td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -736,7 +736,7 @@ def vec_td1_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -814,7 +814,7 @@ def td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -920,7 +920,7 @@ def td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape @@ -1056,7 +1056,7 @@ def vec_td_lambda_return_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not (next_state_value.shape == reward.shape == done.shape == terminated.shape): raise RuntimeError(SHAPE_ERR) @@ -1206,7 +1206,7 @@ def vec_td_lambda_advantage_estimate( """ if terminated is None: - terminated = done + terminated = done.clone() if not ( next_state_value.shape == state_value.shape From ecf5933f6ad3e47e6fe18311c08f577b24a380c3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 11 Nov 2024 14:29:40 +0000 Subject: [PATCH 07/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 11 +++--- sota-implementations/a2c/a2c_mujoco.py | 55 ++++++++++++++++---------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index fcdffed57a0..09919eb7dd0 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -129,11 +129,12 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): .set("grad_norm", gn) ) + compile_mode = None if cfg.loss.compile: compile_mode = cfg.loss.compile_mode if compile_mode in ("", None): if cfg.loss.cudagraphs: - compile_mode = None + compile_mode = "default" else: compile_mode = "reduce-overhead" update = torch.compile(update, mode=compile_mode) @@ -152,7 +153,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): device=device, storing_device=device, policy_device=device, - compile_policy=cfg.loss.compile_mode if cfg.loss.compile else False, + compile_policy={"mode": compile_mode} if cfg.loss.compile else False, cudagraph_policy=cfg.loss.cudagraphs, ) @@ -222,9 +223,6 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): loss = update(batch) losses.append(loss) - if i % 200 == 0: - timeit.print() - timeit.erase() # Get training losses training_time = time.time() - training_start losses = torch.stack(losses).float().mean() @@ -239,6 +237,9 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): **timeit.todict(prefix="time"), } ) + if i % 200 == 0: + timeit.print() + timeit.erase() # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index b6bb4b88efd..1160626ce8e 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -16,6 +16,7 @@ def main(cfg: "DictConfig"): # noqa: F821 import torch.optim import tqdm + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -127,7 +128,7 @@ def update(batch): compile_mode = cfg.loss.compile_mode if compile_mode in ("", None): if cfg.loss.cudagraphs: - compile_mode = None + compile_mode = "default" else: compile_mode = "reduce-overhead" @@ -150,7 +151,7 @@ def update(batch): storing_device=device, max_frames_per_traj=-1, trust_policy=True, - compile_policy=compile_mode if cfg.loss.compile else False, + compile_policy={"mode": compile_mode} if cfg.loss.compile else False, cudagraph_policy=cfg.loss.cudagraphs, ) @@ -164,7 +165,11 @@ def update(batch): pbar = tqdm.tqdm(total=cfg.collector.total_frames) sampling_start = time.time() - for i, data in enumerate(collector): + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + torch.compiler.cudagraph_mark_step_begin() + data = next(c_iter) log_info = {} sampling_time = time.time() - sampling_start @@ -188,27 +193,33 @@ def update(batch): training_start = time.time() # Compute GAE - with torch.no_grad(): + with torch.no_grad(), timeit("advantage"): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - data_buffer.extend(data_reshape) - - for batch in data_buffer: - - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: - alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"].copy_(lr * alpha) - for group in critic_optim.param_groups: - group["lr"].copy_(lr * alpha) - num_network_updates += 1 - torch.compiler.cudagraph_mark_step_begin() - loss = update(batch) - losses.append(loss) + with timeit("emptying"): + data_buffer.empty() + with timeit("extending"): + data_buffer.extend(data_reshape) + + with timeit("optim"): + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + with timeit("optim - lr"): + alpha = 1.0 + if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in actor_optim.param_groups: + group["lr"].copy_(lr * alpha) + for group in critic_optim.param_groups: + group["lr"].copy_(lr * alpha) + num_network_updates += 1 + with timeit("optim - update"): + torch.compiler.cudagraph_mark_step_begin() + loss = update(batch) + losses.append(loss) # Get training losses training_time = time.time() - training_start @@ -220,8 +231,12 @@ def update(batch): "train/lr": alpha * cfg.optim.lr, "train/sampling_time": sampling_time, "train/training_time": training_time, + **timeit.todict(prefix="time"), } ) + if i % 200 == 0: + timeit.print() + timeit.erase() # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): From 9f329cee122be49b93272cd4622f6e0e3ec9204a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 09:42:53 +0000 Subject: [PATCH 08/24] Update [ghstack-poisoned] --- test/_utils_internal.py | 100 ++++++++++++++++++++++++++++++++++++++-- test/test_cost.py | 3 +- 2 files changed, 99 insertions(+), 4 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 48492459315..683a0d60182 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -11,6 +11,7 @@ import os.path import time import unittest +import warnings from functools import wraps # Get relative file path @@ -20,9 +21,15 @@ import torch import torch.cuda -from tensordict import tensorclass, TensorDict -from torch import nn -from torchrl._utils import implement_for, logger as torchrl_logger, seed_generator +from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torch import nn, vmap +from torchrl._utils import ( + implement_for, + logger as torchrl_logger, + RL_WARNINGS, + seed_generator, +) from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm @@ -35,6 +42,7 @@ ToTensorImage, TransformedEnv, ) +from torchrl.objectives.value.advantages import _vmap_func # Specified for test_utils.py __version__ = "0.3" @@ -713,3 +721,89 @@ def forward( ): input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in) + + +def _call_value_nets( + value_net: TensorDictModuleBase, + data: TensorDictBase, + params: TensorDictBase, + next_params: TensorDictBase, + single_call: bool, + value_key: NestedKey, + detach_next: bool, + vmap_randomness: str = "error", +): + in_keys = value_net.in_keys + if single_call: + for i, name in enumerate(data.names): + if name == "time": + ndim = i + 1 + break + else: + ndim = None + if ndim is not None: + # get data at t and last of t+1 + idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) + idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) + idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False)[idx0], + ], + ndim - 1, + ) + else: + if RL_WARNINGS: + warnings.warn( + "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " + "This warning can be turned off by setting the environment variable RL_WARNINGS to False." + ) + ndim = data.ndim + idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) + idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + ndim - 1, + ) + + # next_params should be None or be identical to params + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." + ) + if params is not None: + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) + else: + value_est = value_net(data_in).get(value_key) + value, value_ = value_est[idx], value_est[idx_] + else: + data_in = torch.stack( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + 0, + ) + if (params is not None) ^ (next_params is not None): + raise ValueError( + "params and next_params must be either both provided or not." + ) + elif params is not None: + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) + else: + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) + value_est = data_out.get(value_key) + value, value_ = value_est[0], value_est[1] + data.set(value_key, value) + data.set(("next", value_key), value_) + if detach_next: + value_ = value_.detach() + return value, value_ diff --git a/test/test_cost.py b/test/test_cost.py index 1b54b8bf111..1e157fd7a2f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -106,7 +106,6 @@ ValueEstimators, ) from torchrl.objectives.value.advantages import ( - _call_value_nets, GAE, TD1Estimator, TDLambdaEstimator, @@ -135,6 +134,7 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import ( # noqa + _call_value_nets, dtype_fixture, get_available_devices, get_default_devices, @@ -142,6 +142,7 @@ from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv else: from _utils_internal import ( # noqa + _call_value_nets, dtype_fixture, get_available_devices, get_default_devices, From 74afeef281f05f194ccd81b460dc888e85dbf66c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 11:17:36 +0000 Subject: [PATCH 09/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 14 +++++++------- sota-implementations/a2c/a2c_mujoco.py | 16 ++++++++-------- sota-implementations/a2c/config_atari.yaml | 2 ++ sota-implementations/a2c/config_mujoco.yaml | 4 +++- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 09919eb7dd0..6881efdabe8 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -56,7 +56,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=True, - vectorized=not cfg.loss.compile, + vectorized=not cfg.compile.compile, ) loss_module = A2CLoss( actor_network=actor, @@ -130,17 +130,17 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): ) compile_mode = None - if cfg.loss.compile: - compile_mode = cfg.loss.compile_mode + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode if compile_mode in ("", None): - if cfg.loss.cudagraphs: + if cfg.compile.cudagraphs: compile_mode = "default" else: compile_mode = "reduce-overhead" update = torch.compile(update, mode=compile_mode) adv_module = torch.compile(adv_module, mode=compile_mode) - if cfg.loss.cudagraphs: + if cfg.compile.cudagraphs: update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) adv_module = CudaGraphModule(adv_module) @@ -153,8 +153,8 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): device=device, storing_device=device, policy_device=device, - compile_policy={"mode": compile_mode} if cfg.loss.compile else False, - cudagraph_policy=cfg.loss.cudagraphs, + compile_policy={"mode": compile_mode} if cfg.compile.compile else False, + cudagraph_policy=cfg.compile.cudagraphs, ) # Main loop diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 1160626ce8e..cf513962a12 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -41,7 +41,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create models (check utils_mujoco.py) actor, critic = make_ppo_models( - cfg.env.env_name, device=device, compile=cfg.loss.compile + cfg.env.env_name, device=device, compile=cfg.compile.compile ) # Create data buffer @@ -58,7 +58,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.loss.gae_lambda, value_network=critic, average_gae=False, - vectorized=not cfg.loss.compile, + vectorized=not cfg.compile.compile, ) loss_module = A2CLoss( actor_network=actor, @@ -124,10 +124,10 @@ def update(batch): return loss.select("loss_critic", "loss_objective").detach() # , "loss_entropy" compile_mode = None - if cfg.loss.compile: - compile_mode = cfg.loss.compile_mode + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode if compile_mode in ("", None): - if cfg.loss.cudagraphs: + if cfg.compile.cudagraphs: compile_mode = "default" else: compile_mode = "reduce-overhead" @@ -136,7 +136,7 @@ def update(batch): actor = torch.compile(actor, mode=compile_mode) adv_module = torch.compile(adv_module, mode=compile_mode) - if cfg.loss.cudagraphs: + if cfg.compile.cudagraphs: update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10) actor = CudaGraphModule(actor, warmup=10) adv_module = CudaGraphModule(adv_module) @@ -151,8 +151,8 @@ def update(batch): storing_device=device, max_frames_per_traj=-1, trust_policy=True, - compile_policy={"mode": compile_mode} if cfg.loss.compile else False, - cudagraph_policy=cfg.loss.cudagraphs, + compile_policy={"mode": compile_mode} if cfg.compile.compile else False, + cudagraph_policy=cfg.compile.cudagraphs, ) test_env.eval() diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index 5a7586ee95d..20ff7b91b48 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -34,6 +34,8 @@ loss: critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 + +compile: compile: False compile_mode: cudagraphs: False diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index a42087b2631..43ab041e0d5 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -31,7 +31,9 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + device: + +compile: compile: False compile_mode: default cudagraphs: False - device: From ac6145cf60eb7ebc367043ff125860a6f4e3ffab Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 11:22:36 +0000 Subject: [PATCH 10/24] Update [ghstack-poisoned] --- sota-implementations/a2c/config_atari.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index 20ff7b91b48..1c2fc9fdff5 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -34,9 +34,9 @@ loss: critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 + device: compile: compile: False compile_mode: cudagraphs: False - device: From 36251666a29ebafbb1c365810ca20ab20d27058e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 11:34:27 +0000 Subject: [PATCH 11/24] Update [ghstack-poisoned] --- torchrl/collectors/collectors.py | 11 +++++++ torchrl/objectives/value/advantages.py | 40 ++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 20128e4f6a2..3b003a6a567 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -67,6 +67,15 @@ set_exploration_type, ) +try: + from torchrl.compiler import cudagraph_mark_step_begin +except ImportError: + + def cudagraph_mark_step_begin(): + """Placeholder for missing cudagraph_mark_step_begin method.""" + ... + + _TIMEOUT = 1.0 INSTANTIATE_TIMEOUT = 20 _MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory @@ -1145,6 +1154,8 @@ def rollout(self) -> TensorDictBase: else: policy_input = self._shuttle # we still do the assignment for security + if self.compiled_policy: + cudagraph_mark_step_begin() policy_output = self.policy(policy_input) if self._shuttle is not policy_output: # ad-hoc update shuttle diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index c90f16911dc..739fb9a018e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -664,7 +664,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -879,7 +883,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1103,7 +1111,11 @@ def value_estimate( ): reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1336,7 +1348,13 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + gamma, lmbda = self.gamma, self.lmbda + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1417,7 +1435,13 @@ def value_estimate( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + if self.lmbda.device != device: + self.lmbda = self.lmbda.to(device) + gamma, lmbda = self.gamma, self.lmbda + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) @@ -1688,7 +1712,11 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma = self.gamma.to(device) + + if self.gamma.device != device: + self.gamma = self.gamma.to(device) + gamma = self.gamma + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) From 20573c586191aa6de5caf8043e69d574aeb4b78f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 11:57:23 +0000 Subject: [PATCH 12/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 1 + sota-implementations/a2c/a2c_mujoco.py | 1 + torchrl/collectors/collectors.py | 1 + 3 files changed, 3 insertions(+) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 6881efdabe8..614d4feadf4 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -57,6 +57,7 @@ def main(cfg: "DictConfig"): # noqa: F821 value_network=critic, average_gae=True, vectorized=not cfg.compile.compile, + device=device, ) loss_module = A2CLoss( actor_network=actor, diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index cf513962a12..9f35a8fa111 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -59,6 +59,7 @@ def main(cfg: "DictConfig"): # noqa: F821 value_network=critic, average_gae=False, vectorized=not cfg.compile.compile, + device=device, ) loss_module = A2CLoss( actor_network=actor, diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 3b003a6a567..e07c02cf1e5 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -841,6 +841,7 @@ def _make_final_rollout(self): policy_input_clone = ( policy_input.clone() ) # to test if values have changed in-place + cudagraph_mark_step_begin() policy_output = self.policy(policy_input) # check that we don't have exclusive keys, because they don't appear in keys From 7637609c685cd9120095e62925e6bec66317ed6c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 12:38:11 +0000 Subject: [PATCH 13/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 8 +++++--- sota-implementations/a2c/a2c_mujoco.py | 8 +++++--- torchrl/collectors/collectors.py | 7 +++++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 614d4feadf4..7ebc3333ef8 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -3,15 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra -from tensordict.nn import CudaGraphModule -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder +import torch +torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time + from tensordict.nn import CudaGraphModule + from torchrl._utils import logger as torchrl_logger + from torchrl.record import VideoRecorder import torch.optim import tqdm diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 9f35a8fa111..803ac32444e 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -3,16 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hydra -from tensordict.nn import CudaGraphModule -from torchrl._utils import logger as torchrl_logger -from torchrl.record import VideoRecorder +import torch +torch.set_float32_matmul_precision("high") @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time + from tensordict.nn import CudaGraphModule + from torchrl._utils import logger as torchrl_logger + from torchrl.record import VideoRecorder import torch.optim import tqdm diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index e07c02cf1e5..e6ba91cb953 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -73,7 +73,7 @@ def cudagraph_mark_step_begin(): """Placeholder for missing cudagraph_mark_step_begin method.""" - ... + raise NotImplementedError("cudagraph_mark_step_begin not implemented.") _TIMEOUT = 1.0 @@ -841,7 +841,8 @@ def _make_final_rollout(self): policy_input_clone = ( policy_input.clone() ) # to test if values have changed in-place - cudagraph_mark_step_begin() + if self.compiled_policy: + cudagraph_mark_step_begin() policy_output = self.policy(policy_input) # check that we don't have exclusive keys, because they don't appear in keys @@ -1158,6 +1159,8 @@ def rollout(self) -> TensorDictBase: if self.compiled_policy: cudagraph_mark_step_begin() policy_output = self.policy(policy_input) + if self.compiled_policy: + policy_output = policy_output.clone() if self._shuttle is not policy_output: # ad-hoc update shuttle self._shuttle.update( From b3c088e10621d77f268d17191988bce7c5246fd5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 12:40:52 +0000 Subject: [PATCH 14/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 8 ++++---- sota-implementations/a2c/a2c_mujoco.py | 9 +++++---- torchrl/collectors/collectors.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 7ebc3333ef8..043f36f658b 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -7,24 +7,24 @@ torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time - from tensordict.nn import CudaGraphModule - from torchrl._utils import logger as torchrl_logger - from torchrl.record import VideoRecorder import torch.optim import tqdm + from tensordict.nn import CudaGraphModule - from torchrl._utils import timeit + from torchrl._utils import logger as torchrl_logger, timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_atari import eval_model, make_parallel_env, make_ppo_models diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 803ac32444e..1947f48beee 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -7,24 +7,25 @@ torch.set_float32_matmul_precision("high") + @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time - from tensordict.nn import CudaGraphModule - from torchrl._utils import logger as torchrl_logger - from torchrl.record import VideoRecorder import torch.optim import tqdm - from torchrl._utils import timeit + from tensordict.nn import CudaGraphModule + + from torchrl._utils import logger as torchrl_logger, timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss from torchrl.objectives.value.advantages import GAE + from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index e6ba91cb953..f40b6ebc20e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -68,7 +68,7 @@ ) try: - from torchrl.compiler import cudagraph_mark_step_begin + from torch.compiler import cudagraph_mark_step_begin except ImportError: def cudagraph_mark_step_begin(): From 0250de25d3b0e5be5a741a1ce258b4044b71310a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 12:49:55 +0000 Subject: [PATCH 15/24] Update [ghstack-poisoned] --- sota-implementations/a2c/config_atari.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index 1c2fc9fdff5..6c68c73980e 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -1,11 +1,11 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 1 + num_envs: 16 # collector collector: - frames_per_batch: 80 + frames_per_batch: 8000 total_frames: 40_000_000 # logger From 56bdd3f679734510ab2e157fedd8053e7bb96b3f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 12:53:17 +0000 Subject: [PATCH 16/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 3 +-- sota-implementations/a2c/a2c_mujoco.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 043f36f658b..d0bbfaeb69d 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -173,7 +173,6 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): c_iter = iter(collector) for i in range(len(collector)): with timeit("collecting"): - torch.compiler.cudagraph_mark_step_begin() data = next(c_iter) log_info = {} @@ -223,7 +222,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): with timeit("optim - update"): torch.compiler.cudagraph_mark_step_begin() - loss = update(batch) + loss = update(batch).clone() losses.append(loss) # Get training losses diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 1947f48beee..154d9feae7f 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -172,7 +172,6 @@ def update(batch): c_iter = iter(collector) for i in range(len(collector)): with timeit("collecting"): - torch.compiler.cudagraph_mark_step_begin() data = next(c_iter) log_info = {} @@ -222,7 +221,7 @@ def update(batch): num_network_updates += 1 with timeit("optim - update"): torch.compiler.cudagraph_mark_step_begin() - loss = update(batch) + loss = update(batch).clone() losses.append(loss) # Get training losses From f2238dda42d3c577253e6ed74938a5a67067494f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 13:10:01 +0000 Subject: [PATCH 17/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_mujoco.py | 12 +++++------- sota-implementations/a2c/utils_atari.py | 5 +++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 154d9feae7f..161de736d0f 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -83,6 +83,8 @@ def main(cfg: "DictConfig"): # noqa: F821 lr=torch.tensor(cfg.optim.lr, device=device), capturable=device.type == "cuda", ) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim # Create logger logger = None @@ -120,11 +122,9 @@ def update(batch): (actor_loss + critic_loss).backward() # Update the networks - actor_optim.step() - critic_optim.step() + optim.step() - actor_optim.zero_grad(set_to_none=True) - critic_optim.zero_grad(set_to_none=True) + optim.zero_grad(set_to_none=True) return loss.select("loss_critic", "loss_objective").detach() # , "loss_entropy" compile_mode = None @@ -214,9 +214,7 @@ def update(batch): alpha = 1.0 if cfg.optim.anneal_lr: alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"].copy_(lr * alpha) - for group in critic_optim.param_groups: + for group in optim.param_groups: group["lr"].copy_(lr * alpha) num_network_updates += 1 with timeit("optim - update"): diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index bf7e23cd8f9..99c3ce2338c 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -64,10 +64,12 @@ def make_base_env( def make_parallel_env(env_name, num_envs, device, is_test=False): env = ParallelEnv( num_envs, - EnvCreator(lambda: make_base_env(env_name, device=device)), + EnvCreator(lambda: make_base_env(env_name)), serial_for_single=True, + device=device, ) env = TransformedEnv(env) + env.append_transform(DoubleToFloat()) env.append_transform(ToTensorImage()) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) @@ -76,7 +78,6 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): env.append_transform(StepCounter(max_steps=4500)) if not is_test: env.append_transform(SignTransform(in_keys=["reward"])) - env.append_transform(DoubleToFloat()) env.append_transform(VecNorm(in_keys=["pixels"])) return env From d999cb6ef79ac2a41c06cd1a57706186e9e61dea Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 13:26:20 +0000 Subject: [PATCH 18/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_atari.py | 47 +++++++++------------- sota-implementations/a2c/a2c_mujoco.py | 37 +++++++---------- sota-implementations/a2c/config_atari.yaml | 2 +- 3 files changed, 35 insertions(+), 51 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index d0bbfaeb69d..f6401b9946c 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -11,13 +11,14 @@ @hydra.main(config_path="", config_name="config_atari", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time + from copy import deepcopy import torch.optim import tqdm + from tensordict import from_module from tensordict.nn import CudaGraphModule - from torchrl._utils import logger as torchrl_logger, timeit + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -43,6 +44,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create models (check utils_atari.py) actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device) + with from_module(actor).data.to("meta").to_module(actor): + actor_eval = deepcopy(actor) + actor_eval.eval() + from_module(actor).data.to_module(actor_eval) # Create data buffer sampler = SamplerWithoutReplacement() @@ -163,20 +168,17 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = (total_frames // frames_per_batch) * num_mini_batches lr = cfg.optim.lr - sampling_start = time.time() c_iter = iter(collector) for i in range(len(collector)): with timeit("collecting"): data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) @@ -194,17 +196,17 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): ) losses = [] - training_start = time.time() # Compute GAE with torch.no_grad(), timeit("advantage"): + torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer - with timeit("emptying"): + with timeit("rb - emptying"): data_buffer.empty() - with timeit("extending"): + with timeit("rb - extending"): data_buffer.extend(data_reshape) with timeit("optim"): @@ -220,13 +222,12 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): num_network_updates += 1 - with timeit("optim - update"): + with timeit("update"): torch.compiler.cudagraph_mark_step_begin() loss = update(batch).clone() losses.append(loss) # Get training losses - training_time = time.time() - training_start losses = torch.stack(losses).float().mean() for key, value in losses.items(): @@ -234,46 +235,36 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): log_info.update( { "train/lr": lr * alpha, - "train/sampling_time": sampling_time, - "train/training_time": training_time, - **timeit.todict(prefix="time"), } ) - if i % 200 == 0: - timeit.print() - timeit.erase() # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): + with torch.no_grad(), set_exploration_type( + ExplorationType.DETERMINISTIC + ), timeit("eval"): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: - actor.eval() - eval_start = time.time() test_rewards = eval_model( - actor, test_env, num_episodes=cfg.logger.num_test_episodes + actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "test/reward": test_rewards.mean(), - "test/eval_time": eval_time, } ) - actor.train() + if i % 200 == 0: + log_info.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - sampling_start = time.time() - collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 161de736d0f..cc962020ea2 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -11,20 +11,21 @@ @hydra.main(config_path="", config_name="config_mujoco", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 - import time + from copy import deepcopy import torch.optim import tqdm + from tensordict import from_module from tensordict.nn import CudaGraphModule - from torchrl._utils import logger as torchrl_logger, timeit + from torchrl._utils import timeit from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss - from torchrl.objectives.value.advantages import GAE + from torchrl.objectives import A2CLoss, group_optimizers + from torchrl.objectives.value import GAE from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger from utils_mujoco import eval_model, make_env, make_ppo_models @@ -46,6 +47,10 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, critic = make_ppo_models( cfg.env.env_name, device=device, compile=cfg.compile.compile ) + with from_module(actor).data.to("meta").to_module(actor): + actor_eval = deepcopy(actor) + actor_eval.eval() + from_module(actor).data.to_module(actor_eval) # Create data buffer sampler = SamplerWithoutReplacement() @@ -165,17 +170,14 @@ def update(batch): # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=cfg.collector.total_frames) - sampling_start = time.time() c_iter = iter(collector) for i in range(len(collector)): with timeit("collecting"): data = next(c_iter) log_info = {} - sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch pbar.update(data.numel()) @@ -193,10 +195,10 @@ def update(batch): ) losses = [] - training_start = time.time() # Compute GAE with torch.no_grad(), timeit("advantage"): + torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) data_reshape = data.reshape(-1) @@ -223,21 +225,14 @@ def update(batch): losses.append(loss) # Get training losses - training_time = time.time() - training_start losses = torch.stack(losses).float().mean() for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( { "train/lr": alpha * cfg.optim.lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, - **timeit.todict(prefix="time"), } ) - if i % 200 == 0: - timeit.print() - timeit.erase() # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): @@ -246,32 +241,30 @@ def update(batch): final = collected_frames >= collector.total_frames if prev_test_frame < cur_test_frame or final: actor.eval() - eval_start = time.time() test_rewards = eval_model( actor, test_env, num_episodes=cfg.logger.num_test_episodes ) - eval_time = time.time() - eval_start log_info.update( { "test/reward": test_rewards.mean(), - "test/eval_time": eval_time, } ) actor.train() + if i % 200 == 0: + log_info.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() + if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - sampling_start = time.time() torch.compiler.cudagraph_mark_step_begin() collector.shutdown() if not test_env.is_closed: test_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index 6c68c73980e..59a0a621756 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -5,7 +5,7 @@ env: # collector collector: - frames_per_batch: 8000 + frames_per_batch: 800 total_frames: 40_000_000 # logger From 0f09e9fbeedc2f66cb6acaeb068c61c5cb6a2881 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 15:28:38 +0000 Subject: [PATCH 19/24] Update [ghstack-poisoned] --- sota-implementations/a2c/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sota-implementations/a2c/README.md b/sota-implementations/a2c/README.md index b3c651b7d4e..63833001f04 100644 --- a/sota-implementations/a2c/README.md +++ b/sota-implementations/a2c/README.md @@ -33,7 +33,7 @@ python a2c_mujoco.py compile.compile=1 compile.cudagraphs=1 Runtimes when executed on H100: -| Environment | Eager | Compile | Compile+cudagraphs | -|-------------|-------|---------|--------------------| -| MUJOCO | | | | -| ATARI | | 60 mins | 43 mins | +| Environment | Eager | Compile | Compile+cudagraphs | +|-------------|---------|---------|--------------------| +| MUJOCO | | | | +| ATARI | 80 mins | 60 mins | 43 mins | From 4d21003e7f0e91b5b9b6f2315dbbbd77872156da Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 15:33:48 +0000 Subject: [PATCH 20/24] Update [ghstack-poisoned] --- sota-implementations/a2c/a2c_mujoco.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index cc962020ea2..b75a5224bc5 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -142,13 +142,11 @@ def update(batch): compile_mode = "reduce-overhead" update = torch.compile(update, mode=compile_mode) - actor = torch.compile(actor, mode=compile_mode) adv_module = torch.compile(adv_module, mode=compile_mode) if cfg.compile.cudagraphs: - update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=10) - actor = CudaGraphModule(actor, warmup=10) - adv_module = CudaGraphModule(adv_module) + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=20) + adv_module = CudaGraphModule(adv_module, warmup=20) # Create collector collector = SyncDataCollector( @@ -160,7 +158,7 @@ def update(batch): storing_device=device, max_frames_per_traj=-1, trust_policy=True, - compile_policy={"mode": compile_mode} if cfg.compile.compile else False, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, cudagraph_policy=cfg.compile.cudagraphs, ) From b255b57dd9ce8f7550a66548189e360a700b9656 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 15:35:50 +0000 Subject: [PATCH 21/24] Update [ghstack-poisoned] --- sota-implementations/a2c/config_mujoco.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index 43ab041e0d5..9e8f36a5995 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -4,7 +4,7 @@ env: # collector collector: - frames_per_batch: 64 + frames_per_batch: 640 total_frames: 1_000_000 # logger From 1351bd466eaf968fe7fcd9bd8cae4281c1ad43f8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 15:48:35 +0000 Subject: [PATCH 22/24] Update [ghstack-poisoned] --- sota-implementations/a2c/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sota-implementations/a2c/README.md b/sota-implementations/a2c/README.md index 63833001f04..91c9099c8c9 100644 --- a/sota-implementations/a2c/README.md +++ b/sota-implementations/a2c/README.md @@ -33,7 +33,7 @@ python a2c_mujoco.py compile.compile=1 compile.cudagraphs=1 Runtimes when executed on H100: -| Environment | Eager | Compile | Compile+cudagraphs | -|-------------|---------|---------|--------------------| -| MUJOCO | | | | -| ATARI | 80 mins | 60 mins | 43 mins | +| Environment | Eager | Compile | Compile+cudagraphs | +|-------------|-----------|-----------|--------------------| +| MUJOCO | < 25 mins | < 23 mins | < 20 mins | +| ATARI | < 85 mins | < 60 mins | < 45 mins | From 263f5670a506904fb14dedc4114deb7738f3779c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:23:33 +0000 Subject: [PATCH 23/24] Update [ghstack-poisoned] --- sota-implementations/a2c/utils_mujoco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index e16bcefc890..87587d092f0 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -60,7 +60,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): "low": proof_environment.action_spec.space.low.to(device), "high": proof_environment.action_spec.space.high.to(device), "tanh_loc": False, - "safe_tanh": not compile, + "safe_tanh": True, } # Define policy architecture From 540eaa8c50366e5d77fe2aa81e67fc61abe58716 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:26:30 +0000 Subject: [PATCH 24/24] Update [ghstack-poisoned] --- torchrl/modules/distributions/continuous.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index bae3b28c4b2..62b5df5d14b 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -42,9 +42,8 @@ except ImportError: from torch._dynamo import is_compiling as is_dynamo_compiling -TORCH_VERSION_PRE_2_6 = version.parse(torch.__version__).base_version < version.parse( - "2.6.0" -) +TORCH_VERSION = version.parse(torch.__version__).base_version +TORCH_VERSION_PRE_2_6 = version.parse(TORCH_VERSION) < version.parse("2.6.0") class IndependentNormal(D.Independent):