From 1474f8517959268cb3faae2e974cca29c1994328 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 16:01:46 +0000 Subject: [PATCH] [Refactor] Allow safe-tanh for torch >= 2.6.0 ghstack-source-id: 92df1954451453ee051bbde499f6db5ebaafed49 Pull Request resolved: https://github.com/pytorch/rl/pull/2580 --- torchrl/modules/distributions/continuous.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 8b0d5654b8d..cde2c95d30f 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -11,13 +11,9 @@ import numpy as np import torch +from packaging import version from torch import distributions as D, nn -try: - from torch.compiler import assume_constant_result -except ImportError: - from torch._dynamo import assume_constant_result - from torch.distributions import constraints from torch.distributions.transforms import _InverseTransform @@ -36,11 +32,20 @@ # speeds up distribution construction D.Distribution.set_default_validate_args(False) +try: + from torch.compiler import assume_constant_result +except ImportError: + from torch._dynamo import assume_constant_result + try: from torch.compiler import is_dynamo_compiling except ImportError: from torch._dynamo import is_compiling as is_dynamo_compiling +TORCH_VERSION_PRE_2_6 = version.parse(torch.__version__).base_version < version.parse( + "2.6.0" +) + class IndependentNormal(D.Independent): """Implements a Normal distribution with location scaling. @@ -437,7 +442,7 @@ def __init__( self.high = high if safe_tanh: - if is_dynamo_compiling(): + if is_dynamo_compiling() and TORCH_VERSION_PRE_2_6: _err_compile_safetanh() t = SafeTanhTransform() else: @@ -772,8 +777,8 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: def _err_compile_safetanh(): raise RuntimeError( - "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass" - "safe_tanh=False. " + "safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. " + "To deactivate it, pass safe_tanh=False. " "If you are using a ProbabilisticTensorDictModule, this can be done via " "`distribution_kwargs={'safe_tanh': False}`. " "See https://github.com/pytorch/pytorch/issues/133529 for more details."