Skip to content

Commit

Permalink
[Refactor] Allow safe-tanh for torch >= 2.6.0
Browse files Browse the repository at this point in the history
ghstack-source-id: 92df1954451453ee051bbde499f6db5ebaafed49
Pull Request resolved: #2580
  • Loading branch information
vmoens committed Nov 18, 2024
1 parent 600760f commit 1474f85
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,9 @@

import numpy as np
import torch
from packaging import version
from torch import distributions as D, nn

try:
from torch.compiler import assume_constant_result
except ImportError:
from torch._dynamo import assume_constant_result

from torch.distributions import constraints
from torch.distributions.transforms import _InverseTransform

Expand All @@ -36,11 +32,20 @@
# speeds up distribution construction
D.Distribution.set_default_validate_args(False)

try:
from torch.compiler import assume_constant_result
except ImportError:
from torch._dynamo import assume_constant_result

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling

TORCH_VERSION_PRE_2_6 = version.parse(torch.__version__).base_version < version.parse(
"2.6.0"
)


class IndependentNormal(D.Independent):
"""Implements a Normal distribution with location scaling.
Expand Down Expand Up @@ -437,7 +442,7 @@ def __init__(
self.high = high

if safe_tanh:
if is_dynamo_compiling():
if is_dynamo_compiling() and TORCH_VERSION_PRE_2_6:
_err_compile_safetanh()
t = SafeTanhTransform()
else:
Expand Down Expand Up @@ -772,8 +777,8 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:

def _err_compile_safetanh():
raise RuntimeError(
"safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass"
"safe_tanh=False. "
"safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. "
"To deactivate it, pass safe_tanh=False. "
"If you are using a ProbabilisticTensorDictModule, this can be done via "
"`distribution_kwargs={'safe_tanh': False}`. "
"See https://github.com/pytorch/pytorch/issues/133529 for more details."
Expand Down

0 comments on commit 1474f85

Please sign in to comment.