diff --git a/CHANGELOG.md b/CHANGELOG.md index c5ab69fc..9758f0b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -378,7 +378,7 @@ This release breaks backwards compatibility with `deployments` created by earlie * Add performance hint to the ovms config by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/152 * Fix bug in deployment resource clean up method by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/153 * Update python-dotenv requirement from ==0.21.* to ==1.0.* in /requirements by @dependabot in https://github.com/openvinotoolkit/geti-sdk/pull/156 -* Add a short sleep in `Geti.upload_project` after media upload by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/157 +* Add a short sleep in `Geti.upload_project_data` after media upload by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/157 * Add OVMS deployment resources to manifest by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/158 @@ -477,7 +477,7 @@ This release breaks backwards compatibility with `deployments` created by earlie * Update numpy requirement to 1.21.* by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/89 * Reduce permissions upon directory creation by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/90 * Update README to correctly reference Intel Geti brand everywhere by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/92 -* Improve check for video processing in `Geti.upload_project()` to avoid potential infinite loop by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/93 +* Improve check for video processing in `Geti.upload_project_data()` to avoid potential infinite loop by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/93 * Add unit tests to pre-merge test suite by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/91 * Update ProjectStatus and TaskStatus to include new field `n_new_annotations` by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/94 * Add progress bars for up/download of projects, media, annotations and predictions by @ljcornel in https://github.com/openvinotoolkit/geti-sdk/pull/95 diff --git a/README.md b/README.md index 01002a05..51638891 100644 --- a/README.md +++ b/README.md @@ -186,7 +186,7 @@ Instantiating the `Geti` class will establish the connection and perform authent host="https://your_server_hostname_or_ip_address", token="your_personal_access_token" ) - geti.download_project(project_name="dummy_project") + geti.download_project_data(project_name="dummy_project") ``` Here, it is assumed that the project with name 'dummy_project' exists on the cluster. @@ -223,7 +223,7 @@ Instantiating the `Geti` class will establish the connection and perform authent host="https://your_server_hostname_or_ip_address", token="your_personal_access_token" ) - geti.upload_project(target_folder="dummy_project") + geti.upload_project_data(target_folder="dummy_project") ``` The parameter `target_folder` must be a valid path to the directory holding the @@ -301,10 +301,10 @@ the screenshot below). ## High level API reference The `Geti` class provides the following methods: -- `download_project` -- Downloads a project by project name. +- `download_project_data` -- Downloads a project by project name (Geti-SDK representation), returns an interactive object. -- `upload_project` -- Uploads project from a folder. +- `upload_project_data` -- Uploads project (Geti-SDK representation) from a folder. - `download_all_projects` -- Downloads all projects found on the server. @@ -313,6 +313,18 @@ The `Geti` class provides the following methods: - `upload_all_projects` -- Uploads all projects found in a specified folder to the server. +- `export_project` -- Exports a project to an archive on disk. This method is useful for + creating a backup of a project, or for migrating a project to a different cluster. + +- `import_project` -- Imports a project from an archive on disk. This method is useful for + restoring a project from a backup, or for migrating a project to a different cluster. + +- `export_dataset` -- Exports a dataset to an archive on disk. This method is useful for + creating a backup of a dataset, or for migrating a dataset to a different cluster. + +- `import_dataset` -- Imports a dataset from an archive on disk. A new project will + be created for the dataset. This method is useful for restoring a project from a dataset + backup, or for migrating a dataset to a different cluster. - `upload_and_predict_image` -- Uploads a single image to an existing project on the server, and requests a prediction for that image. Optionally, the prediction can @@ -427,5 +439,5 @@ docker run --rm -ti -v $(pwd):/app geti-sdk:latest /bin/bash - Model upload - Prediction upload -- Exporting datasets to COCO/YOLO/VOC format: For this, you can use the export +- Importing datasets to an existing project: For this, you can use the import functionality from the Intel® Geti™ user interface instead. diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index bc42455d..f8a04af5 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -7,6 +7,7 @@ API Reference geti Annotation readers Data models + Import Export module Deployment HTTP session REST converters diff --git a/geti_sdk/__init__.py b/geti_sdk/__init__.py index d5cdd0db..b8bfdae7 100644 --- a/geti_sdk/__init__.py +++ b/geti_sdk/__init__.py @@ -17,9 +17,9 @@ ------------ These pages contain the documentation for the main SDK class, -:py:class:`~geti_sdk.sc_rest_client.Geti`. +:py:class:`~geti_sdk.geti.Geti`. -The :py:class:`~geti_sdk.sc_rest_client.Geti` class implements convenience +The :py:class:`~geti_sdk.geti.Geti` class implements convenience methods for common operations that can be performed on the Intel® Geti™ cluster, such as creating a project from a pre-existing dataset, downloading or uploading a project, uploading an image and getting a prediction for it and creating a deployment for a @@ -35,7 +35,31 @@ host="https://0.0.0.0", username="dummy_user", password="dummy_password" ) - geti.download_project(project_name="dummy_project") + geti.download_project_data(project_name="dummy_project") + +The :py:class:`~geti_sdk.geti.Geti` class provides a high-level interface for +import-export operations in Intel® Geti™ platform. Here is a list of these operations: +* Project download + :py:meth:`~geti_sdk.geti.Geti.download_project_data` method fetches the project data + and creates a local Python object that supports a range of operations with the project. +* Project upload + :py:meth:`~geti_sdk.geti.Geti.upload_project_data` method uploads the project data + from a local Python object to the Intel® Geti™ platform. +* Batched project download and upload + :py:meth:`~geti_sdk.geti.Geti.download_all_projects` and + :py:meth:`~geti_sdk.geti.Geti.upload_all_projects` methods download and upload + multiple projects at once. +* Project export + :py:meth:`~geti_sdk.geti.Geti.export_project` method exports the project snapshot + to a zip archive. The archive can be used to import the project to another or the same Intel® Geti™ + instance. +* Project import + :py:meth:`~geti_sdk.geti.Geti.import_project` method imports the project from a zip archive. +* Dataset export + :py:meth:`~geti_sdk.geti.Geti.export_dataset` method exports the dataset to a zip archive. +* Dataset import + :py:meth:`~geti_sdk.geti.Geti.import_dataset` method imports the dataset from a zip archive + as a new project. For custom operations or more fine-grained control over the behavior, the :py:mod:`~geti_sdk.rest_clients` subpackage should be used. @@ -48,20 +72,30 @@ .. rubric:: Project download and upload - .. automethod:: download_project + .. automethod:: download_project_data - .. automethod:: upload_project + .. automethod:: upload_project_data .. automethod:: download_all_projects .. automethod:: upload_all_projects + .. automethod:: import_project + + .. automethod:: export_project + + .. rubric:: Dataset export + + .. automethod:: export_dataset + .. rubric:: Project creation from dataset .. automethod:: create_single_task_project_from_dataset .. automethod:: create_task_chain_project_from_dataset + .. automethod:: import_dataset + .. rubric:: Project deployment .. automethod:: deploy_project diff --git a/geti_sdk/data_models/__init__.py b/geti_sdk/data_models/__init__.py index c61aceb5..26a580cf 100644 --- a/geti_sdk/data_models/__init__.py +++ b/geti_sdk/data_models/__init__.py @@ -23,7 +23,7 @@ :py:class:`~geti_sdk.data_models.model.Model` and many more. When interacting with the GETi cluster through the -:py:class:`geti_sdk.sc_rest_client.Geti` or the +:py:class:`geti_sdk.geti.Geti` or the :py:mod:`~geti_sdk.rest_clients`, all entities retrieved from the cluster will be deserialized into the data models defined in this package. diff --git a/geti_sdk/data_models/enums/dataset_format.py b/geti_sdk/data_models/enums/dataset_format.py new file mode 100644 index 00000000..9f4d1385 --- /dev/null +++ b/geti_sdk/data_models/enums/dataset_format.py @@ -0,0 +1,33 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from enum import Enum + + +class DatasetFormat(Enum): + """ + Enum representing the different annotation formats for datasets within an + Intel® Geti™ platform. + """ + + COCO = "coco" + YOLO = "yolo" + VOC = "voc" + DATUMARO = "datumaro" + + def __str__(self): + """ + Return the string representation of the DatasetFormat instance. + """ + return self.value diff --git a/geti_sdk/data_models/enums/job_type.py b/geti_sdk/data_models/enums/job_type.py index dc61965c..b7c5ad58 100644 --- a/geti_sdk/data_models/enums/job_type.py +++ b/geti_sdk/data_models/enums/job_type.py @@ -31,6 +31,9 @@ class JobType(Enum): TEST = "test" PREPARE_IMPORT_TO_NEW_PROJECT = "prepare_import_to_new_project" PERFORM_IMPORT_TO_NEW_PROJECT = "perform_import_to_new_project" + EXPORT_PROJECT = "export_project" + IMPORT_PROJECT = "import_project" + EXPORT_DATASET = "export_dataset" def __str__(self) -> str: """ diff --git a/geti_sdk/data_models/job.py b/geti_sdk/data_models/job.py index b9df9f64..11ea065f 100644 --- a/geti_sdk/data_models/job.py +++ b/geti_sdk/data_models/job.py @@ -102,6 +102,31 @@ class ProjectMetadata: name: Optional[str] = None id: Optional[str] = None + type: Optional[str] = None + + +@attr.define +class DatasetMetadata: + """ + Metadata related to a dataset on the GETi cluster. + + :var name: Name of the dataset + :var id: ID of the dataset + """ + + name: Optional[str] = None + id: Optional[str] = None + + +@attr.define +class ParametersMetadata: + """ + Metadata related to a project import to the GETi cluster. + + :var file_id: ID of the uploaded file + """ + + file_id: Optional[str] = None @attr.define @@ -153,11 +178,16 @@ class JobMetadata: task: Optional[TaskMetadata] = None project: Optional[ProjectMetadata] = None + dataset: Optional[DatasetMetadata] = None + parameters: Optional[ParametersMetadata] = None test: Optional[TestMetadata] = None base_model_id: Optional[str] = None model_storage_id: Optional[str] = None optimization_type: Optional[str] = None optimized_model_id: Optional[str] = None + download_url: Optional[str] = None + export_format: Optional[str] = None + file_id: Optional[str] = None scores: Optional[List[ScoreMetadata]] = None trained_model: Optional[ModelMetadata] = None # Added in Geti v1.7 warnings: Optional[List[dict]] = None # Added in Geti v1.13 for dataset import jobs @@ -191,6 +221,11 @@ class Job: :var id: Unique database ID of the job :var project_id: Unique database ID of the project from which the job originates :var type: Type of the job + :var creation_time: Time at which the job was created + :var start_time: Time at which the job started running + :var end_time: Time at which the job finished running + :var author: Author of the job + :var cancellation_info: Information relating to the cancellation of the jobW :var metadata: JobMetadata object holding metadata for the job """ @@ -282,6 +317,12 @@ def update(self, session: GetiSession) -> "Job": self.steps = response.get("steps", None) self.state = JobState(response["state"]) + self.metadata.project_id = response["metadata"].get("project_id", None) + self.metadata.download_url = response["metadata"].get("download_url", None) + self.metadata.warnings = response["metadata"].get("warnings", None) + self.metadata.supported_project_types = response["metadata"].get( + "supported_project_types", None + ) if self._geti_version is None: self.geti_version = session.version diff --git a/geti_sdk/deployment/__init__.py b/geti_sdk/deployment/__init__.py index 0b178644..5ee24577 100644 --- a/geti_sdk/deployment/__init__.py +++ b/geti_sdk/deployment/__init__.py @@ -24,8 +24,8 @@ is the same in both cases. Creating a deployment for a project is done through the -:py:class:`~geti_sdk.sc_rest_client.Geti` class, which provides a -convenience method :py:meth:`~geti_sdk.sc_rest_client.Geti.deploy_project`. +:py:class:`~geti_sdk.geti.Geti` class, which provides a +convenience method :py:meth:`~geti_sdk.geti.Geti.deploy_project`. The following code snippet shows: diff --git a/geti_sdk/deployment/deployed_model.py b/geti_sdk/deployment/deployed_model.py index daeee19b..4e35baa9 100644 --- a/geti_sdk/deployment/deployed_model.py +++ b/geti_sdk/deployment/deployed_model.py @@ -187,8 +187,8 @@ def get_data(self, source: Union[str, os.PathLike, GetiSession]): raise ValueError( "\n" "This deployment model is not compatible with the current SDK. Proposed solutions:\n" - "1. Please deploy a model using GETi Platform version 2.0.0 or higher.\n" - "2. Downgrade to a compatible GETi-SDK version to continue using this model.\n\n" + "1. Please deploy a model using Intel Geti Platform version 2.0.0 or higher.\n" + "2. Downgrade to a compatible Geti-SDK version to continue using this model.\n\n" ) elif isinstance(source, GetiSession): diff --git a/geti_sdk/geti.py b/geti_sdk/geti.py index 1bcb76a1..8bdc9e37 100644 --- a/geti_sdk/geti.py +++ b/geti_sdk/geti.py @@ -14,20 +14,15 @@ import logging import os import sys -import time import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union import numpy as np -from pathvalidate import sanitize_filepath -from tqdm.auto import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm - -from .annotation_readers import ( - AnnotationReader, - DatumAnnotationReader, - GetiAnnotationReader, -) + +from geti_sdk.data_models.enums.dataset_format import DatasetFormat +from geti_sdk.import_export.import_export_module import GetiIE + +from .annotation_readers import AnnotationReader, DatumAnnotationReader from .data_models import ( Dataset, Image, @@ -40,19 +35,13 @@ from .data_models.containers import MediaList from .data_models.model import BaseModel from .deployment import Deployment -from .http_session import ( - GetiRequestException, - GetiSession, - ServerCredentialConfig, - ServerTokenConfig, -) +from .http_session import GetiSession, ServerCredentialConfig, ServerTokenConfig from .rest_clients import ( AnnotationClient, ConfigurationClient, DatasetClient, DeploymentClient, ImageClient, - ModelClient, PredictionClient, ProjectClient, VideoClient, @@ -60,7 +49,6 @@ from .utils import ( generate_classification_labels, get_default_workspace_id, - get_project_folder_name, get_task_types_by_project_type, show_image_with_annotation_scene, show_video_frames_with_annotation_scenes, @@ -192,6 +180,11 @@ def __init__( self.project_client = ProjectClient( workspace_id=workspace_id, session=self.session ) + self.import_export_module = GetiIE( + session=self.session, + workspace_id=self.workspace_id, + project_client=self.project_client, + ) # Cache of deployment clients for projects in the workspace self._deployment_clients: Dict[str, DeploymentClient] = {} @@ -207,7 +200,9 @@ def projects(self) -> List[Project]: """ return self.project_client.get_all_projects() - def get_project(self, project_name: str) -> Project: + def get_project( + self, project_name: str, project_id: Optional[str] = None + ) -> Project: """ Return the Intel® Geti™ project named `project_name`, if any. If no project by that name is found on the Intel® Geti™ server, this method will raise a @@ -217,7 +212,9 @@ def get_project(self, project_name: str) -> Project: :raises: KeyError if project named `project_name` is not found on the server :return: Project identified by `project_name` """ - project = self.project_client.get_project_by_name(project_name=project_name) + project = self.project_client.get_project_by_name( + project_name=project_name, project_id=project_id + ) if project is None: raise KeyError( f"Project '{project_name}' was not found in the current workspace on " @@ -225,9 +222,10 @@ def get_project(self, project_name: str) -> Project: ) return project - def download_project( + def download_project_data( self, project_name: str, + project_id: Optional[str] = None, target_folder: Optional[str] = None, include_predictions: bool = False, include_active_models: bool = False, @@ -304,85 +302,13 @@ def download_project( :return: Project object, holding information obtained from the cluster regarding the downloaded project """ - # Obtain project details from cluster - project = self.get_project(project_name) - - # Validate or create target_folder - if target_folder is None: - target_folder = os.path.join(".", get_project_folder_name(project)) - else: - sanitize_filepath(target_folder, platform="auto") - os.makedirs(target_folder, exist_ok=True, mode=0o770) - - # Download project creation parameters: - self.project_client.download_project_info( - project_name=project_name, path_to_folder=target_folder - ) - - # Download images - image_client = ImageClient( - workspace_id=self.workspace_id, session=self.session, project=project - ) - images = image_client.get_all_images() - if len(images) > 0: - image_client.download_all( - path_to_folder=target_folder, - append_image_uid=images.has_duplicate_filenames, - max_threads=max_threads, - ) - - # Download videos - video_client = VideoClient( - workspace_id=self.workspace_id, session=self.session, project=project - ) - videos = video_client.get_all_videos() - if len(videos) > 0: - video_client.download_all( - path_to_folder=target_folder, - append_video_uid=videos.has_duplicate_filenames, - max_threads=max_threads, - ) - - # Download annotations - annotation_client = AnnotationClient( - session=self.session, project=project, workspace_id=self.workspace_id - ) - annotation_client.download_all_annotations( - path_to_folder=target_folder, max_threads=max_threads - ) - - # Download predictions - prediction_client = PredictionClient( - workspace_id=self.workspace_id, session=self.session, project=project - ) - if prediction_client.ready_to_predict and include_predictions: - if len(images) > 0: - prediction_client.download_predictions_for_images( - images=images, - path_to_folder=target_folder, - include_result_media=True, - ) - if len(videos) > 0: - prediction_client.download_predictions_for_videos( - videos=videos, - path_to_folder=target_folder, - include_result_media=True, - inferred_frames_only=False, - ) - - # Download configuration - configuration_client = ConfigurationClient( - workspace_id=self.workspace_id, session=self.session, project=project + project = self.import_export_module.download_project_data( + project=self.get_project(project_name=project_name, project_id=project_id), + target_folder=target_folder, + include_predictions=include_predictions, + include_active_models=include_active_models, + max_threads=max_threads, ) - configuration_client.download_configuration(path_to_folder=target_folder) - - # Download active models - if include_active_models: - model_client = ModelClient( - workspace_id=self.workspace_id, session=self.session, project=project - ) - model_client.download_all_active_models(path_to_folder=target_folder) - # Download deployment if include_deployment: logging.info("Creating deployment for project...") @@ -391,7 +317,7 @@ def download_project( logging.info(f"Project '{project.name}' was downloaded successfully.") return project - def upload_project( + def upload_project_data( self, target_folder: str, project_name: Optional[str] = None, @@ -433,126 +359,157 @@ def upload_project( :return: Project object, holding information obtained from the cluster regarding the uploaded project """ - project = self.project_client.create_project_from_folder( - path_to_folder=target_folder, project_name=project_name + return self.import_export_module.upload_project_data( + target_folder=target_folder, + project_name=project_name, + enable_auto_train=enable_auto_train, + max_threads=max_threads, ) - # Disable auto-train to prevent the project from training right away - configuration_client = ConfigurationClient( - workspace_id=self.workspace_id, session=self.session, project=project - ) - configuration_client.set_project_auto_train(auto_train=False) + def download_all_projects( + self, target_folder: str, include_predictions: bool = True + ) -> List[Project]: + """ + Download all projects in the workspace from the Intel® Geti™ server. - # Upload media - image_client = ImageClient( - workspace_id=self.workspace_id, session=self.session, project=project + :param target_folder: Directory on local disk to download the project data to. + If not specified, this method will create a directory named 'projects' in + the current working directory. + :param include_predictions: True to also download the predictions for all + images and videos in the project, False to not download any predictions. + If this is set to True but the project has no trained models, downloading + predictions will be skipped. + :return: List of Project objects, each entry corresponding to one of the + projects found on the Intel® Geti™ server + """ + return self.import_export_module.download_all_projects( + target_folder=target_folder, include_predictions=include_predictions ) - video_client = VideoClient( - workspace_id=self.workspace_id, session=self.session, project=project + + def upload_all_projects(self, target_folder: str) -> List[Project]: + """ + Upload all projects found in the directory `target_folder` on local disk to + the Intel® Geti™ server. + + This method expects the directory `target_folder` to contain subfolders. Each + subfolder should correspond to the (previously downloaded) data for one + project. The method looks for project folders non-recursively, meaning that + only folders directly below the `target_folder` in the hierarchy are + considered to be uploaded as project. + + :param target_folder: Directory on local disk to retrieve the project data from + :return: List of Project objects, each entry corresponding to one of the + projects uploaded to the Intel® Geti™ server. + """ + return self.import_export_module.upload_all_projects( + target_folder=target_folder ) - # Check the media folders inside the project folder. If they are organized - # according to the projects datasets, upload the media into their corresponding - # dataset. Otherwise, upload all media into training dataset. - dataset_client = DatasetClient( - workspace_id=self.workspace_id, session=self.session, project=project + def export_project( + self, + filepath: os.PathLike, + project_name: str, + project_id: Optional[str] = None, + ) -> None: + """ + Export a project with name `project_name` to the file specified by `filepath`. + The project will be saved in a .zip file format, containing all project data + and metadata required for project import to another instance of the Intel® Geti™ platform. + + :param filepath: Path to the file to save the project to + :param project_name: Name of the project to export + :param project_id: Optional ID of the project to export. If not specified, the + project with name `project_name` will be exported. + """ + if project_id is None: + project_id = self.get_project(project_name=project_name).id + assert project_id is not None + self.import_export_module.export_project( + project_id=project_id, filepath=filepath ) - if len(project.datasets) == 1 or not dataset_client.has_dataset_subfolders( - target_folder - ): - # Upload all media directly to the training dataset - images = image_client.upload_folder( - path_to_folder=os.path.join(target_folder, "images"), - max_threads=max_threads, - ) - videos = video_client.upload_folder( - path_to_folder=os.path.join(target_folder, "videos"), - max_threads=max_threads, - ) - else: - # Make sure that media is uploaded to the correct dataset - images: MediaList[Image] = MediaList([]) - videos: MediaList[Video] = MediaList([]) - for dataset in project.datasets: - images.extend( - image_client.upload_folder( - path_to_folder=os.path.join( - target_folder, "images", dataset.name - ), - dataset=dataset, - max_threads=max_threads, - ) - ) - videos.extend( - video_client.upload_folder( - path_to_folder=os.path.join( - target_folder, "videos", dataset.name - ), - dataset=dataset, - max_threads=max_threads, - ) - ) - # Short sleep to make sure all uploaded media is processed server side - time.sleep(5) + def import_project( + self, filepath: os.PathLike, project_name: Optional[str] = None + ) -> Project: + """ + Import a project from the zip file specified by `filepath` to the Intel® Geti™ server. + The project will be created on the server with the name `project_name`, if + specified, esle with the archive base name. + > Note: The project zip archive should be exported from the Geti™ server of the same version. - # Upload annotations - annotation_reader = GetiAnnotationReader( - base_data_folder=os.path.join(target_folder, "annotations"), - task_type=None, + :param filepath: Path to the file to import the project from + :param project_name: Optional name of the project to create on the cluster. If + left unspecified, the name of the archive file will be used. + :return: Project object, holding information obtained from the cluster + regarding the uploaded project. + """ + return self.import_export_module.import_project( + filepath=filepath, project_name=project_name ) - annotation_client = AnnotationClient[GetiAnnotationReader]( - session=self.session, + + def export_dataset( + self, + project: Project, + dataset: Dataset, + filepath: os.PathLike, + export_format: Union[str, DatasetFormat] = "DATUMARO", + include_unannotated_media: bool = False, + ): + """ + Export a dataset from a project to a file specified by `filepath`. The dataset + will be saved in the format specified by `export_format`. + + :param project: Project object to export the dataset from + :param dataset: Dataset object to export + :param filepath: Path to the file to save the dataset to + :param export_format: Format to save the dataset in. Provide on of the following + strings: 'COCO', 'YOLO', 'VOC', 'DATUMARO' or a corresponding DatasetFormat object. + :param include_unannotated_media: True to include media that have no annotations + in the dataset, False to only include media with annotations. Defaults to + False. + """ + if type(export_format) is str: + export_format = DatasetFormat[export_format] + self.import_export_module.export_dataset( project=project, - workspace_id=self.workspace_id, - annotation_reader=annotation_reader, + dataset=dataset, + filepath=filepath, + export_format=export_format, + include_unannotated_media=include_unannotated_media, ) - if len(images) > 0: - annotation_client.upload_annotations_for_images( - images=images, - ) - if len(videos) > 0: - are_videos_processed = False - start_time = time.time() - logging.info( - "Waiting for the Geti server to process all uploaded videos..." - ) - while (not are_videos_processed) and (time.time() - start_time < 100): - # Ensure all uploaded videos are processed by the server - project_videos = video_client.get_all_videos() - uploaded_ids = {video.id for video in videos} - project_video_ids = {video.id for video in project_videos} - are_videos_processed = uploaded_ids.issubset(project_video_ids) - time.sleep(1) - annotation_client.upload_annotations_for_videos( - videos=videos, - ) - configuration_file = os.path.join(target_folder, "configuration.json") - if os.path.isfile(configuration_file): - result = None - try: - result = configuration_client.apply_from_file( - path_to_folder=target_folder - ) - except GetiRequestException: - logging.warning( - f"Attempted to set configuration according to the " - f"'configuration.json' file in the project directory, but setting " - f"the configuration failed. Probably the configuration specified " - f"in '{configuration_file}' does " - f"not apply to the default model for one of the tasks in the " - f"project. Please make sure to reconfigure the models manually." - ) - if result is None: - logging.warning( - f"Not all configurable parameters could be set according to the " - f"configuration in {configuration_file}. Please make sure to " - f"verify model configuration manually." - ) - configuration_client.set_project_auto_train(auto_train=enable_auto_train) - logging.info(f"Project '{project.name}' was uploaded successfully.") - return project + def import_dataset( + self, filepath: os.PathLike, project_name: str, project_type: str + ) -> Project: + """ + Import a dataset from the zip archive specified by `filepath` to the Intel® Geti™ server. + A new project will be created from the dataset on the server with the name `project_name`. + Please set the `project_type` to determine the type of the project with one of possible values are: + + * classification + * classification_hierarchical + * detection + * segmentation + * instance_segmentation + * anomaly_classification + * anomaly_detection + * anomaly_segmentation + * detection_oriented + * detection_to_classification + * detection_to_segmentation + + > Note: The dataset zip archive should be exported from the Geti™ server of the same version. + + :param filepath: Path to the file to import the dataset from + :param project_name: Name of the project to create on the cluster + :param project_type: Type of the project, this determines which task the + project will perform. + :return: Project object, holding information obtained from the cluster + regarding the uploaded project. + """ + return self.import_export_module.import_dataset_as_new_project( + filepath=filepath, project_name=project_name, project_type=project_type + ) def create_single_task_project_from_dataset( self, @@ -844,94 +801,6 @@ def create_task_chain_project_from_dataset( configuration_client.set_project_auto_train(auto_train=enable_auto_train) return project - def download_all_projects( - self, target_folder: str, include_predictions: bool = True - ) -> List[Project]: - """ - Download all projects in the workspace from the Intel® Geti™ server. - - :param target_folder: Directory on local disk to download the project data to. - If not specified, this method will create a directory named 'projects' in - the current working directory. - :param include_predictions: True to also download the predictions for all - images and videos in the project, False to not download any predictions. - If this is set to True but the project has no trained models, downloading - predictions will be skipped. - :return: List of Project objects, each entry corresponding to one of the - projects found on the Intel® Geti™ server - """ - # Obtain project details from cluster - projects = self.projects - - # Validate or create target_folder - if target_folder is None: - target_folder = os.path.join(".", "projects") - os.makedirs(target_folder, exist_ok=True, mode=0o770) - logging.info( - f"Found {len(projects)} projects in the designated workspace on the " - f"Intel® Geti™ server. Commencing project download..." - ) - - # Download all found projects - with logging_redirect_tqdm(tqdm_class=tqdm): - for index, project in enumerate( - tqdm(projects, desc="Downloading projects") - ): - logging.info( - f"Downloading project '{project.name}'... {index+1}/{len(projects)}." - ) - self.download_project( - project_name=project.name, - target_folder=os.path.join( - target_folder, get_project_folder_name(project) - ), - include_predictions=include_predictions, - ) - return projects - - def upload_all_projects(self, target_folder: str) -> List[Project]: - """ - Upload all projects found in the directory `target_folder` on local disk to - the Intel® Geti™ server. - - This method expects the directory `target_folder` to contain subfolders. Each - subfolder should correspond to the (previously downloaded) data for one - project. The method looks for project folders non-recursively, meaning that - only folders directly below the `target_folder` in the hierarchy are - considered to be uploaded as project. - - :param target_folder: Directory on local disk to retrieve the project data from - :return: List of Project objects, each entry corresponding to one of the - projects uploaded to the Intel® Geti™ server. - """ - candidate_project_folders = [ - os.path.join(target_folder, subfolder) - for subfolder in os.listdir(target_folder) - ] - project_folders = [ - folder - for folder in candidate_project_folders - if ProjectClient.is_project_dir(folder) - ] - logging.info( - f"Found {len(project_folders)} project data folders in the target " - f"directory '{target_folder}'. Commencing project upload..." - ) - projects: List[Project] = [] - with logging_redirect_tqdm(tqdm_class=tqdm): - for index, project_folder in enumerate( - tqdm(project_folders, desc="Uploading projects") - ): - logging.info( - f"Uploading project from folder '{os.path.basename(project_folder)}'..." - f" {index + 1}/{len(project_folders)}." - ) - project = self.upload_project( - target_folder=project_folder, enable_auto_train=False - ) - projects.append(project) - return projects - def upload_and_predict_media_folder( self, project_name: str, diff --git a/geti_sdk/http_session/geti_session.py b/geti_sdk/http_session/geti_session.py index 575597e6..0c34ff16 100644 --- a/geti_sdk/http_session/geti_session.py +++ b/geti_sdk/http_session/geti_session.py @@ -233,10 +233,11 @@ def get_rest_response( url: str, method: str, contenttype: str = "json", - data=None, + data: Optional[Any] = None, allow_reauthentication: bool = True, include_organization_id: bool = True, allow_text_response: bool = False, + request_headers: Dict[str, str] = {}, ) -> Union[Response, dict, list]: """ Return the REST response from a request to `url` with `method`. @@ -257,9 +258,10 @@ def get_rest_response( when authentication has expired. However, some endpoints are designed to return text responses, for those endpoints this parameter should be set to True + :param request_headers: Additional headers to include in the request """ - if url.startswith(self.config.api_pattern): - url = url[len(self.config.api_pattern) :] + if self.config.api_pattern in url: + url = url.split(self.config.api_pattern)[-1] self._update_headers_for_content_type(content_type=contenttype) @@ -280,8 +282,10 @@ def get_rest_response( else: raise ValueError( f"Making a POST request with content of type {contenttype} is " - f"currently not supported through the Geti SDK." + f"currently not supported through the Intel Geti SDK." ) + elif method == "PATCH": + kw_data_arg = {"data": data} else: kw_data_arg = {} @@ -306,7 +310,9 @@ def get_rest_response( last_conn_error: Optional[ConnectionError] = None while retries: try: - response = self.request(**request_params, proxies=self._proxies) + response = self.request( + **request_params, proxies=self._proxies, headers=request_headers + ) break except requests.exceptions.SSLError as error: raise requests.exceptions.SSLError( @@ -320,13 +326,14 @@ def get_rest_response( if last_conn_error is not None: raise last_conn_error - response_content_type = response.headers.get("Content-Type", []) if ( response.status_code not in SUCCESS_STATUS_CODES or "text/html" in response_content_type ): - if not ("text/html" in response_content_type and allow_text_response): + if response.status_code == 204 and method in ["OPTIONS", "PATCH"]: + pass + elif not ("text/html" in response_content_type and allow_text_response): response = self._handle_error_response( response=response, request_params=request_params, @@ -334,7 +341,6 @@ def get_rest_response( allow_reauthentication=allow_reauthentication, content_type=contenttype, ) - if response.headers.get("Content-Type", "").startswith("application/json"): result = response.json() else: @@ -528,6 +534,10 @@ def _update_headers_for_content_type(self, content_type: str) -> None: self.headers.pop("Content-Type", None) elif content_type == "zip": self.headers.update({"Content-Type": "application/zip"}) + elif content_type == "offset+octet-stream": + self.headers.update({"Content-Type": "application/offset+octet-stream"}) + else: + self.headers.update({"Content-Type": content_type}) @property def base_url(self) -> str: diff --git a/geti_sdk/import_export/__init__.py b/geti_sdk/import_export/__init__.py new file mode 100644 index 00000000..d5c6047d --- /dev/null +++ b/geti_sdk/import_export/__init__.py @@ -0,0 +1,32 @@ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +""" +Introduction +------------ + +The `import-export` package contains the `GetiIE` class with number of methods for importing and +exporting projets and datasets to and from the Intel® Geti™ platform. + +Module contents +--------------- +""" + +from .import_export_module import GetiIE +from .tus_uploader import TUSUploader + +__all__ = [ + "TUSUploader", + "GetiIE", +] diff --git a/geti_sdk/import_export/import_export_module.py b/geti_sdk/import_export/import_export_module.py new file mode 100644 index 00000000..274ff76c --- /dev/null +++ b/geti_sdk/import_export/import_export_module.py @@ -0,0 +1,592 @@ +import logging +import os +import time +from typing import List, Optional + +from pathvalidate import sanitize_filepath +from tqdm.auto import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + +from geti_sdk.annotation_readers.geti_annotation_reader import GetiAnnotationReader +from geti_sdk.data_models.containers.media_list import MediaList +from geti_sdk.data_models.enums.dataset_format import DatasetFormat +from geti_sdk.data_models.media import Image, Video +from geti_sdk.data_models.project import Dataset, Project +from geti_sdk.http_session.exception import GetiRequestException +from geti_sdk.http_session.geti_session import GetiSession +from geti_sdk.import_export.tus_uploader import TUSUploader +from geti_sdk.rest_clients.annotation_clients.annotation_client import AnnotationClient +from geti_sdk.rest_clients.configuration_client import ConfigurationClient +from geti_sdk.rest_clients.dataset_client import DatasetClient +from geti_sdk.rest_clients.media_client.image_client import ImageClient +from geti_sdk.rest_clients.media_client.video_client import VideoClient +from geti_sdk.rest_clients.model_client import ModelClient +from geti_sdk.rest_clients.prediction_client import PredictionClient +from geti_sdk.rest_clients.project_client.project_client import ProjectClient +from geti_sdk.utils.job_helpers import get_job_with_timeout, monitor_job +from geti_sdk.utils.project_helpers import get_project_folder_name + + +class GetiIE: + """ + Class to handle importing and exporting projects and datasets to and from the Intel® Geti™ platform. + """ + + def __init__( + self, workspace_id: str, session: GetiSession, project_client: ProjectClient + ) -> None: + """ + Initialize the GetiIE class. + + :param workspace_id: The workspace id. + :param session: The Geti session. + :param project_client: The project client. + """ + self.workspace_id = workspace_id + self.session = session + self.base_url = f"workspaces/{workspace_id}/" + self.project_client = project_client + + def download_project_data( + self, + project: Project, + target_folder: Optional[str] = None, + include_predictions: bool = False, + include_active_models: bool = False, + max_threads: int = 10, + ) -> Project: + """ + Download a project from the Geti Platform. + + :param project_name: The name of the project. + :param project_id: The id of the project. + :param target_folder: The path to save the downloaded project. + :param include_predictions: Whether to download predictions for the project. + :param include_active_models: Whether to download active models for the project. + :param max_threads: The maximum number of threads to use for downloading media. + :return: The downloaded project. + """ + # Validate or create target_folder + if target_folder is None: + target_folder = os.path.join(".", get_project_folder_name(project)) + else: + sanitize_filepath(target_folder, platform="auto") + os.makedirs(target_folder, exist_ok=True, mode=0o770) + + # Download project creation parameters: + self.project_client.download_project_info( + project_name=project.name, path_to_folder=target_folder + ) + + # Download images + image_client = ImageClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + images = image_client.get_all_images() + if len(images) > 0: + image_client.download_all( + path_to_folder=target_folder, + append_image_uid=images.has_duplicate_filenames, + max_threads=max_threads, + ) + + # Download videos + video_client = VideoClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + videos = video_client.get_all_videos() + if len(videos) > 0: + video_client.download_all( + path_to_folder=target_folder, + append_video_uid=videos.has_duplicate_filenames, + max_threads=max_threads, + ) + + # Download annotations + annotation_client = AnnotationClient( + session=self.session, project=project, workspace_id=self.workspace_id + ) + annotation_client.download_all_annotations( + path_to_folder=target_folder, max_threads=max_threads + ) + + # Download predictions + prediction_client = PredictionClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + if prediction_client.ready_to_predict and include_predictions: + if len(images) > 0: + prediction_client.download_predictions_for_images( + images=images, + path_to_folder=target_folder, + include_result_media=True, + ) + if len(videos) > 0: + prediction_client.download_predictions_for_videos( + videos=videos, + path_to_folder=target_folder, + include_result_media=True, + inferred_frames_only=False, + ) + + # Download configuration + configuration_client = ConfigurationClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + configuration_client.download_configuration(path_to_folder=target_folder) + + # Download active models + if include_active_models: + model_client = ModelClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + model_client.download_all_active_models(path_to_folder=target_folder) + + return project + + def upload_project_data( + self, + target_folder: str, + project_name: Optional[str] = None, + enable_auto_train: bool = True, + max_threads: int = 5, + ) -> Project: + """ + Upload a project to the Geti Platform. + + :param target_folder: The path to the project data folder. + :param project_name: The name of the project. + :param enable_auto_train: Whether to enable auto-train for the project. + :param max_threads: The maximum number of threads to use for uploading media. + :return: The uploaded project. + """ + project = self.project_client.create_project_from_folder( + path_to_folder=target_folder, project_name=project_name + ) + + # Disable auto-train to prevent the project from training right away + configuration_client = ConfigurationClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + configuration_client.set_project_auto_train(auto_train=False) + + # Upload media + image_client = ImageClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + video_client = VideoClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + + # Check the media folders inside the project folder. If they are organized + # according to the projects datasets, upload the media into their corresponding + # dataset. Otherwise, upload all media into training dataset. + dataset_client = DatasetClient( + workspace_id=self.workspace_id, session=self.session, project=project + ) + if len(project.datasets) == 1 or not dataset_client.has_dataset_subfolders( + target_folder + ): + # Upload all media directly to the training dataset + images = image_client.upload_folder( + path_to_folder=os.path.join(target_folder, "images"), + max_threads=max_threads, + ) + videos = video_client.upload_folder( + path_to_folder=os.path.join(target_folder, "videos"), + max_threads=max_threads, + ) + else: + # Make sure that media is uploaded to the correct dataset + images: MediaList[Image] = MediaList([]) + videos: MediaList[Video] = MediaList([]) + for dataset in project.datasets: + images.extend( + image_client.upload_folder( + path_to_folder=os.path.join( + target_folder, "images", dataset.name + ), + dataset=dataset, + max_threads=max_threads, + ) + ) + videos.extend( + video_client.upload_folder( + path_to_folder=os.path.join( + target_folder, "videos", dataset.name + ), + dataset=dataset, + max_threads=max_threads, + ) + ) + + # Short sleep to make sure all uploaded media is processed server side + time.sleep(5) + + # Upload annotations + annotation_reader = GetiAnnotationReader( + base_data_folder=os.path.join(target_folder, "annotations"), + task_type=None, + ) + annotation_client = AnnotationClient[GetiAnnotationReader]( + session=self.session, + project=project, + workspace_id=self.workspace_id, + annotation_reader=annotation_reader, + ) + if len(images) > 0: + annotation_client.upload_annotations_for_images( + images=images, + ) + if len(videos) > 0: + are_videos_processed = False + start_time = time.time() + logging.info( + "Waiting for the Geti server to process all uploaded videos..." + ) + while (not are_videos_processed) and (time.time() - start_time < 100): + # Ensure all uploaded videos are processed by the server + project_videos = video_client.get_all_videos() + uploaded_ids = {video.id for video in videos} + project_video_ids = {video.id for video in project_videos} + are_videos_processed = uploaded_ids.issubset(project_video_ids) + time.sleep(1) + annotation_client.upload_annotations_for_videos( + videos=videos, + ) + + configuration_file = os.path.join(target_folder, "configuration.json") + if os.path.isfile(configuration_file): + result = None + try: + result = configuration_client.apply_from_file( + path_to_folder=target_folder + ) + except GetiRequestException: + logging.warning( + f"Attempted to set configuration according to the " + f"'configuration.json' file in the project directory, but setting " + f"the configuration failed. Probably the configuration specified " + f"in '{configuration_file}' does " + f"not apply to the default model for one of the tasks in the " + f"project. Please make sure to reconfigure the models manually." + ) + if result is None: + logging.warning( + f"Not all configurable parameters could be set according to the " + f"configuration in {configuration_file}. Please make sure to " + f"verify model configuration manually." + ) + configuration_client.set_project_auto_train(auto_train=enable_auto_train) + logging.info(f"Project '{project.name}' was uploaded successfully.") + return project + + def download_all_projects( + self, target_folder: str, include_predictions: bool = True + ) -> List[Project]: + """ + Download all projects from the Geti Platform. + + :param target_folder: The path to the directory to save the downloaded projects. + :param include_predictions: Whether to download predictions for the projects. + :return: The downloaded projects. + """ + # Obtain project details from cluster + projects = self.project_client.get_all_projects() + + # Validate or create target_folder + if target_folder is None: + target_folder = os.path.join(".", "projects") + os.makedirs(target_folder, exist_ok=True, mode=0o770) + logging.info( + f"Found {len(projects)} projects in the designated workspace on the " + f"Intel® Geti™ server. Commencing project download..." + ) + + # Download all found projects + with logging_redirect_tqdm(tqdm_class=tqdm): + for index, project in enumerate( + tqdm(projects, desc="Downloading projects") + ): + logging.info( + f"Downloading project '{project.name}'... {index+1}/{len(projects)}." + ) + self.download_project_data( + project=project, + target_folder=os.path.join( + target_folder, get_project_folder_name(project) + ), + include_predictions=include_predictions, + ) + return projects + + def upload_all_projects(self, target_folder: str) -> List[Project]: + """ + Upload all projects in the target directory to the Geti Platform. + + :param target_folder: The path to the directory containing the project data folders. + :return: The uploaded projects. + """ + candidate_project_folders = [ + os.path.join(target_folder, subfolder) + for subfolder in os.listdir(target_folder) + ] + project_folders = [ + folder + for folder in candidate_project_folders + if ProjectClient.is_project_dir(folder) + ] + logging.info( + f"Found {len(project_folders)} project data folders in the target " + f"directory '{target_folder}'. Commencing project upload..." + ) + projects: List[Project] = [] + with logging_redirect_tqdm(tqdm_class=tqdm): + for index, project_folder in enumerate( + tqdm(project_folders, desc="Uploading projects") + ): + logging.info( + f"Uploading project from folder '{os.path.basename(project_folder)}'..." + f" {index + 1}/{len(project_folders)}." + ) + project = self.upload_project_data( + target_folder=project_folder, enable_auto_train=False + ) + projects.append(project) + return projects + + def import_dataset_as_new_project( + self, filepath: os.PathLike, project_name: str, project_type: str + ) -> Project: + """ + Import a dataset as a new project to the Geti Platform. + + :param filepath: The path to the dataset archive. + :param project_name: The name of the new project. + :param project_type: The type of the new project. Provide one of + [classification, classification_hierarchical, detection, segmentation, + instance_segmentation, anomaly_classification, anomaly_detection, anomaly_segmentation, + detection_oriented, detection_to_classification, detection_to_segmentation] + :return: The imported project. + :raises: RuntimeError if the project type is not supported for the imported dataset. + """ + # Upload the dataset archive to the server + upload_endpoint = self.base_url + "datasets/uploads/resumable" + file_id = self._tus_upload_file( + upload_endpoint=upload_endpoint, filepath=filepath + ) + # Prepare for import + response = self.session.get_rest_response( + url=f"{self.base_url}datasets:prepare-for-import?file_id={file_id}", + method="POST", + ) + job = get_job_with_timeout( + job_id=response["job_id"], + session=self.session, + workspace_id=self.workspace_id, + job_type="import_dataset", + ) + job = monitor_job(session=self.session, job=job, interval=5) + # Make sure that the project type is supported for the imported dataset + if "_to_" in project_type: + # Translate the SDK `detection_to_segmentation` project type format + # to the Geti Platform `detection_segmentation` format + project_type = project_type.replace("_to_", "_") + project_dict = next( + ( + entry + for entry in job.metadata.supported_project_types + if entry["project_type"] == project_type + ), + None, + ) + if project_dict is None: + supported_project_types = [ + entry["project_type"] for entry in job.metadata.supported_project_types + ] + raise RuntimeError( + f"Project type '{project_type}' is not supported for the imported dataset.\n" + f" Please select one of the supported project types: `{supported_project_types}`" + ) + # Create a new project from the imported dataset + label_names = [ + label_dict["name"] + for task_dict in project_dict["pipeline"]["tasks"] + for label_dict in task_dict["labels"] + ] + data = { + "project_name": project_name, + "task_type": project_type, + "file_id": job.metadata.file_id, + "labels": [{"name": label_name} for label_name in label_names], + } + response = self.session.get_rest_response( + url=f"{self.base_url}projects:import-from-dataset", + method="POST", + data=data, + ) + # Get the job id and monitor the job + # until it returns the project id + job = get_job_with_timeout( + job_id=response["job_id"], + session=self.session, + workspace_id=self.workspace_id, + job_type="import_project_from_dataset", + ) + job = monitor_job(session=self.session, job=job, interval=5) + logging.info( + f"Project '{project_name}' was successfully imported from the dataset." + ) + imported_project = self.project_client.get_project_by_name( + project_name=project_name, + project_id=job.metadata.project_id, + ) + if imported_project is None: + raise RuntimeError( + f"Failed to retrieve the imported project '{project_name}'." + ) + return imported_project + + def import_project( + self, filepath: os.PathLike, project_name: Optional[str] = None + ) -> Project: + """ + Import a project to the Geti Platform. + + :param filepath: The path to the project archive. + :param project_name: The name of the project. + :return: The imported project. + """ + if project_name is None: + project_name = os.path.basename(filepath).split(".")[0] + + upload_endpoint = self.base_url + "projects/uploads/resumable" + file_id = self._tus_upload_file( + upload_endpoint=upload_endpoint, filepath=filepath + ) + + # Start project import process using the uploaded archive + response = self.session.get_rest_response( + url=f"{self.base_url}projects:import", + method="POST", + data={ + "file_id": file_id, + "project_name": project_name, + }, + ) + + job = get_job_with_timeout( + job_id=response["job_id"], + session=self.session, + workspace_id=self.workspace_id, + job_type="import_project", + ) + + job = monitor_job(session=self.session, job=job, interval=5) + imported_project = self.project_client.get_project_by_name( + project_name=project_name, + project_id=job.metadata.project_id, + ) + if imported_project is None: + raise RuntimeError( + f"Failed to retrieve the imported project '{project_name}'." + ) + return imported_project + + def _tus_upload_file(self, upload_endpoint: str, filepath: os.PathLike) -> str: + """ + Upload a file using the TUS protocol. + + :param upload_endpoint: The TUS upload endpoint. + :param filepath: The path to the file to upload. + :return: The file id created on the Geti Platform. + :raises: RuntimeError if the file id is not retrieved. + """ + tus_uploader = TUSUploader( + session=self.session, base_url=upload_endpoint, file_path=filepath + ) + tus_uploader.upload() + file_id = tus_uploader.get_file_id() + if file_id is None or len(file_id) < 2: + raise RuntimeError("Failed to get file id for project {project_name}.") + return file_id + + def export_project(self, project_id: str, filepath: os.PathLike): + """ + Export a project from the Geti Platform. + + :param project: The project to export. + :param filepath: The path to save the exported project. + :raises: RuntimeError if the download url is not retrieved. + """ + url = f"{self.base_url}projects/{project_id}:export" + self._export_snapshot(url=url, filepath=filepath) + + def export_dataset( + self, + project: Project, + dataset: Dataset, + filepath: os.PathLike, + export_format: DatasetFormat = DatasetFormat.DATUMARO, + include_unannotated_media: bool = False, + ): + """ + Export a dataset from the Geti Platform. + + :param project: The project containing the dataset. + :param dataset: The dataset to export. + :param filepath: The path to save the exported dataset. + :param export_format: The format to export the dataset in. + :param include_unannotated_media: Whether to include media that has not been annotated. + :raises: RuntimeError if the download url is not retrieved. + """ + query_params = ( + f"export_format={str(export_format)}&" + f"include_unannotated_media={str(include_unannotated_media).lower()}" + ) + url = ( + f"{self.base_url}projects/{project.id}/datasets/{dataset.id}" + f":prepare-for-export?{query_params}" + ) + + self._export_snapshot(url=url, filepath=filepath) + + def _export_snapshot(self, url: str, filepath: os.PathLike): + """ + Export an entity from the Geti Platform. + + :param url: The export endpoint. + :param filepath: The path to save the exported entity. + :raises: RuntimeError if the download url is not retrieved. + """ + parent_dir = os.path.dirname(filepath) + os.makedirs(parent_dir, exist_ok=True) + + response = self.session.get_rest_response( + url=url, + method="POST", + ) + if response.get("job_id") is None: + raise RuntimeError("Failed to get job id for the export entity.") + + job = get_job_with_timeout( + job_id=response.get("job_id"), + session=self.session, + workspace_id=self.workspace_id, + job_type="export_project", + ) + + job = monitor_job(session=self.session, job=job, interval=5) + if job.metadata.download_url is None: + raise RuntimeError("Failed to get download url for the exported entity.") + url = job.metadata.download_url + + if not url.startswith("/"): + url = "/" + url + + logging.info("Downloading the archive...") + zip_response = self.session.get_rest_response( + url=url, method="GET", contenttype="multipart" + ) + with open(filepath, "wb") as f: + f.write(zip_response.content) diff --git a/geti_sdk/import_export/tus_uploader.py b/geti_sdk/import_export/tus_uploader.py new file mode 100644 index 00000000..b68c45ef --- /dev/null +++ b/geti_sdk/import_export/tus_uploader.py @@ -0,0 +1,225 @@ +import os +import time +from io import BufferedReader +from typing import Optional + +from tqdm.auto import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + +from geti_sdk.http_session.geti_session import GetiSession + + +class TUSUploader: + """ + Class to handle tus uploads. + """ + + DEFAULT_CHUNK_SIZE = 5 * 2**20 # 5MB + + def __init__( + self, + session: GetiSession, + base_url: str, + file_path: Optional[os.PathLike] = None, + file_stream: Optional[BufferedReader] = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + retries: int = 0, + retry_delay: int = 30, + ): + """ + Initialize TUSUploader instance. + + :param session: GetiSession instance. + :param base_url: Base url for the tus upload. + :param file_path: Path to the file to be uploaded. + :param file_stream: File stream of the file to be uploaded. + :param chunk_size: Size of the chunk to be uploaded at each cycle. + :param retries: Number of retries to be made in case of upload failure. + :param retry_delay: Delay between retries. + """ + if file_path is None and file_stream is None: + raise ValueError("Either 'file_path' or 'file_stream' cannot be None.") + + self.file_path = file_path + self.file_stream = file_stream + if self.file_stream is None: + self.file_stream = self.get_file_stream() + self.stop_at = self.get_file_size() + self.session = session + self.base_url = base_url + self.tus_resumable_version = self._get_tus_resumable_version() + self.offset = 0 + self.upload_url = None + self.chunk_size = chunk_size + self.retries = retries + self.request = None + self._retried = 0 + self.retry_delay = retry_delay + + def _get_tus_resumable_version(self): + """ + Return tus resumable version. + """ + response = self.session.get_rest_response( + url=self.base_url, + method="OPTIONS", + ) + return response.headers["tus-resumable"] + + def get_offset(self): + """ + Return offset from tus server. + + Make an http request to the tus server to retrieve the current offset. + + :return: Offset value. + :raises: Exception if offset retrieval fails. + """ + response = self.session.get_rest_response( + url=self.upload_url, + method="HEAD", + request_headers={ + "tus-resumable": self.tus_resumable_version, + }, + ) + offset = response.headers.get("upload-offset") + if offset is None: + raise Exception("Attempt to retrieve offset failed") + return int(offset) + + def get_request_length(self): + """ + Return length of next chunk upload. + """ + remainder = self.stop_at - self.offset + return self.chunk_size if remainder > self.chunk_size else remainder + + def get_file_stream(self) -> BufferedReader: + """ + Return a file stream instance of the upload. + + :return: File stream instance. + :raises: ValueError if file_path is invalid. + """ + if self.file_stream: + self.file_stream.seek(0) + return self.file_stream + elif self.file_path is not None and os.path.isfile(self.file_path): + return open(self.file_path, "rb") + else: + raise ValueError("invalid file {}".format(self.file_path)) + + def get_file_size(self): + """ + Return size of the file. + """ + stream = self.get_file_stream() + stream.seek(0, os.SEEK_END) + return stream.tell() + + def upload(self, stop_at: Optional[int] = None): + """ + Perform file upload. + + Performs continous upload of chunks of the file. The size uploaded at each cycle is + the value of the attribute 'chunk_size'. + + :param stop_at: Offset value at which the upload should stop. If not specified this + defaults to the file size. + """ + self.stop_at = stop_at or self.get_file_size() + + if not self.upload_url: + self.upload_url = self.create_upload_url() + self.offset = 0 + + self.file_stream = self.get_file_stream() + with logging_redirect_tqdm(tqdm_class=tqdm): + with tqdm( + total=self.stop_at >> 20, + desc="Uploading file", + unit="MB", + ) as tbar: + while self.offset < self.stop_at: + self.upload_chunk() + tbar.update((self.offset >> 20) - tbar.n) + + def create_upload_url(self): + """ + Return upload url. + + Makes request to tus server to create a new upload url for the required file upload. + """ + response = self.session.get_rest_response( + url=self.base_url, + method="POST", + request_headers={ + "tus-resumable": self.tus_resumable_version, + "upload-length": str(self.get_file_size()), + }, + ) + upload_url = response.headers.get("location") + if upload_url is None: + raise ValueError("Upload url not returned by server") + return upload_url + + def get_file_id(self) -> Optional[str]: + """ + Return file id from upload url. + + :return: File id. + """ + if self.upload_url is None: + return + return self.upload_url.split("/")[-1] + + def upload_chunk(self): + """ + Upload chunk of file. + """ + self._retried = 0 + try: + self.offset = self._patch() + except Exception as err: + self._retry(err) + + def _patch(self) -> int: + """ + Perform actual request. + + :return: Offset value after the request. + """ + chunk = self.file_stream.read(self.get_request_length()) + # self.add_checksum(chunk) + response = self.session.get_rest_response( + url=self.upload_url, + method="PATCH", + data=chunk, + contenttype="offset+octet-stream", + request_headers={ + "upload-offset": str(self.offset), + "tus-resumable": self.tus_resumable_version, + }, + ) + upload_offset = int(response.headers.get("upload-offset")) + return int(upload_offset) + + def _retry(self, error): + """ + Retry upload in case of failure. + + :param error: Error that caused the upload to fail. + :raises: error if retries are exhausted. + """ + if self.retries > self._retried: + time.sleep(self.retry_delay) + + self._retried += 1 + try: + self.offset = self.get_offset() + except Exception as err: + self._retry(err) + else: + self.upload() + else: + raise error diff --git a/geti_sdk/rest_clients/annotation_clients/annotation_client.py b/geti_sdk/rest_clients/annotation_clients/annotation_client.py index 5606c78f..90af6a4e 100644 --- a/geti_sdk/rest_clients/annotation_clients/annotation_client.py +++ b/geti_sdk/rest_clients/annotation_clients/annotation_client.py @@ -48,10 +48,7 @@ def get_latest_annotations_for_video(self, video: Video) -> List[AnnotationScene return [] else: raise error - if self.session.version.is_sc_1_1 or self.session.version.is_sc_mvp: - annotations = response - else: - annotations = response["video_annotations"] + annotations = response["video_annotations"] annotation_scenes = [ self.annotation_scene_from_rest_response( annotation_scene, media_information=video.media_information diff --git a/geti_sdk/rest_clients/annotation_clients/base_annotation_client.py b/geti_sdk/rest_clients/annotation_clients/base_annotation_client.py index 64b10117..0774f9d0 100644 --- a/geti_sdk/rest_clients/annotation_clients/base_annotation_client.py +++ b/geti_sdk/rest_clients/annotation_clients/base_annotation_client.py @@ -37,9 +37,6 @@ from geti_sdk.http_session import GetiRequestException, GetiSession from geti_sdk.rest_clients.dataset_client import DatasetClient from geti_sdk.rest_converters import AnnotationRESTConverter -from geti_sdk.rest_converters.annotation_rest_converter import ( - NormalizedAnnotationRESTConverter, -) AnnotationReaderType = TypeVar("AnnotationReaderType", bound=AnnotationReader) MediaType = TypeVar("MediaType", Image, Video) @@ -220,17 +217,9 @@ def _upload_annotation_for_2d_media_item( ) if scene_to_upload.has_data: scene_to_upload.prepare_for_post() - if self.session.version.is_sc_mvp or self.session.version.is_sc_1_1: - rest_data = NormalizedAnnotationRESTConverter.to_normalized_dict( - scene_to_upload, - deidentify=False, - image_width=media_item.media_information.width, - image_height=media_item.media_information.height, - ) - else: - rest_data = AnnotationRESTConverter.to_dict( - scene_to_upload, deidentify=False - ) + rest_data = AnnotationRESTConverter.to_dict( + scene_to_upload, deidentify=False + ) rest_data.pop("kind") self.session.get_rest_response( url=f"{media_item.base_url}/annotations", @@ -267,17 +256,9 @@ def _append_annotation_for_2d_media_item( annotation_scene.extend(new_annotation_scene.annotations) if annotation_scene.has_data: - if self.session.version.is_sc_mvp or self.session.version.is_sc_1_1: - rest_data = NormalizedAnnotationRESTConverter.to_normalized_dict( - annotation_scene, - deidentify=False, - image_width=media_item.media_information.width, - image_height=media_item.media_information.height, - ) - else: - rest_data = AnnotationRESTConverter.to_dict( - annotation_scene, deidentify=False - ) + rest_data = AnnotationRESTConverter.to_dict( + annotation_scene, deidentify=False + ) rest_data.pop("kind", None) rest_data.pop("annotation_state_per_task", None) rest_data.pop("id", None) @@ -332,16 +313,7 @@ def annotation_scene_from_rest_response( annotation applies :return: AnnotationScene object corresponding to the data in the response_dict """ - if self.session.version.is_sc_mvp or self.session.version.is_sc_1_1: - annotation_scene = ( - NormalizedAnnotationRESTConverter.normalized_annotation_scene_from_dict( - response_dict, - image_width=media_information.width, - image_height=media_information.height, - ) - ) - else: - annotation_scene = AnnotationRESTConverter.from_dict(response_dict) + annotation_scene = AnnotationRESTConverter.from_dict(response_dict) return annotation_scene def _get_latest_annotation_for_2d_media_item( diff --git a/geti_sdk/rest_clients/model_client.py b/geti_sdk/rest_clients/model_client.py index bf8335ec..0c2b8074 100644 --- a/geti_sdk/rest_clients/model_client.py +++ b/geti_sdk/rest_clients/model_client.py @@ -59,15 +59,7 @@ def get_all_model_groups(self) -> List[ModelGroup]: :return: List of model groups in the project """ response = self.session.get_rest_response(url=self.base_url, method="GET") - if self.session.version.is_sc_1_1 or self.session.version.is_sc_mvp: - # The API is not fully consistent here, depending on exact release. - # Response may either be a dict or a list - try: - response_array = response["items"] - except TypeError: - response_array = response - else: - response_array = response["model_groups"] + response_array = response["model_groups"] model_groups = [ ModelRESTConverter.model_group_from_dict(group) for group in response_array ] @@ -512,10 +504,7 @@ def get_model_for_job(self, job: Job, check_status: bool = True) -> Model: """ if check_status: job.update(self.session) - if self.session.version.is_sc_mvp or self.session.version.is_sc_1_1: - job_pid = job.project_id - else: - job_pid = job.metadata.project.id + job_pid = job.metadata.project.id if job_pid != self.project.id: raise ValueError( f"Cannot get model for job `{job.description}`. This job does not " diff --git a/geti_sdk/rest_clients/prediction_client.py b/geti_sdk/rest_clients/prediction_client.py index d26a2898..d3eba6d3 100644 --- a/geti_sdk/rest_clients/prediction_client.py +++ b/geti_sdk/rest_clients/prediction_client.py @@ -38,10 +38,7 @@ from geti_sdk.data_models.enums import MediaType, PredictionMode from geti_sdk.data_models.predictions import ResultMedium from geti_sdk.http_session import GetiRequestException, GetiSession -from geti_sdk.rest_converters.prediction_rest_converter import ( - NormalizedPredictionRESTConverter, - PredictionRESTConverter, -) +from geti_sdk.rest_converters.prediction_rest_converter import PredictionRESTConverter class PredictionClient: @@ -72,17 +69,7 @@ def __are_models_trained(self) -> bool: ) model_info_array: List[Dict[str, Any]] - if self.session.version.is_sc_1_1 or self.session.version.is_sc_mvp: - if isinstance(response, dict): - model_info_array = response.get("items", []) - elif isinstance(response, list): - model_info_array = response - else: - raise ValueError( - f"Unexpected response from Intel® Geti™ server: {response}" - ) - else: - model_info_array = response.get("model_groups", []) + model_info_array = response.get("model_groups", []) task_ids = [task.id for task in self.project.get_trainable_tasks()] tasks_with_models: List[str] = [] @@ -253,14 +240,7 @@ def _get_prediction_for_media_item( data=data, ) if isinstance(media_item, (Image, VideoFrame)): - if self.session.version.is_sc_mvp or self.session.version.is_sc_1_1: - result = NormalizedPredictionRESTConverter.normalized_prediction_from_dict( - prediction=response, - image_height=media_item.media_information.height, - image_width=media_item.media_information.width, - ) - else: - result = PredictionRESTConverter.from_dict(response) + result = PredictionRESTConverter.from_dict(response) if include_explanation: maps: List[ResultMedium] = [] for map_dict in explain_response.get("maps", []): @@ -274,36 +254,24 @@ def _get_prediction_for_media_item( result.resolve_label_names_and_colors(labels=self._labels) elif isinstance(media_item, Video): - if self.session.version.is_sc_mvp or self.session.version.is_sc_1_1: - result = [ - NormalizedPredictionRESTConverter.normalized_prediction_from_dict( - prediction=prediction, - image_width=media_item.media_information.width, - image_height=media_item.media_information.height, - ).resolve_labels_for_result_media( - labels=self._labels - ) - for prediction in response - ] - else: - result = [] - for ind, prediction in enumerate(response["video_predictions"]): - pred_object = PredictionRESTConverter.from_dict(prediction) - pred_object.resolve_label_names_and_colors(labels=self._labels) - if include_explanation: - maps: List[ResultMedium] = [] - for map_dict in explain_response["explanations"][ind].get( - "maps", [] - ): - map = ResultMedium( - name="saliency_map", - label_id=map_dict.get("label_id", None), - ) - map.data = map_dict["data"].encode("utf-8") - maps.append(map) - pred_object.maps = maps - pred_object.resolve_labels_for_result_media(labels=self._labels) - result.append(pred_object) + result = [] + for ind, prediction in enumerate(response["video_predictions"]): + pred_object = PredictionRESTConverter.from_dict(prediction) + pred_object.resolve_label_names_and_colors(labels=self._labels) + if include_explanation: + maps: List[ResultMedium] = [] + for map_dict in explain_response["explanations"][ind].get( + "maps", [] + ): + map = ResultMedium( + name="saliency_map", + label_id=map_dict.get("label_id", None), + ) + map.data = map_dict["data"].encode("utf-8") + maps.append(map) + pred_object.maps = maps + pred_object.resolve_labels_for_result_media(labels=self._labels) + result.append(pred_object) else: raise TypeError( f"Getting predictions is not supported for media item of type " diff --git a/geti_sdk/rest_clients/project_client/project_client.py b/geti_sdk/rest_clients/project_client/project_client.py index 9e6936a4..96bc8177 100644 --- a/geti_sdk/rest_clients/project_client/project_client.py +++ b/geti_sdk/rest_clients/project_client/project_client.py @@ -78,9 +78,6 @@ def get_all_projects( :return: List of Project objects, containing the project information for each project on the Intel® Geti™ server """ - project_key = "projects" - num_total_projects_key = "project_counts" - # The 'projects' endpoint uses pagination: multiple HTTP may be necessary to # fetch the full list of projects project_rest_list: List[Dict] = [] @@ -88,8 +85,8 @@ def get_all_projects( url=f"{self.base_url}projects?limit={request_page_size}&skip={len(project_rest_list)}", method="GET", ): - project_rest_list.extend(response[project_key]) - if len(project_rest_list) >= response[num_total_projects_key]: + project_rest_list.extend(response["projects"]) + if len(project_rest_list) >= response["project_counts"]: break project_list = [ diff --git a/geti_sdk/rest_clients/training_client.py b/geti_sdk/rest_clients/training_client.py index 1ae12354..ef8f0f6c 100644 --- a/geti_sdk/rest_clients/training_client.py +++ b/geti_sdk/rest_clients/training_client.py @@ -88,11 +88,7 @@ def get_jobs( query += f"&project_id={self.project.id}" if running_only: query += "&state=running" - - if self.session.version.is_sc_mvp or self.session.version.is_sc_1_1: - response_list_key = "items" - else: - response_list_key = "jobs" + response_list_key = "jobs" job_rest_list: List[dict] = [] while response := self.session.get_rest_response( diff --git a/geti_sdk/utils/job_helpers.py b/geti_sdk/utils/job_helpers.py index 61239538..b93a7ca8 100644 --- a/geti_sdk/utils/job_helpers.py +++ b/geti_sdk/utils/job_helpers.py @@ -274,7 +274,7 @@ def monitor_jobs( def monitor_job( session: GetiSession, job: Job, timeout: int = 10000, interval: int = 15 -) -> List[Job]: +) -> Job: """ Monitor and print the progress of a single `job`. Execution is halted until the job has either finished, failed or was cancelled. @@ -326,7 +326,11 @@ def monitor_job( previous_progress = 0 previous_message = job.current_step_message current_step = job.current_step - outer_description = f"Project `{job.metadata.project.name}` - {job.name}" + outer_description = ( + f"Project `{job.metadata.project.name}` - " + if job.metadata.project + else "" + ) + f"{job.name}" total_steps = job.total_steps outer_bar = tqdm( total=total_steps, diff --git a/geti_sdk/utils/workspace_helpers.py b/geti_sdk/utils/workspace_helpers.py index c3ecc36c..4ae05514 100644 --- a/geti_sdk/utils/workspace_helpers.py +++ b/geti_sdk/utils/workspace_helpers.py @@ -26,10 +26,7 @@ def get_default_workspace_id(rest_session: GetiSession) -> str: if isinstance(workspaces, list): workspace_list = workspaces elif isinstance(workspaces, dict): - if rest_session.version.is_sc_mvp or rest_session.version.is_sc_1_1: - workspace_list = workspaces["items"] - else: - workspace_list = workspaces["workspaces"] + workspace_list = workspaces["workspaces"] else: raise ValueError( f"Unexpected response from cluster: {workspaces}. Expected to receive a " diff --git a/notebooks/009_download_and_upload_project.ipynb b/notebooks/009_download_and_upload_project.ipynb index eb3628cd..73df1ca9 100644 --- a/notebooks/009_download_and_upload_project.ipynb +++ b/notebooks/009_download_and_upload_project.ipynb @@ -74,7 +74,7 @@ "metadata": {}, "source": [ "## Project download\n", - "Now, let's do the project download itself. The `Geti` provides a method `download_project()` to do so. It takes the following arguments:\n", + "Now, let's do the project download itself. The `Geti` provides a method `download_project_data()` to do so. It takes the following arguments:\n", "\n", "- `project_name`: Name of the project to download\n", "- `target_folder`: Path of the folder to download to. If left empty, a folder named `project_name` will be created in the current directory\n", @@ -94,7 +94,7 @@ "source": [ "import os\n", "\n", - "project = geti.download_project(\n", + "project = geti.download_project_data(\n", " project_name=PROJECT_NAME,\n", " target_folder=os.path.join(\"projects\", PROJECT_NAME),\n", " include_predictions=False,\n", @@ -110,7 +110,7 @@ "source": [ "That's all there is to it! Now, you should have a folder `projects` showing up in the current directory. Inside it should be another folder named `{PROJECT_NAME}`, which should contain all a file `project.json` holding the project details, as well as all media and annotations in the project and a file `configuration.json` that contains the full project configuration. \n", "\n", - "In addition, the `download_project` method can also create a deployment for the project (see notebook [008 deploy_project](008_deploy_project.ipynb) for more details on deployments), if the parameter `include_deployment` is set to True. In that case you should see a folder called `deployment` in the project directory as well.\n", + "In addition, the `download_project_data` method can also create a deployment for the project (see notebook [008 deploy_project](008_deploy_project.ipynb) for more details on deployments), if the parameter `include_deployment` is set to True. In that case you should see a folder called `deployment` in the project directory as well.\n", "\n", "Note that in this case predictions, models and the deployment are not included because we have set `include_predictions=False`, `include_active_models=False` and `include_deployment=False`." ] @@ -121,7 +121,7 @@ "metadata": {}, "source": [ "## Project upload\n", - "Now that the project is downloaded, we can use it to create a new project on the platform. Once the project is created, we can upload the media, annotations and configuration that were downloaded to the project folder to it. The `Geti` class provides a `upload_project()` method to do all that, which takes three parameters:\n", + "Now that the project is downloaded, we can use it to create a new project on the platform. Once the project is created, we can upload the media, annotations and configuration that were downloaded to the project folder to it. The `Geti` class provides a `upload_project_data()` method to do all that, which takes three parameters:\n", "- `target_folder`: Path to the folder containing the project data to upload\n", "- `project_name`: Optional name to assign to the new project on the platform. If left unspecified, the project name will correspond to the downloaded project name\n", "- `enable_auto_train`: True to turn on auto training for all tasks in the project, once the media and annotation upload is complete. False to leave auto training turned off." @@ -134,7 +134,7 @@ "metadata": {}, "outputs": [], "source": [ - "uploaded_project = geti.upload_project(\n", + "uploaded_project = geti.upload_project_data(\n", " target_folder=os.path.join(\"projects\", PROJECT_NAME),\n", " project_name=PROJECT_NAME,\n", " enable_auto_train=False,\n", @@ -146,7 +146,7 @@ "id": "632426d1-3e42-4571-a9ea-26eedf2a81c3", "metadata": {}, "source": [ - "Done! The uploaded project should now show up in your workspace. Note that it is of course also possible to upload the project to a different server, simply by setting up a new Geti instance to that server and calling the `upload_project()` method from that instance." + "Done! The uploaded project should now show up in your workspace. Note that it is of course also possible to upload the project to a different server, simply by setting up a new Geti instance to that server and calling the `upload_project_data()` method from that instance." ] }, { diff --git a/tests/fixtures/cassettes/DEVELOP/TestImportExport.test_export_import_dataset.cassette b/tests/fixtures/cassettes/DEVELOP/TestImportExport.test_export_import_dataset.cassette new file mode 100644 index 00000000..c4da205c --- /dev/null +++ b/tests/fixtures/cassettes/DEVELOP/TestImportExport.test_export_import_dataset.cassette @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:992a9ebc828326c6b1fdf07c62da0d2ac83577eb7fa730e90ae8668e8b62be48 +size 1144834 diff --git a/tests/fixtures/cassettes/DEVELOP/TestImportExport.test_export_import_project.cassette b/tests/fixtures/cassettes/DEVELOP/TestImportExport.test_export_import_project.cassette new file mode 100644 index 00000000..ab25c307 --- /dev/null +++ b/tests/fixtures/cassettes/DEVELOP/TestImportExport.test_export_import_project.cassette @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79736e0d1e21589d66f8bf135a8d704ab44e67e097fdc029365cdefb5c2a242f +size 188743390 diff --git a/tests/fixtures/cassettes/DEVELOP/geti.cassette b/tests/fixtures/cassettes/DEVELOP/geti.cassette index 4f217cba..10863b7d 100644 --- a/tests/fixtures/cassettes/DEVELOP/geti.cassette +++ b/tests/fixtures/cassettes/DEVELOP/geti.cassette @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a8e15f74e44a536481d53e2d3f3728f6eb25e9e3410d1c817d622725aa1fd963 -size 82870 +oid sha256:e84bf8317b63cf54ef4074f51f47132e95e11560a2ca20821d26d25a5afc879e +size 44600 diff --git a/tests/pre-merge/integration/test_geti.py b/tests/pre-merge/integration/test_geti.py index 6c90c9f3..d704bcd0 100644 --- a/tests/pre-merge/integration/test_geti.py +++ b/tests/pre-merge/integration/test_geti.py @@ -286,7 +286,7 @@ def test_download_and_upload_project( project = lazy_fxt_project_service.project target_folder = os.path.join(fxt_temp_directory, project.name) - fxt_geti.download_project( + fxt_geti.download_project_data( project.name, target_folder=target_folder, max_threads=1, @@ -298,7 +298,7 @@ def test_download_and_upload_project( n_images = len(os.listdir(os.path.join(target_folder, "images"))) n_annotations = len(os.listdir(os.path.join(target_folder, "annotations"))) - uploaded_project = fxt_geti.upload_project( + uploaded_project = fxt_geti.upload_project_data( target_folder=target_folder, project_name=f"{project.name}_upload", enable_auto_train=False, @@ -660,7 +660,7 @@ def test_download_project_including_models_and_predictions( target_folder = os.path.join( fxt_temp_directory, project.name + "_all_inclusive" ) - fxt_geti.download_project( + fxt_geti.download_project_data( project_name=project.name, target_folder=target_folder, include_predictions=True, diff --git a/tests/pre-merge/integration/test_import_export.py b/tests/pre-merge/integration/test_import_export.py new file mode 100644 index 00000000..24e5d6f2 --- /dev/null +++ b/tests/pre-merge/integration/test_import_export.py @@ -0,0 +1,124 @@ +# Copyright (C) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +import os + +import pytest + +from geti_sdk.geti import Geti + + +class TestImportExport: + """ + Integration tests for the Import Export methods in the Geti class. + """ + + @pytest.mark.vcr() + def test_export_import_project( + self, + fxt_geti: Geti, + fxt_temp_directory: str, + ) -> None: + project = fxt_geti.project_client.get_all_projects(get_project_details=False)[0] + target_folder = os.path.join(fxt_temp_directory, project.name + "_snapshot") + archive_path = target_folder + "/project_archive.zip" + imported_project_name = "IMPORTED_PROJECT" + + # Project is exported + assert not os.path.exists(archive_path) + fxt_geti.export_project( + project_name=project.name, project_id=project.id, filepath=archive_path + ) + assert os.path.exists(archive_path) + # Project is imported + existing_projects_pre_import = fxt_geti.project_client.get_all_projects( + get_project_details=False + ) + imported_project = fxt_geti.import_project( + filepath=archive_path, project_name=imported_project_name + ) + assert imported_project.name == imported_project_name + existing_projects = fxt_geti.project_client.get_all_projects( + get_project_details=False + ) + assert ( + next((p for p in existing_projects if p.id == imported_project.id), None) + is not None + ) + assert ( + next( + ( + p + for p in existing_projects_pre_import + if p.id == imported_project.id + ), + None, + ) + is None + ) + # Project is deleted + fxt_geti.project_client.delete_project( + imported_project, requires_confirmation=False + ) + + @pytest.mark.vcr() + def test_export_import_dataset( + self, + fxt_geti: Geti, + fxt_temp_directory: str, + ) -> None: + project = fxt_geti.project_client.get_all_projects(get_project_details=False)[0] + project = fxt_geti.project_client._get_project_detail(project) + assert project.datasets + dataset = project.datasets[0] + target_folder = os.path.join(fxt_temp_directory, project.name + "_snapshot") + archive_path = target_folder + "/dataset_archive.zip" + imported_project_name = "IMPORTED_PROJECT_FROM_DATASET" + + # Dataset is exported + assert not os.path.exists(archive_path) + fxt_geti.export_dataset(project=project, dataset=dataset, filepath=archive_path) + assert os.path.exists(archive_path) + # Dataset is imported as a project + existing_projects_pre_import = fxt_geti.project_client.get_all_projects( + get_project_details=False + ) + imported_project = fxt_geti.import_dataset( + filepath=archive_path, + project_name=imported_project_name, + project_type=project.project_type, + ) + assert imported_project.name == imported_project_name + existing_projects = fxt_geti.project_client.get_all_projects( + get_project_details=False + ) + assert ( + next((p for p in existing_projects if p.id == imported_project.id), None) + is not None + ) + assert ( + next( + ( + p + for p in existing_projects_pre_import + if p.id == imported_project.id + ), + None, + ) + is None + ) + # Project is deleted + fxt_geti.project_client.delete_project( + imported_project, requires_confirmation=False + ) diff --git a/tests/pre-merge/unit/benchmarking/test_benchmarker.py b/tests/pre-merge/unit/benchmarking/test_benchmarker.py index 04f95069..51bbc5d6 100644 --- a/tests/pre-merge/unit/benchmarking/test_benchmarker.py +++ b/tests/pre-merge/unit/benchmarking/test_benchmarker.py @@ -59,7 +59,9 @@ def test_initialize( ) # Assert - mock_get_project_by_name.assert_called_once_with(project_name=project_name) + mock_get_project_by_name.assert_called_once_with( + project_name=project_name, project_id=None + ) mocked_model_client.assert_called_once() mocked_training_client.assert_called_once() assert benchmarker._is_single_task @@ -118,7 +120,9 @@ def test_initialize_task_chain( ) # Assert - mock_get_project_by_name.assert_called_once_with(project_name=project_name) + mock_get_project_by_name.assert_called_once_with( + project_name=project_name, project_id=None + ) mock_image_client_get_all.assert_called_once() mocked_model_client.assert_called_once() model_client_object_mock.get_all_active_models.assert_called_once() diff --git a/tests/pre-merge/unit/test_geti_unit.py b/tests/pre-merge/unit/test_geti_unit.py index 388053f8..ee5a770d 100644 --- a/tests/pre-merge/unit/test_geti_unit.py +++ b/tests/pre-merge/unit/test_geti_unit.py @@ -94,7 +94,9 @@ def test_download_all_projects( "geti_sdk.geti.ProjectClient.get_all_projects", return_value=fxt_nightly_projects, ) - mock_download_project = mocker.patch.object(fxt_mocked_geti, "download_project") + mock_download_project_data = mocker.patch( + "geti_sdk.import_export.import_export_module.GetiIE.download_project_data" + ) # Act projects = fxt_mocked_geti.download_all_projects( @@ -103,7 +105,7 @@ def test_download_all_projects( # Assert mock_get_all_projects.assert_called_once() - assert mock_download_project.call_count == len(projects) + assert mock_download_project_data.call_count == len(projects) def test_upload_all_projects( self, @@ -119,11 +121,13 @@ def test_upload_all_projects( mock_is_project_dir = mocker.patch( "geti_sdk.geti.ProjectClient.is_project_dir", return_value=True ) - mock_upload_project = mocker.patch.object(fxt_mocked_geti, "upload_project") + mock_upload_project_data = mocker.patch( + "geti_sdk.import_export.import_export_module.GetiIE.upload_project_data" + ) # Act fxt_mocked_geti.upload_all_projects(target_folder=target_dir) # Assert assert mock_is_project_dir.call_count == len(fxt_nightly_projects) - assert mock_upload_project.call_count == len(fxt_nightly_projects) + assert mock_upload_project_data.call_count == len(fxt_nightly_projects)