Skip to content

Commit

Permalink
[CI] Add benchmarks to test runs
Browse files Browse the repository at this point in the history
ghstack-source-id: 8d83ae870d3629117d857fdbe98a992ca56b7838
Pull Request resolved: #2410
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent 0a410ff commit 98d2933
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pytest-cov
- pytest-mock
- pytest-instafail
- pytest-benchmark
- pytest-rerunfailures
- pytest-timeout
- expecttest
Expand Down
8 changes: 6 additions & 2 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ conda deactivate
conda activate "${env_dir}"

echo "installing gymnasium"
pip3 install "gymnasium"
pip3 install ale_py
pip3 install "gymnasium[atari,accept-rom-license]"
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
pip3 install "mujoco" -U

Expand Down Expand Up @@ -189,9 +188,14 @@ export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch
export MAX_IDLE_COUNT=100
export BATCHED_PIPE_TIMEOUT=60
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1

pytest test/smoke_test.py -v --durations 200
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'

# Check that benchmarks run
python -m pytest benchmarks

if [ "${CU_VERSION:-}" != cpu ] ; then
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ do
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
pip3 install 'gymnasium[atari,accept-rom-license,ale-py]'==$GYM_VERSION
pip3 install 'gymnasium[atari,accept-rom-license]'==$GYM_VERSION

$DIR/run_test.sh

Expand Down
8 changes: 6 additions & 2 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

from ._extension import _init_extension

try:
from torch.compiler import is_dynamo_compiling
except Exception:
from torch._dynamo import is_compiling as is_dynamo_compiling

try:
from .version import __version__
Expand Down Expand Up @@ -69,7 +73,7 @@ def _inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand All @@ -84,7 +88,7 @@ def _inv(self):
inv = self._inv()
if inv is None:
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
else:
Expand Down
21 changes: 11 additions & 10 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
)
from torchrl.modules.utils import mappings

try:
from torch.compiler import is_dynamo_compiling
except Exception:
from torch._dynamo import is_compiling as is_dynamo_compiling

# speeds up distribution construction
D.Distribution.set_default_validate_args(False)

Expand Down Expand Up @@ -112,7 +117,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand Down Expand Up @@ -334,7 +339,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
return inv
Expand All @@ -348,7 +353,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand Down Expand Up @@ -460,15 +465,13 @@ def __init__(
self.high = high

if safe_tanh:
if torch.compiler.is_dynamo_compiling():
if is_dynamo_compiling():
_err_compile_safetanh()
t = SafeTanhTransform()
else:
t = D.TanhTransform()
# t = D.TanhTransform()
if torch.compiler.is_dynamo_compiling() or (
self.non_trivial_max or self.non_trivial_min
):
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
t = _PatchedComposeTransform(
[
t,
Expand All @@ -495,9 +498,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
if self.tanh_loc:
loc = (loc / self.upscale).tanh() * self.upscale
# loc must be rescaled if tanh_loc
if torch.compiler.is_dynamo_compiling() or (
self.non_trivial_max or self.non_trivial_min
):
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
loc = loc + (self.high - self.low) / 2 + self.low
self.loc = loc
self.scale = scale
Expand Down
7 changes: 6 additions & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
raise err_ft from err
from torchrl.envs.utils import step_mdp

try:
from torch.compiler import is_dynamo_compiling
except Exception:
from torch._dynamo import is_compiling as is_dynamo_compiling

_GAMMA_LMBDA_DEPREC_ERROR = (
"Passing gamma / lambda parameters through the loss constructor "
"is a deprecated feature. To customize your value function, "
Expand Down Expand Up @@ -460,7 +465,7 @@ def _cache_values(func):

@functools.wraps(func)
def new_func(self, netname=None):
if torch.compiler.is_dynamo_compiling():
if is_dynamo_compiling():
if netname is not None:
return func(self, netname)
else:
Expand Down

0 comments on commit 98d2933

Please sign in to comment.