From 8c6b6b602e8a458a67372a175dd027a4546c2fa0 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 9 Jan 2024 13:37:39 +0100 Subject: [PATCH] Fix cupy detection (#1635) * fix cupy installed detection * fix --- optimum/onnxruntime/utils.py | 4 ++-- optimum/utils/import_utils.py | 34 +++++++++++++++++++++++++--------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index aea997eb39a..e269c8de718 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -13,7 +13,6 @@ # limitations under the License. """Utility functions, classes and constants for ONNX Runtime.""" -import importlib.util import os import re from enum import Enum @@ -28,6 +27,7 @@ import onnxruntime as ort from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss +from ..utils.import_utils import _is_package_available logger = logging.get_logger(__name__) @@ -83,7 +83,7 @@ def is_cupy_available(): """ Checks if onnxruntime-training is available. """ - return importlib.util.find_spec("cupy") is not None + return _is_package_available("cupy") class ORTConfigManager: diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 2bbfdfaf0ac..0894a3b4902 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -18,13 +18,29 @@ import sys from collections import OrderedDict from contextlib import contextmanager -from typing import Union +from typing import Tuple, Union import numpy as np import packaging from transformers.utils import is_torch_available +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib.metadata.version(pkg_name) + package_exists = True + except importlib.metadata.PackageNotFoundError: + package_exists = False + if return_version: + return package_exists, package_version + else: + return package_exists + + # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata @@ -42,14 +58,14 @@ ORT_QUANTIZE_MINIMUM_VERSION = packaging.version.parse("1.4.0") -_onnx_available = importlib.util.find_spec("onnx") is not None -_onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None -_pydantic_available = importlib.util.find_spec("pydantic") is not None -_accelerate_available = importlib.util.find_spec("accelerate") is not None -_diffusers_available = importlib.util.find_spec("diffusers") is not None -_auto_gptq_available = importlib.util.find_spec("auto_gptq") is not None -_timm_available = importlib.util.find_spec("timm") is not None -_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None +_onnx_available = _is_package_available("onnx") +_onnxruntime_available = _is_package_available("onnxruntime") +_pydantic_available = _is_package_available("pydantic") +_accelerate_available = _is_package_available("accelerate") +_diffusers_available = _is_package_available("diffusers") +_auto_gptq_available = _is_package_available("auto_gptq") +_timm_available = _is_package_available("timm") +_sentence_transformers_available = _is_package_available("sentence_transformers") torch_version = None if is_torch_available():