From f4fdad0774b01139bafef7f8546dd74f03a8528a Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Mon, 19 Feb 2024 13:26:16 -0500 Subject: [PATCH] remove onnxruntime-gpu requirement --- .../rastervision/pytorch_learner/utils/utils.py | 6 ++++-- rastervision_pytorch_learner/requirements.txt | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/utils.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/utils.py index 83f44cb73..afeee2d42 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/utils.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/utils.py @@ -12,7 +12,6 @@ from albumentations.core.transforms_interface import ImageOnlyTransform import cv2 import pandas as pd -import onnxruntime as ort from rastervision.pipeline.file_system.utils import (file_exists, file_to_json, get_tmp_dir) @@ -20,6 +19,7 @@ upgrade_config) if TYPE_CHECKING: + import onnxruntime as ort from rastervision.pytorch_learner import LearnerConfig log = logging.getLogger(__name__) @@ -456,7 +456,7 @@ class ONNXRuntimeAdapter: also outputs PyTorch Tensors. """ - def __init__(self, ort_session: ort.InferenceSession) -> None: + def __init__(self, ort_session: 'ort.InferenceSession') -> None: """Constructor. Args: @@ -482,6 +482,8 @@ def from_file(cls, path: str, providers: Optional[List[str]] = None Returns: ONNXRuntimeAdapter: An ONNXRuntimeAdapter instance. """ + import onnxruntime as ort + if providers is None: providers = ort.get_available_providers() log.info(f'Using ONNX execution providers: {providers}') diff --git a/rastervision_pytorch_learner/requirements.txt b/rastervision_pytorch_learner/requirements.txt index 0af6d0c7d..c8e997a1c 100644 --- a/rastervision_pytorch_learner/requirements.txt +++ b/rastervision_pytorch_learner/requirements.txt @@ -13,4 +13,3 @@ opencv-python-headless==4.9.0.80 matplotlib==3.8.2 tqdm==4.66.1 onnx==1.15.0 -onnxruntime-gpu==1.17