diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 2234683a497..8b512526a06 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -15,6 +15,7 @@ dependencies: - pytest-cov - pytest-mock - pytest-instafail + - pytest-benchmark - pytest-rerunfailures - pytest-timeout - expecttest @@ -33,3 +34,5 @@ dependencies: - transformers - ninja - timm + - gymnasium[atari,accept-rom-license] + - mo-gymnasium[mujoco] diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 07de5e33099..dec06bd2d8d 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -87,12 +87,6 @@ conda env update --file "${this_dir}/environment.yml" --prune conda deactivate conda activate "${env_dir}" -echo "installing gymnasium" -pip3 install "gymnasium" -pip3 install ale_py -pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py -pip3 install "mujoco" -U - # sanity check: remove? python3 -c """ import dm_control @@ -189,9 +183,14 @@ export MKL_THREADING_LAYER=GNU export CKPT_BACKEND=torch export MAX_IDLE_COUNT=100 export BATCHED_PIPE_TIMEOUT=60 +export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 pytest test/smoke_test.py -v --durations 200 pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb' + +# Check that benchmarks run +python -m pytest benchmarks + if [ "${CU_VERSION:-}" != cpu ] ; then python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \ diff --git a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh index 9622984a421..a99e4b5a104 100755 --- a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -126,7 +126,7 @@ do conda activate ./cloned_env echo "Testing gym version: ${GYM_VERSION}" - pip3 install 'gymnasium[atari,accept-rom-license,ale-py]'==$GYM_VERSION + pip3 install 'gymnasium[atari,accept-rom-license]'==$GYM_VERSION $DIR/run_test.sh diff --git a/torchrl/__init__.py b/torchrl/__init__.py index cbd7b66a65e..30224da113e 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -21,6 +21,10 @@ from ._extension import _init_extension +try: + from torch.compiler import is_dynamo_compiling +except Exception: + from torch._dynamo import is_compiling as is_dynamo_compiling try: from .version import __version__ @@ -69,7 +73,7 @@ def _inv(self): inv = self._inv() if inv is None: inv = _InverseTransform(self) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) return inv @@ -84,7 +88,7 @@ def _inv(self): inv = self._inv() if inv is None: inv = ComposeTransform([p.inv for p in reversed(self.parts)]) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) inv._inv = weakref.ref(self) else: diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 71fee70d5b8..6da0b2b2895 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -33,6 +33,11 @@ ) from torchrl.modules.utils import mappings +try: + from torch.compiler import is_dynamo_compiling +except Exception: + from torch._dynamo import is_compiling as is_dynamo_compiling + # speeds up distribution construction D.Distribution.set_default_validate_args(False) @@ -112,7 +117,7 @@ def inv(self): inv = self._inv() if inv is None: inv = _InverseTransform(self) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) return inv @@ -334,7 +339,7 @@ def inv(self): inv = self._inv() if inv is None: inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)]) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) inv._inv = weakref.ref(self) return inv @@ -348,7 +353,7 @@ def inv(self): inv = self._inv() if inv is None: inv = _InverseTransform(self) - if not torch.compiler.is_dynamo_compiling(): + if not is_dynamo_compiling(): self._inv = weakref.ref(inv) return inv @@ -460,15 +465,13 @@ def __init__( self.high = high if safe_tanh: - if torch.compiler.is_dynamo_compiling(): + if is_dynamo_compiling(): _err_compile_safetanh() t = SafeTanhTransform() else: t = D.TanhTransform() # t = D.TanhTransform() - if torch.compiler.is_dynamo_compiling() or ( - self.non_trivial_max or self.non_trivial_min - ): + if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min): t = _PatchedComposeTransform( [ t, @@ -495,9 +498,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: if self.tanh_loc: loc = (loc / self.upscale).tanh() * self.upscale # loc must be rescaled if tanh_loc - if torch.compiler.is_dynamo_compiling() or ( - self.non_trivial_max or self.non_trivial_min - ): + if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min): loc = loc + (self.high - self.low) / 2 + self.low self.loc = loc self.scale = scale diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 66eae215e54..701a7426882 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -26,6 +26,11 @@ raise err_ft from err from torchrl.envs.utils import step_mdp +try: + from torch.compiler import is_dynamo_compiling +except Exception: + from torch._dynamo import is_compiling as is_dynamo_compiling + _GAMMA_LMBDA_DEPREC_ERROR = ( "Passing gamma / lambda parameters through the loss constructor " "is a deprecated feature. To customize your value function, " @@ -460,7 +465,7 @@ def _cache_values(func): @functools.wraps(func) def new_func(self, netname=None): - if torch.compiler.is_dynamo_compiling(): + if is_dynamo_compiling(): if netname is not None: return func(self, netname) else: