Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Add benchmarks to test runs #2410

Open
wants to merge 17 commits into
base: gh/vmoens/22/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- pytest-cov
- pytest-mock
- pytest-instafail
- pytest-benchmark
- pytest-rerunfailures
- pytest-timeout
- expecttest
Expand All @@ -33,3 +34,5 @@ dependencies:
- transformers
- ninja
- timm
- gymnasium[atari,accept-rom-license]
- mo-gymnasium[mujoco]
11 changes: 5 additions & 6 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,6 @@ conda env update --file "${this_dir}/environment.yml" --prune
conda deactivate
conda activate "${env_dir}"

echo "installing gymnasium"
pip3 install "gymnasium"
pip3 install ale_py
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
pip3 install "mujoco" -U

# sanity check: remove?
python3 -c """
import dm_control
Expand Down Expand Up @@ -189,9 +183,14 @@ export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch
export MAX_IDLE_COUNT=100
export BATCHED_PIPE_TIMEOUT=60
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1

pytest test/smoke_test.py -v --durations 200
pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb'

# Check that benchmarks run
python -m pytest benchmarks

if [ "${CU_VERSION:-}" != cpu ] ; then
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ do
conda activate ./cloned_env

echo "Testing gym version: ${GYM_VERSION}"
pip3 install 'gymnasium[atari,accept-rom-license,ale-py]'==$GYM_VERSION
pip3 install 'gymnasium[atari,accept-rom-license]'==$GYM_VERSION

$DIR/run_test.sh

Expand Down
8 changes: 6 additions & 2 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

from ._extension import _init_extension

try:
from torch.compiler import is_dynamo_compiling
except Exception:
from torch._dynamo import is_compiling as is_dynamo_compiling

try:
from .version import __version__
Expand Down Expand Up @@ -69,7 +73,7 @@ def _inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand All @@ -84,7 +88,7 @@ def _inv(self):
inv = self._inv()
if inv is None:
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
else:
Expand Down
21 changes: 11 additions & 10 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
)
from torchrl.modules.utils import mappings

try:
from torch.compiler import is_dynamo_compiling
except Exception:
from torch._dynamo import is_compiling as is_dynamo_compiling

# speeds up distribution construction
D.Distribution.set_default_validate_args(False)

Expand Down Expand Up @@ -112,7 +117,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand Down Expand Up @@ -334,7 +339,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _PatchedComposeTransform([p.inv for p in reversed(self.parts)])
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
return inv
Expand All @@ -348,7 +353,7 @@ def inv(self):
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
if not torch.compiler.is_dynamo_compiling():
if not is_dynamo_compiling():
self._inv = weakref.ref(inv)
return inv

Expand Down Expand Up @@ -460,15 +465,13 @@ def __init__(
self.high = high

if safe_tanh:
if torch.compiler.is_dynamo_compiling():
if is_dynamo_compiling():
_err_compile_safetanh()
t = SafeTanhTransform()
else:
t = D.TanhTransform()
# t = D.TanhTransform()
if torch.compiler.is_dynamo_compiling() or (
self.non_trivial_max or self.non_trivial_min
):
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
t = _PatchedComposeTransform(
[
t,
Expand All @@ -495,9 +498,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
if self.tanh_loc:
loc = (loc / self.upscale).tanh() * self.upscale
# loc must be rescaled if tanh_loc
if torch.compiler.is_dynamo_compiling() or (
self.non_trivial_max or self.non_trivial_min
):
if is_dynamo_compiling() or (self.non_trivial_max or self.non_trivial_min):
loc = loc + (self.high - self.low) / 2 + self.low
self.loc = loc
self.scale = scale
Expand Down
7 changes: 6 additions & 1 deletion torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
raise err_ft from err
from torchrl.envs.utils import step_mdp

try:
from torch.compiler import is_dynamo_compiling
except Exception:
from torch._dynamo import is_compiling as is_dynamo_compiling

_GAMMA_LMBDA_DEPREC_ERROR = (
"Passing gamma / lambda parameters through the loss constructor "
"is a deprecated feature. To customize your value function, "
Expand Down Expand Up @@ -460,7 +465,7 @@ def _cache_values(func):

@functools.wraps(func)
def new_func(self, netname=None):
if torch.compiler.is_dynamo_compiling():
if is_dynamo_compiling():
if netname is not None:
return func(self, netname)
else:
Expand Down
Loading