Skip to content

Commit

Permalink
Fix cupy detection (#1635)
Browse files Browse the repository at this point in the history
* fix cupy installed detection

* fix
  • Loading branch information
fxmarty authored Jan 9, 2024
1 parent fc214d4 commit 8c6b6b6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
4 changes: 2 additions & 2 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Utility functions, classes and constants for ONNX Runtime."""

import importlib.util
import os
import re
from enum import Enum
Expand All @@ -28,6 +27,7 @@
import onnxruntime as ort

from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss
from ..utils.import_utils import _is_package_available


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -83,7 +83,7 @@ def is_cupy_available():
"""
Checks if onnxruntime-training is available.
"""
return importlib.util.find_spec("cupy") is not None
return _is_package_available("cupy")


class ORTConfigManager:
Expand Down
34 changes: 25 additions & 9 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,29 @@
import sys
from collections import OrderedDict
from contextlib import contextmanager
from typing import Union
from typing import Tuple, Union

import numpy as np
import packaging
from transformers.utils import is_torch_available


def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib.metadata.version(pkg_name)
package_exists = True
except importlib.metadata.PackageNotFoundError:
package_exists = False
if return_version:
return package_exists, package_version
else:
return package_exists


# The package importlib_metadata is in a different place, depending on the python version.
if sys.version_info < (3, 8):
import importlib_metadata
Expand All @@ -42,14 +58,14 @@
ORT_QUANTIZE_MINIMUM_VERSION = packaging.version.parse("1.4.0")


_onnx_available = importlib.util.find_spec("onnx") is not None
_onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None
_pydantic_available = importlib.util.find_spec("pydantic") is not None
_accelerate_available = importlib.util.find_spec("accelerate") is not None
_diffusers_available = importlib.util.find_spec("diffusers") is not None
_auto_gptq_available = importlib.util.find_spec("auto_gptq") is not None
_timm_available = importlib.util.find_spec("timm") is not None
_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None
_onnx_available = _is_package_available("onnx")
_onnxruntime_available = _is_package_available("onnxruntime")
_pydantic_available = _is_package_available("pydantic")
_accelerate_available = _is_package_available("accelerate")
_diffusers_available = _is_package_available("diffusers")
_auto_gptq_available = _is_package_available("auto_gptq")
_timm_available = _is_package_available("timm")
_sentence_transformers_available = _is_package_available("sentence_transformers")

torch_version = None
if is_torch_available():
Expand Down

0 comments on commit 8c6b6b6

Please sign in to comment.