From 8f04818a06f14c9616ac3a6925b21039e19035be Mon Sep 17 00:00:00 2001 From: Danylo Baibak Date: Sun, 25 Feb 2024 01:53:27 +0100 Subject: [PATCH] [BugFix] Fixed import for importlib (#1914) Co-authored-by: vmoens --- .github/workflows/build-wheels-m1.yml | 4 ++-- test/smoke_test.py | 2 ++ torchrl/_extension.py | 5 ++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index 84fe79d09d2..82971c8233d 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -19,12 +19,12 @@ permissions: jobs: generate-matrix: - uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@Remove-Builds-Limits-for-Testing with: package-type: wheel os: macos-arm64 test-infra-repository: pytorch/test-infra - test-infra-ref: main + test-infra-ref: Remove-Builds-Limits-for-Testing build: needs: generate-matrix strategy: diff --git a/test/smoke_test.py b/test/smoke_test.py index 313c786088c..c6500deb5e8 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -14,3 +14,5 @@ def test_imports(): from torchrl.envs.gym_like import GymLikeEnv # noqa: F401 from torchrl.modules import SafeModule # noqa: F401 from torchrl.objectives.common import LossModule # noqa: F401 + + PrioritizedReplayBuffer(alpha=1.1, beta=1.1) diff --git a/torchrl/_extension.py b/torchrl/_extension.py index 5eb820cb86f..a9e52dbf9a4 100644 --- a/torchrl/_extension.py +++ b/torchrl/_extension.py @@ -3,18 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import importlib +import importlib.util import warnings def is_module_available(*modules: str) -> bool: - r"""Returns if a top-level module with :attr:`name` exists *without** importing it. + """Returns if a top-level module with :attr:`name` exists *without** importing it. This is generally safer than try-catch block around a `import X`. It avoids third party libraries breaking assumptions of some of our tests, e.g., setting multiprocessing start method when imported (see librosa/#747, torchvision/#544). - """ return all(importlib.util.find_spec(m) is not None for m in modules)