diff --git a/python/hopsworks/client/__init__.py b/python/hopsworks/client/__init__.py index 004e49c8b..19e0feb8d 100644 --- a/python/hopsworks/client/__init__.py +++ b/python/hopsworks/client/__init__.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Logical Clocks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,60 +14,27 @@ # limitations under the License. # -from hopsworks.client import external, hopsworks - - -_client = None -_python_version = None - - -def init( - client_type, - host=None, - port=None, - project=None, - hostname_verification=None, - trust_store_path=None, - cert_folder=None, - api_key_file=None, - api_key_value=None, -): - global _client - if not _client: - if client_type == "hopsworks": - _client = hopsworks.Client() - elif client_type == "external": - _client = external.Client( - host, - port, - project, - hostname_verification, - trust_store_path, - cert_folder, - api_key_file, - api_key_value, - ) - - -def get_instance(): - global _client - if _client: - return _client - raise Exception("Couldn't find client. Try reconnecting to Hopsworks.") - - -def get_python_version(): - global _python_version - return _python_version - - -def set_python_version(python_version): - global _python_version - _python_version = python_version - - -def stop(): - global _client - if _client: - _client._close() - _client = None +from hopsworks_common.client import ( + auth, + base, + exceptions, + external, + get_instance, + hopsworks, + init, + online_store_rest_client, + stop, +) + + +__all__ = [ + auth, + base, + exceptions, + external, + get_instance, + hopsworks, + init, + online_store_rest_client, + stop, +] diff --git a/python/hopsworks/client/auth.py b/python/hopsworks/client/auth.py index 8bbd4ae53..e912b1daf 100644 --- a/python/hopsworks/client/auth.py +++ b/python/hopsworks/client/auth.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Logical Clocks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,26 +14,15 @@ # limitations under the License. # -import requests +from hopsworks_common.client.auth import ( + ApiKeyAuth, + BearerAuth, + OnlineStoreKeyAuth, +) -class BearerAuth(requests.auth.AuthBase): - """Class to encapsulate a Bearer token.""" - - def __init__(self, token): - self._token = token - - def __call__(self, r): - r.headers["Authorization"] = "Bearer " + self._token.strip() - return r - - -class ApiKeyAuth(requests.auth.AuthBase): - """Class to encapsulate an API key.""" - - def __init__(self, token): - self._token = token - - def __call__(self, r): - r.headers["Authorization"] = "ApiKey " + self._token.strip() - return r +__all__ = [ + ApiKeyAuth, + BearerAuth, + OnlineStoreKeyAuth, +] diff --git a/python/hopsworks/client/base.py b/python/hopsworks/client/base.py index 852259639..3ff35d800 100644 --- a/python/hopsworks/client/base.py +++ b/python/hopsworks/client/base.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Logical Clocks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,172 +14,11 @@ # limitations under the License. # -import os -from abc import ABC, abstractmethod +from hopsworks_common.client.base import ( + Client, +) -import furl -import requests -import urllib3 -from hopsworks.client import auth, exceptions -from hopsworks.decorators import connected - -urllib3.disable_warnings(urllib3.exceptions.SecurityWarning) -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - -class Client(ABC): - TOKEN_FILE = "token.jwt" - APIKEY_FILE = "api.key" - REST_ENDPOINT = "REST_ENDPOINT" - HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST" - - @abstractmethod - def __init__(self): - """To be implemented by clients.""" - pass - - def _get_verify(self, verify, trust_store_path): - """Get verification method for sending HTTP requests to Hopsworks. - - Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c - - :param verify: perform hostname verification, 'true' or 'false' - :type verify: str - :param trust_store_path: path of the truststore locally if it was uploaded manually to - the external environment such as AWS Sagemaker - :type trust_store_path: str - :return: if verify is true and the truststore is provided, then return the trust store location - if verify is true but the truststore wasn't provided, then return true - if verify is false, then return false - :rtype: str or boolean - """ - if verify == "true": - if trust_store_path is not None: - return trust_store_path - else: - return True - - return False - - def _get_host_port_pair(self): - """ - Removes "http or https" from the rest endpoint and returns a list - [endpoint, port], where endpoint is on the format /path.. without http:// - - :return: a list [endpoint, port] - :rtype: list - """ - endpoint = self._base_url - if "http" in endpoint: - last_index = endpoint.rfind("/") - endpoint = endpoint[last_index + 1 :] - host, port = endpoint.split(":") - return host, port - - def _read_jwt(self): - """Retrieve jwt from local container.""" - return self._read_file(self.TOKEN_FILE) - - def _read_apikey(self): - """Retrieve apikey from local container.""" - return self._read_file(self.APIKEY_FILE) - - def _read_file(self, secret_file): - """Retrieve secret from local container.""" - with open(os.path.join(self._secrets_dir, secret_file), "r") as secret: - return secret.read() - - def _get_credentials(self, project_id): - """Makes a REST call to hopsworks for getting the project user certificates needed to connect to services such as Hive - - :param project_id: id of the project - :type project_id: int - :return: JSON response with credentials - :rtype: dict - """ - return self._send_request("GET", ["project", project_id, "credentials"]) - - def _write_pem_file(self, content: str, path: str) -> None: - with open(path, "w") as f: - f.write(content) - - @connected - def _send_request( - self, - method, - path_params, - query_params=None, - headers=None, - data=None, - stream=False, - files=None, - with_base_path_params=True, - ): - """Send REST request to Hopsworks. - - Uses the client it is executed from. Path parameters are url encoded automatically. - - :param method: 'GET', 'PUT' or 'POST' - :type method: str - :param path_params: a list of path params to build the query url from starting after - the api resource, for example `["project", 119, "featurestores", 67]`. - :type path_params: list - :param query_params: A dictionary of key/value pairs to be added as query parameters, - defaults to None - :type query_params: dict, optional - :param headers: Additional header information, defaults to None - :type headers: dict, optional - :param data: The payload as a python dictionary to be sent as json, defaults to None - :type data: dict, optional - :param stream: Set if response should be a stream, defaults to False - :type stream: boolean, optional - :param files: dictionary for multipart encoding upload - :type files: dict, optional - :raises RestAPIError: Raised when request wasn't correctly received, understood or accepted - :return: Response json - :rtype: dict - """ - f_url = furl.furl(self._base_url) - if with_base_path_params: - base_path_params = ["hopsworks-api", "api"] - f_url.path.segments = base_path_params + path_params - else: - f_url.path.segments = path_params - url = str(f_url) - - request = requests.Request( - method, - url=url, - headers=headers, - data=data, - params=query_params, - auth=self._auth, - files=files, - ) - - prepped = self._session.prepare_request(request) - response = self._session.send(prepped, verify=self._verify, stream=stream) - - if response.status_code == 401 and self.REST_ENDPOINT in os.environ: - # refresh token and retry request - only on hopsworks - self._auth = auth.BearerAuth(self._read_jwt()) - # Update request with the new token - request.auth = self._auth - prepped = self._session.prepare_request(request) - response = self._session.send(prepped, verify=self._verify, stream=stream) - - if response.status_code // 100 != 2: - raise exceptions.RestAPIError(url, response) - - if stream: - return response - else: - # handle different success response codes - if len(response.content) == 0: - return None - return response.json() - - def _close(self): - """Closes a client. Can be implemented for clean up purposes, not mandatory.""" - self._connected = False +__all__ = [ + Client, +] diff --git a/python/hopsworks/client/exceptions.py b/python/hopsworks/client/exceptions.py index 637146492..b34ef198f 100644 --- a/python/hopsworks/client/exceptions.py +++ b/python/hopsworks/client/exceptions.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Logical Clocks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,71 +14,37 @@ # limitations under the License. # - -class RestAPIError(Exception): - """REST Exception encapsulating the response object and url.""" - - def __init__(self, url, response): - try: - error_object = response.json() - except Exception: - error_object = {} - message = ( - "Metadata operation error: (url: {}). Server response: \n" - "HTTP code: {}, HTTP reason: {}, body: {}, error code: {}, error msg: {}, user " - "msg: {}".format( - url, - response.status_code, - response.reason, - response.content, - error_object.get("errorCode", ""), - error_object.get("errorMsg", ""), - error_object.get("usrMsg", ""), - ) - ) - super().__init__(message) - self.url = url - self.response = response - - -class UnknownSecretStorageError(Exception): - """This exception will be raised if an unused secrets storage is passed as a parameter.""" - - -class GitException(Exception): - """Generic git exception""" - - -class JobException(Exception): - """Generic job exception""" - - -class EnvironmentException(Exception): - """Generic python environment exception""" - - -class KafkaException(Exception): - """Generic kafka exception""" - - -class DatasetException(Exception): - """Generic dataset exception""" - - -class ProjectException(Exception): - """Generic project exception""" - - -class OpenSearchException(Exception): - """Generic opensearch exception""" - - -class JobExecutionException(Exception): - """Generic job executions exception""" - - -class ExternalClientError(TypeError): - """Raised when external client cannot be initialized due to missing arguments.""" - - def __init__(self, message): - super().__init__(message) +from hopsworks_common.client.exceptions import ( + DatasetException, + DataValidationException, + EnvironmentException, + ExternalClientError, + FeatureStoreException, + GitException, + JobException, + JobExecutionException, + KafkaException, + OpenSearchException, + ProjectException, + RestAPIError, + UnknownSecretStorageError, + VectorDatabaseException, +) + + +__all__ = [ + DatasetException, + DataValidationException, + EnvironmentException, + ExternalClientError, + FeatureStoreException, + GitException, + JobException, + JobExecutionException, + KafkaException, + OpenSearchException, + ProjectException, + RestAPIError, + UnknownSecretStorageError, + VectorDatabaseException, +] diff --git a/python/hopsworks/client/external.py b/python/hopsworks/client/external.py index d0a277e71..1384b1c20 100644 --- a/python/hopsworks/client/external.py +++ b/python/hopsworks/client/external.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Logical Clocks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,154 +14,11 @@ # limitations under the License. # -import base64 -import os +from hopsworks_common.client.external import ( + Client, +) -import requests -from hopsworks.client import auth, base, exceptions - -class Client(base.Client): - def __init__( - self, - host, - port, - project, - hostname_verification, - trust_store_path, - cert_folder, - api_key_file, - api_key_value, - ): - """Initializes a client in an external environment such as AWS Sagemaker.""" - if not host: - raise exceptions.ExternalClientError("host") - - self._host = host - self._port = port - self._base_url = "https://" + self._host + ":" + str(self._port) - self._project_name = project - if project is not None: - project_info = self._get_project_info(project) - self._project_id = str(project_info["projectId"]) - else: - self._project_id = None - - if api_key_value is not None: - api_key = api_key_value - elif api_key_file is not None: - file = None - if os.path.exists(api_key_file): - try: - file = open(api_key_file, mode="r") - api_key = file.read() - finally: - file.close() - else: - raise IOError( - "Could not find api key file on path: {}".format(api_key_file) - ) - else: - raise exceptions.ExternalClientError( - "Either api_key_file or api_key_value must be set when connecting to" - " hopsworks from an external environment." - ) - - self._auth = auth.ApiKeyAuth(api_key) - - self._session = requests.session() - self._connected = True - self._verify = self._get_verify(self._host, trust_store_path) - - self._cert_folder_base = os.path.join(cert_folder, host) - - def download_certs(self, project_name): - project_info = self._get_project_info(project_name) - project_id = str(project_info["projectId"]) - - project_cert_folder = os.path.join(self._cert_folder_base, project_name) - - trust_store_path = os.path.join(project_cert_folder, "trustStore.jks") - key_store_path = os.path.join(project_cert_folder, "keyStore.jks") - - os.makedirs(project_cert_folder, exist_ok=True) - credentials = self._get_credentials(project_id) - self._write_b64_cert_to_bytes( - str(credentials["kStore"]), - path=key_store_path, - ) - self._write_b64_cert_to_bytes( - str(credentials["tStore"]), - path=trust_store_path, - ) - - self._write_pem_file( - credentials["caChain"], self._get_ca_chain_path(project_name) - ) - self._write_pem_file( - credentials["clientCert"], self._get_client_cert_path(project_name) - ) - self._write_pem_file( - credentials["clientKey"], self._get_client_key_path(project_name) - ) - - with open(os.path.join(project_cert_folder, "material_passwd"), "w") as f: - f.write(str(credentials["password"])) - - def _close(self): - """Closes a client and deletes certificates.""" - # TODO: Implement certificate cleanup. Currently do not remove certificates as it may break users using hsfs python ingestion. - self._connected = False - - def _get_jks_trust_store_path(self): - return self._trust_store_path - - def _get_jks_key_store_path(self): - return self._key_store_path - - def _get_ca_chain_path(self, project_name) -> str: - return os.path.join(self._cert_folder_base, project_name, "ca_chain.pem") - - def _get_client_cert_path(self, project_name) -> str: - return os.path.join(self._cert_folder_base, project_name, "client_cert.pem") - - def _get_client_key_path(self, project_name) -> str: - return os.path.join(self._cert_folder_base, project_name, "client_key.pem") - - def _get_project_info(self, project_name): - """Makes a REST call to hopsworks to get all metadata of a project for the provided project. - - :param project_name: the name of the project - :type project_name: str - :return: JSON response with project info - :rtype: dict - """ - return self._send_request("GET", ["project", "getProjectInfo", project_name]) - - def _write_b64_cert_to_bytes(self, b64_string, path): - """Converts b64 encoded certificate to bytes file . - - :param b64_string: b64 encoded string of certificate - :type b64_string: str - :param path: path where file is saved, including file name. e.g. /path/key-store.jks - :type path: str - """ - - with open(path, "wb") as f: - cert_b64 = base64.b64decode(b64_string) - f.write(cert_b64) - - def _cleanup_file(self, file_path): - """Removes local files with `file_path`.""" - try: - os.remove(file_path) - except OSError: - pass - - def replace_public_host(self, url): - """no need to replace as we are already in external client""" - return url - - @property - def host(self): - return self._host +__all__ = [ + Client, +] diff --git a/python/hopsworks/client/hopsworks.py b/python/hopsworks/client/hopsworks.py index 514e3fe48..c360b8cb9 100644 --- a/python/hopsworks/client/hopsworks.py +++ b/python/hopsworks/client/hopsworks.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Logical Clocks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,224 +14,11 @@ # limitations under the License. # -import base64 -import os -import textwrap -from pathlib import Path +from hopsworks_common.client.hopsworks import ( + Client, +) -import requests -from hopsworks.client import auth, base - -try: - import jks -except ImportError: - pass - - -class Client(base.Client): - REQUESTS_VERIFY = "REQUESTS_VERIFY" - DOMAIN_CA_TRUSTSTORE_PEM = "DOMAIN_CA_TRUSTSTORE_PEM" - PROJECT_ID = "HOPSWORKS_PROJECT_ID" - PROJECT_NAME = "HOPSWORKS_PROJECT_NAME" - HADOOP_USER_NAME = "HADOOP_USER_NAME" - MATERIAL_DIRECTORY = "MATERIAL_DIRECTORY" - HDFS_USER = "HDFS_USER" - T_CERTIFICATE = "t_certificate" - K_CERTIFICATE = "k_certificate" - TRUSTSTORE_SUFFIX = "__tstore.jks" - KEYSTORE_SUFFIX = "__kstore.jks" - PEM_CA_CHAIN = "ca_chain.pem" - CERT_KEY_SUFFIX = "__cert.key" - MATERIAL_PWD = "material_passwd" - SECRETS_DIR = "SECRETS_DIR" - - def __init__(self): - """Initializes a client being run from a job/notebook directly on Hopsworks.""" - self._base_url = self._get_hopsworks_rest_endpoint() - self._host, self._port = self._get_host_port_pair() - self._secrets_dir = ( - os.environ[self.SECRETS_DIR] if self.SECRETS_DIR in os.environ else "" - ) - self._cert_key = self._get_cert_pw() - trust_store_path = self._get_trust_store_path() - hostname_verification = ( - os.environ[self.REQUESTS_VERIFY] - if self.REQUESTS_VERIFY in os.environ - else "true" - ) - self._project_id = os.environ[self.PROJECT_ID] - self._project_name = self._project_name() - try: - self._auth = auth.BearerAuth(self._read_jwt()) - except FileNotFoundError: - self._auth = auth.ApiKeyAuth(self._read_apikey()) - self._verify = self._get_verify(hostname_verification, trust_store_path) - self._session = requests.session() - - self._connected = True - - credentials = self._get_credentials(self._project_id) - - self._write_pem_file( - credentials["caChain"], self._get_ca_chain_path(self._project_name) - ) - self._write_pem_file( - credentials["clientCert"], self._get_client_cert_path(self._project_name) - ) - self._write_pem_file( - credentials["clientKey"], self._get_client_key_path(self._project_name) - ) - - def _get_hopsworks_rest_endpoint(self): - """Get the hopsworks REST endpoint for making requests to the REST API.""" - return os.environ[self.REST_ENDPOINT] - - def _get_trust_store_path(self): - """Convert truststore from jks to pem and return the location""" - ca_chain_path = Path(self.PEM_CA_CHAIN) - if not ca_chain_path.exists(): - self._write_ca_chain(ca_chain_path) - return str(ca_chain_path) - - def _get_ca_chain_path(self, project_name) -> str: - return os.path.join("/tmp", "ca_chain.pem") - - def _get_client_cert_path(self, project_name) -> str: - return os.path.join("/tmp", "client_cert.pem") - - def _get_client_key_path(self, project_name) -> str: - return os.path.join("/tmp", "client_key.pem") - - def _write_ca_chain(self, ca_chain_path): - """ - Converts JKS trustore file into PEM to be compatible with Python libraries - """ - keystore_pw = self._cert_key - keystore_ca_cert = self._convert_jks_to_pem( - self._get_jks_key_store_path(), keystore_pw - ) - truststore_ca_cert = self._convert_jks_to_pem( - self._get_jks_trust_store_path(), keystore_pw - ) - - with ca_chain_path.open("w") as f: - f.write(keystore_ca_cert + truststore_ca_cert) - - def _convert_jks_to_pem(self, jks_path, keystore_pw): - """ - Converts a keystore JKS that contains client private key, - client certificate and CA certificate that was used to - sign the certificate to PEM format and returns the CA certificate. - Args: - :jks_path: path to the JKS file - :pw: password for decrypting the JKS file - Returns: - strings: (ca_cert) - """ - # load the keystore and decrypt it with password - ks = jks.KeyStore.load(jks_path, keystore_pw, try_decrypt_keys=True) - ca_certs = "" - - # Convert CA Certificates into PEM format and append to string - for _alias, c in ks.certs.items(): - ca_certs = ca_certs + self._bytes_to_pem_str(c.cert, "CERTIFICATE") - return ca_certs - - def _bytes_to_pem_str(self, der_bytes, pem_type): - """ - Utility function for creating PEM files - - Args: - der_bytes: DER encoded bytes - pem_type: type of PEM, e.g Certificate, Private key, or RSA private key - - Returns: - PEM String for a DER-encoded certificate or private key - """ - pem_str = "" - pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n" - pem_str = ( - pem_str - + "\r\n".join( - textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64) - ) - + "\n" - ) - pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n" - return pem_str - - def _get_jks_trust_store_path(self): - """ - Get truststore location - - Returns: - truststore location - """ - t_certificate = Path(self.T_CERTIFICATE) - if t_certificate.exists(): - return str(t_certificate) - else: - username = os.environ[self.HADOOP_USER_NAME] - material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) - return str(material_directory.joinpath(username + self.TRUSTSTORE_SUFFIX)) - - def _get_jks_key_store_path(self): - """ - Get keystore location - - Returns: - keystore location - """ - k_certificate = Path(self.K_CERTIFICATE) - if k_certificate.exists(): - return str(k_certificate) - else: - username = os.environ[self.HADOOP_USER_NAME] - material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) - return str(material_directory.joinpath(username + self.KEYSTORE_SUFFIX)) - - def _project_name(self): - try: - return os.environ[self.PROJECT_NAME] - except KeyError: - pass - - hops_user = self._project_user() - hops_user_split = hops_user.split( - "__" - ) # project users have username project__user - project = hops_user_split[0] - return project - - def _project_user(self): - try: - hops_user = os.environ[self.HADOOP_USER_NAME] - except KeyError: - hops_user = os.environ[self.HDFS_USER] - return hops_user - - def _get_cert_pw(self): - """ - Get keystore password from local container - - Returns: - Certificate password - """ - pwd_path = Path(self.MATERIAL_PWD) - if not pwd_path.exists(): - username = os.environ[self.HADOOP_USER_NAME] - material_directory = Path(os.environ[self.MATERIAL_DIRECTORY]) - pwd_path = material_directory.joinpath(username + self.CERT_KEY_SUFFIX) - - with pwd_path.open() as f: - return f.read() - - def replace_public_host(self, url): - """replace hostname to public hostname set in HOPSWORKS_PUBLIC_HOST""" - ui_url = url._replace(netloc=os.environ[self.HOPSWORKS_PUBLIC_HOST]) - return ui_url - - @property - def host(self): - return self._host +__all__ = [ + Client, +] diff --git a/python/hopsworks/client/online_store_rest_client.py b/python/hopsworks/client/online_store_rest_client.py new file mode 100644 index 000000000..c75be81b7 --- /dev/null +++ b/python/hopsworks/client/online_store_rest_client.py @@ -0,0 +1,28 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.client.online_store_rest_client import ( + OnlineStoreRestClientSingleton, + get_instance, + init_or_reset_online_store_rest_client, +) + + +__all__ = [ + OnlineStoreRestClientSingleton, + get_instance, + init_or_reset_online_store_rest_client, +] diff --git a/python/hopsworks/connection.py b/python/hopsworks/connection.py index 61f2e3d6a..c43cfeeb9 100644 --- a/python/hopsworks/connection.py +++ b/python/hopsworks/connection.py @@ -215,12 +215,6 @@ def _check_compatibility(self): ) sys.stderr.flush() - def _set_client_variables(self): - python_version = self._variable_api.get_variable( - "docker_base_image_python_version" - ) - client.set_python_version(python_version) - @not_connected def connect(self): """Instantiate the connection. @@ -271,7 +265,6 @@ def connect(self): ) self._check_compatibility() - self._set_client_variables() def close(self): """Close a connection gracefully. diff --git a/python/hopsworks/core/variable_api.py b/python/hopsworks/core/variable_api.py new file mode 100644 index 000000000..9d6e9765f --- /dev/null +++ b/python/hopsworks/core/variable_api.py @@ -0,0 +1,24 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.core.variable_api import ( + VariableApi, +) + + +__all__ = [ + VariableApi, +] diff --git a/python/hopsworks/decorators.py b/python/hopsworks/decorators.py index 51b7d635a..1165a2daa 100644 --- a/python/hopsworks/decorators.py +++ b/python/hopsworks/decorators.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Logical Clocks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,42 +14,21 @@ # limitations under the License. # -import functools - - -def not_connected(fn): - @functools.wraps(fn) - def if_not_connected(inst, *args, **kwargs): - if inst._connected: - raise HopsworksConnectionError - return fn(inst, *args, **kwargs) - - return if_not_connected - - -def connected(fn): - @functools.wraps(fn) - def if_connected(inst, *args, **kwargs): - if not inst._connected: - raise NoHopsworksConnectionError - return fn(inst, *args, **kwargs) - - return if_connected - - -class HopsworksConnectionError(Exception): - """Thrown when attempted to change connection attributes while connected.""" - - def __init__(self): - super().__init__( - "Connection is currently in use. Needs to be closed for modification." - ) - - -class NoHopsworksConnectionError(Exception): - """Thrown when attempted to perform operation on connection while not connected.""" - - def __init__(self): - super().__init__( - "Connection is not active. Needs to be connected for hopsworks operations." - ) +from hopsworks_common.decorators import ( + HopsworksConnectionError, + NoHopsworksConnectionError, + connected, + not_connected, + typechecked, + uses_great_expectations, +) + + +__all__ = [ + HopsworksConnectionError, + NoHopsworksConnectionError, + connected, + not_connected, + typechecked, + uses_great_expectations, +] diff --git a/python/hopsworks_common/client/__init__.py b/python/hopsworks_common/client/__init__.py index 736b2006f..0ef17040e 100644 --- a/python/hopsworks_common/client/__init__.py +++ b/python/hopsworks_common/client/__init__.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from __future__ import annotations from typing import Literal, Optional, Union -from hsfs.client import external, hopsworks +from hopsworks_common.client import external, hopsworks -_client = None +_client: Union[hopsworks.Client, external.Client, None] = None def init( @@ -66,5 +67,6 @@ def get_instance() -> Union[hopsworks.Client, external.Client]: def stop() -> None: global _client - _client._close() + if _client: + _client._close() _client = None diff --git a/python/hopsworks_common/client/auth.py b/python/hopsworks_common/client/auth.py index 1556a5b4c..a960a7a8b 100644 --- a/python/hopsworks_common/client/auth.py +++ b/python/hopsworks_common/client/auth.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from __future__ import annotations import requests diff --git a/python/hopsworks_common/client/base.py b/python/hopsworks_common/client/base.py index eeb6eb369..ec3d377d8 100644 --- a/python/hopsworks_common/client/base.py +++ b/python/hopsworks_common/client/base.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from __future__ import annotations import base64 @@ -24,8 +25,8 @@ import furl import requests import urllib3 -from hsfs.client import auth, exceptions -from hsfs.decorators import connected +from hopsworks_common.client import auth, exceptions +from hopsworks_common.decorators import connected try: @@ -123,6 +124,7 @@ def _send_request( data=None, stream=False, files=None, + with_base_path_params=True, ): """Send REST request to Hopsworks. @@ -144,13 +146,16 @@ def _send_request( :type stream: boolean, optional :param files: dictionary for multipart encoding upload :type files: dict, optional - :raises hsfs.client.exceptions.RestAPIError: Raised when request wasn't correctly received, understood or accepted + :raises RestAPIError: Raised when request wasn't correctly received, understood or accepted :return: Response json :rtype: dict """ - base_path_params = ["hopsworks-api", "api"] f_url = furl.furl(self._base_url) - f_url.path.segments = base_path_params + path_params + if with_base_path_params: + base_path_params = ["hopsworks-api", "api"] + f_url.path.segments = base_path_params + path_params + else: + f_url.path.segments = path_params url = str(f_url) request = requests.Request( diff --git a/python/hopsworks_common/client/exceptions.py b/python/hopsworks_common/client/exceptions.py index 7a7f67d5c..6f5b26e40 100644 --- a/python/hopsworks_common/client/exceptions.py +++ b/python/hopsworks_common/client/exceptions.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from __future__ import annotations from enum import Enum @@ -108,3 +109,35 @@ def __init__(self, missing_argument: str) -> None: "argument to connect to hopsworks from an external environment." ).format(missing_argument) super().__init__(message) + + +class GitException(Exception): + """Generic git exception""" + + +class JobException(Exception): + """Generic job exception""" + + +class EnvironmentException(Exception): + """Generic python environment exception""" + + +class KafkaException(Exception): + """Generic kafka exception""" + + +class DatasetException(Exception): + """Generic dataset exception""" + + +class ProjectException(Exception): + """Generic project exception""" + + +class OpenSearchException(Exception): + """Generic opensearch exception""" + + +class JobExecutionException(Exception): + """Generic job executions exception""" diff --git a/python/hopsworks_common/client/external.py b/python/hopsworks_common/client/external.py index e99fc20b4..dfb30b722 100644 --- a/python/hopsworks_common/client/external.py +++ b/python/hopsworks_common/client/external.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from __future__ import annotations import base64 @@ -22,6 +23,8 @@ import boto3 import requests +from hopsworks_common.client import auth, base, exceptions +from hopsworks_common.client.exceptions import FeatureStoreException try: @@ -29,9 +32,6 @@ except ImportError: pass -from hsfs.client import auth, base, exceptions -from hsfs.client.exceptions import FeatureStoreException - _logger = logging.getLogger(__name__) @@ -60,14 +60,18 @@ def __init__( _logger.info("Initializing external client") if not host: raise exceptions.ExternalClientError("host") - if not project: - raise exceptions.ExternalClientError("project") self._host = host self._port = port self._base_url = "https://" + self._host + ":" + str(self._port) _logger.info("Base URL: %s", self._base_url) self._project_name = project + if project is not None: + project_info = self._get_project_info(project) + self._project_id = str(project_info["projectId"]) + _logger.debug("Setting Project ID: %s", self._project_id) + else: + self._project_id = None _logger.debug("Project name: %s", self._project_name) self._region_name = region_name or self.DEFAULT_REGION _logger.debug("Region name: %s", self._region_name) @@ -77,6 +81,8 @@ def __init__( api_key = api_key_value else: _logger.debug("Querying secrets store for API key") + if secrets_store is None: + secrets_store = self.LOCAL_STORE api_key = self._get_secret(secrets_store, "api-key", api_key_file) _logger.debug("Using api key to setup header authentification") @@ -89,22 +95,15 @@ def __init__( self._verify = self._get_verify(self._host, trust_store_path) _logger.debug("Verify: %s", self._verify) - project_info = self._get_project_info(self._project_name) - - self._project_id = str(project_info["projectId"]) - _logger.debug("Setting Project ID: %s", self._project_id) - self._cert_key = None - self._cert_folder_base = None + self._cert_folder_base = cert_folder + self._cert_folder = None - if engine == "python": - credentials = self._materialize_certs(cert_folder, host, project) + if project is None: + return - self._write_pem_file(credentials["caChain"], self._get_ca_chain_path()) - self._write_pem_file( - credentials["clientCert"], self._get_client_cert_path() - ) - self._write_pem_file(credentials["clientKey"], self._get_client_key_path()) + if engine == "python": + self.download_certs(project) elif engine == "spark": # When using the Spark engine with metastore connection, the certificates @@ -126,12 +125,13 @@ def __init__( self._key_store_path = _spark_session.conf.get( "spark.hadoop.hops.ssl.keystore.name" ) + elif engine == "spark-no-metastore": _logger.debug( "Running in Spark environment with no metastore, initializing Spark session" ) _spark_session = SparkSession.builder.getOrCreate() - self._materialize_certs(cert_folder, host, project) + self.download_certs(project) # Set credentials location in the Spark configuration # Set other options in the Spark configuration @@ -150,20 +150,33 @@ def __init__( for conf_key, conf_value in configuration_dict.items(): _spark_session._jsc.hadoopConfiguration().set(conf_key, conf_value) - def _materialize_certs(self, cert_folder, host, project): - self._cert_folder_base = cert_folder - self._cert_folder = os.path.join(cert_folder, host, project) + def download_certs(self, project): + res = self._materialize_certs(self, project) + self._write_pem_file(res["caChain"], self._get_ca_chain_path()) + self._write_pem_file(res["clientCert"], self._get_client_cert_path()) + self._write_pem_file(res["clientKey"], self._get_client_key_path()) + return res + + def _materialize_certs(self, project): + if project != self._project_name: + self._project_name = project + _logger.debug("Project name: %s", self._project_name) + project_info = self._get_project_info(project) + self._project_id = str(project_info["projectId"]) + _logger.debug("Setting Project ID: %s", self._project_id) + + self._cert_folder = os.path.join(self._cert_folder_base, self._host, project) self._trust_store_path = os.path.join(self._cert_folder, "trustStore.jks") self._key_store_path = os.path.join(self._cert_folder, "keyStore.jks") if os.path.exists(self._cert_folder): _logger.debug( - f"Running in Python environment, reading certificates from certificates folder {cert_folder}" + f"Running in Python environment, reading certificates from certificates folder {self._cert_folder_base}" ) - _logger.debug("Found certificates: %s", os.listdir(cert_folder)) + _logger.debug("Found certificates: %s", os.listdir(self._cert_folder_base)) else: _logger.debug( - f"Running in Python environment, creating certificates folder {cert_folder}" + f"Running in Python environment, creating certificates folder {self._cert_folder_base}" ) os.makedirs(self._cert_folder, exist_ok=True) @@ -176,6 +189,7 @@ def _materialize_certs(self, cert_folder, host, project): str(credentials["tStore"]), path=self._get_jks_trust_store_path(), ) + self._cert_key = str(credentials["password"]) self._cert_key_path = os.path.join(self._cert_folder, "material_passwd") with open(self._cert_key_path, "w") as f: @@ -213,7 +227,8 @@ def _validate_spark_configuration(self, _spark_session): def _close(self): """Closes a client and deletes certificates.""" _logger.info("Closing external client and cleaning up certificates.") - if self._cert_folder_base is None: + self._connected = False + if self._cert_folder is None: _logger.debug("No certificates to clean up.") # On external Spark clients (Databricks, Spark Cluster), # certificates need to be provided before the Spark application starts. @@ -237,7 +252,8 @@ def _close(self): os.rmdir(self._cert_folder_base) except OSError: pass - self._connected = False + + self._cert_folder = None def _get_jks_trust_store_path(self): _logger.debug("Getting trust store path: %s", self._trust_store_path) @@ -247,18 +263,30 @@ def _get_jks_key_store_path(self): _logger.debug("Getting key store path: %s", self._key_store_path) return self._key_store_path - def _get_ca_chain_path(self) -> str: - path = os.path.join(self._cert_folder, "ca_chain.pem") + def _get_ca_chain_path(self, project_name=None) -> str: + if project_name is None: + project_name = self._project_name + path = os.path.join( + self._cert_folder_base, self._host, project_name, "ca_chain.pem" + ) _logger.debug(f"Getting ca chain path {path}") return path - def _get_client_cert_path(self) -> str: - path = os.path.join(self._cert_folder, "client_cert.pem") + def _get_client_cert_path(self, project_name=None) -> str: + if project_name is None: + project_name = self._project_name + path = os.path.join( + self._cert_folder_base, self._host, project_name, "client_cert.pem" + ) _logger.debug(f"Getting client cert path {path}") return path - def _get_client_key_path(self) -> str: - path = os.path.join(self._cert_folder, "client_key.pem") + def _get_client_key_path(self, project_name=None) -> str: + if project_name is None: + project_name = self._project_name + path = os.path.join( + self._cert_folder_base, self._host, project_name, "client_key.pem" + ) _logger.debug(f"Getting client key path {path}") return path @@ -374,9 +402,6 @@ def replace_public_host(self, url): """no need to replace as we are already in external client""" return url - def _is_external(self) -> bool: - return True - @property - def host(self) -> str: + def host(self): return self._host diff --git a/python/hopsworks_common/client/hopsworks.py b/python/hopsworks_common/client/hopsworks.py index 2134756b1..df4154a85 100644 --- a/python/hopsworks_common/client/hopsworks.py +++ b/python/hopsworks_common/client/hopsworks.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from __future__ import annotations +import base64 import os +import textwrap from pathlib import Path import requests -from hsfs.client import auth, base +from hopsworks_common.client import auth, base try: @@ -84,28 +85,76 @@ def _get_trust_store_path(self): """Convert truststore from jks to pem and return the location""" ca_chain_path = Path(self.PEM_CA_CHAIN) if not ca_chain_path.exists(): - ks = jks.KeyStore.load( - self._get_jks_key_store_path(), self._cert_key, try_decrypt_keys=True - ) - ts = jks.KeyStore.load( - self._get_jks_trust_store_path(), self._cert_key, try_decrypt_keys=True - ) - self._write_ca_chain( - ks, - ts, - ca_chain_path, - ) + self._write_ca_chain(ca_chain_path) return str(ca_chain_path) - def _get_ca_chain_path(self) -> str: + def _get_ca_chain_path(self, project_name=None) -> str: return os.path.join("/tmp", "ca_chain.pem") - def _get_client_cert_path(self) -> str: + def _get_client_cert_path(self, project_name=None) -> str: return os.path.join("/tmp", "client_cert.pem") - def _get_client_key_path(self) -> str: + def _get_client_key_path(self, project_name=None) -> str: return os.path.join("/tmp", "client_key.pem") + def _write_ca_chain(self, ca_chain_path): + """ + Converts JKS trustore file into PEM to be compatible with Python libraries + """ + keystore_pw = self._cert_key + keystore_ca_cert = self._convert_jks_to_pem( + self._get_jks_key_store_path(), keystore_pw + ) + truststore_ca_cert = self._convert_jks_to_pem( + self._get_jks_trust_store_path(), keystore_pw + ) + + with ca_chain_path.open("w") as f: + f.write(keystore_ca_cert + truststore_ca_cert) + + def _convert_jks_to_pem(self, jks_path, keystore_pw): + """ + Converts a keystore JKS that contains client private key, + client certificate and CA certificate that was used to + sign the certificate to PEM format and returns the CA certificate. + Args: + :jks_path: path to the JKS file + :pw: password for decrypting the JKS file + Returns: + strings: (ca_cert) + """ + # load the keystore and decrypt it with password + ks = jks.KeyStore.load(jks_path, keystore_pw, try_decrypt_keys=True) + ca_certs = "" + + # Convert CA Certificates into PEM format and append to string + for _alias, c in ks.certs.items(): + ca_certs = ca_certs + self._bytes_to_pem_str(c.cert, "CERTIFICATE") + return ca_certs + + def _bytes_to_pem_str(self, der_bytes, pem_type): + """ + Utility function for creating PEM files + + Args: + der_bytes: DER encoded bytes + pem_type: type of PEM, e.g Certificate, Private key, or RSA private key + + Returns: + PEM String for a DER-encoded certificate or private key + """ + pem_str = "" + pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n" + pem_str = ( + pem_str + + "\r\n".join( + textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64) + ) + + "\n" + ) + pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n" + return pem_str + def _get_jks_trust_store_path(self): """ Get truststore location @@ -143,9 +192,8 @@ def _project_name(self): pass hops_user = self._project_user() - hops_user_split = hops_user.split( - "__" - ) # project users have username project__user + # project users have username project__user: + hops_user_split = hops_user.split("__") project = hops_user_split[0] return project @@ -177,9 +225,6 @@ def replace_public_host(self, url): ui_url = url._replace(netloc=os.environ[self.HOPSWORKS_PUBLIC_HOST]) return ui_url - def _is_external(self): - return False - @property def host(self): return self._host diff --git a/python/hopsworks_common/client/online_store_rest_client.py b/python/hopsworks_common/client/online_store_rest_client.py index b733269a1..03d77471c 100644 --- a/python/hopsworks_common/client/online_store_rest_client.py +++ b/python/hopsworks_common/client/online_store_rest_client.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from __future__ import annotations import logging @@ -22,9 +23,9 @@ import requests import requests.adapters from furl import furl -from hsfs import client -from hsfs.client.exceptions import FeatureStoreException -from hsfs.core import variable_api +from hopsworks_common import client +from hopsworks_common.client.exceptions import FeatureStoreException +from hopsworks_common.core import variable_api _logger = logging.getLogger(__name__) diff --git a/python/hopsworks_common/core/__init__.py b/python/hopsworks_common/core/__init__.py new file mode 100644 index 000000000..ff8055b9b --- /dev/null +++ b/python/hopsworks_common/core/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/hopsworks_common/core/constants.py b/python/hopsworks_common/core/constants.py index 74f14aca7..56f98d01e 100644 --- a/python/hopsworks_common/core/constants.py +++ b/python/hopsworks_common/core/constants.py @@ -1,3 +1,19 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import importlib.util diff --git a/python/hopsworks_common/core/variable_api.py b/python/hopsworks_common/core/variable_api.py index d4e8d188c..7b3c74575 100644 --- a/python/hopsworks_common/core/variable_api.py +++ b/python/hopsworks_common/core/variable_api.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from __future__ import annotations import re from typing import Optional, Tuple -from hopsworks import client -from hopsworks.client.exceptions import RestAPIError +from hopsworks_common import client +from hopsworks_common.client.exceptions import RestAPIError class VariableApi: diff --git a/python/hopsworks_common/decorators.py b/python/hopsworks_common/decorators.py index 3ce15277f..d24321ffb 100644 --- a/python/hopsworks_common/decorators.py +++ b/python/hopsworks_common/decorators.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from __future__ import annotations import functools import os -from hsfs.core.constants import ( +from hopsworks_common.core.constants import ( HAS_GREAT_EXPECTATIONS, great_expectations_not_installed_message, ) diff --git a/python/hsfs/client/__init__.py b/python/hsfs/client/__init__.py new file mode 100644 index 000000000..19e0feb8d --- /dev/null +++ b/python/hsfs/client/__init__.py @@ -0,0 +1,40 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.client import ( + auth, + base, + exceptions, + external, + get_instance, + hopsworks, + init, + online_store_rest_client, + stop, +) + + +__all__ = [ + auth, + base, + exceptions, + external, + get_instance, + hopsworks, + init, + online_store_rest_client, + stop, +] diff --git a/python/hsfs/client/auth.py b/python/hsfs/client/auth.py new file mode 100644 index 000000000..e912b1daf --- /dev/null +++ b/python/hsfs/client/auth.py @@ -0,0 +1,28 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.client.auth import ( + ApiKeyAuth, + BearerAuth, + OnlineStoreKeyAuth, +) + + +__all__ = [ + ApiKeyAuth, + BearerAuth, + OnlineStoreKeyAuth, +] diff --git a/python/hsfs/client/base.py b/python/hsfs/client/base.py new file mode 100644 index 000000000..3ff35d800 --- /dev/null +++ b/python/hsfs/client/base.py @@ -0,0 +1,24 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.client.base import ( + Client, +) + + +__all__ = [ + Client, +] diff --git a/python/hsfs/client/exceptions.py b/python/hsfs/client/exceptions.py new file mode 100644 index 000000000..b34ef198f --- /dev/null +++ b/python/hsfs/client/exceptions.py @@ -0,0 +1,50 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.client.exceptions import ( + DatasetException, + DataValidationException, + EnvironmentException, + ExternalClientError, + FeatureStoreException, + GitException, + JobException, + JobExecutionException, + KafkaException, + OpenSearchException, + ProjectException, + RestAPIError, + UnknownSecretStorageError, + VectorDatabaseException, +) + + +__all__ = [ + DatasetException, + DataValidationException, + EnvironmentException, + ExternalClientError, + FeatureStoreException, + GitException, + JobException, + JobExecutionException, + KafkaException, + OpenSearchException, + ProjectException, + RestAPIError, + UnknownSecretStorageError, + VectorDatabaseException, +] diff --git a/python/hsfs/client/external.py b/python/hsfs/client/external.py new file mode 100644 index 000000000..1384b1c20 --- /dev/null +++ b/python/hsfs/client/external.py @@ -0,0 +1,24 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.client.external import ( + Client, +) + + +__all__ = [ + Client, +] diff --git a/python/hsfs/client/hopsworks.py b/python/hsfs/client/hopsworks.py new file mode 100644 index 000000000..c360b8cb9 --- /dev/null +++ b/python/hsfs/client/hopsworks.py @@ -0,0 +1,24 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.client.hopsworks import ( + Client, +) + + +__all__ = [ + Client, +] diff --git a/python/hsfs/client/online_store_rest_client.py b/python/hsfs/client/online_store_rest_client.py new file mode 100644 index 000000000..c75be81b7 --- /dev/null +++ b/python/hsfs/client/online_store_rest_client.py @@ -0,0 +1,28 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.client.online_store_rest_client import ( + OnlineStoreRestClientSingleton, + get_instance, + init_or_reset_online_store_rest_client, +) + + +__all__ = [ + OnlineStoreRestClientSingleton, + get_instance, + init_or_reset_online_store_rest_client, +] diff --git a/python/hsfs/core/constants.py b/python/hsfs/core/constants.py new file mode 100644 index 000000000..8de65ecc9 --- /dev/null +++ b/python/hsfs/core/constants.py @@ -0,0 +1,46 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.core.constants import ( + HAS_AIOMYSQL, + HAS_ARROW, + HAS_AVRO, + HAS_CONFLUENT_KAFKA, + HAS_FAST_AVRO, + HAS_GREAT_EXPECTATIONS, + HAS_NUMPY, + HAS_PANDAS, + HAS_POLARS, + HAS_SQLALCHEMY, + great_expectations_not_installed_message, + initialise_expectation_suite_for_single_expectation_api_message, +) + + +__all__ = [ + HAS_AIOMYSQL, + HAS_ARROW, + HAS_AVRO, + HAS_CONFLUENT_KAFKA, + HAS_FAST_AVRO, + HAS_GREAT_EXPECTATIONS, + HAS_NUMPY, + HAS_PANDAS, + HAS_POLARS, + HAS_SQLALCHEMY, + great_expectations_not_installed_message, + initialise_expectation_suite_for_single_expectation_api_message, +] diff --git a/python/hsfs/core/variable_api.py b/python/hsfs/core/variable_api.py index b499bd9b4..9d6e9765f 100644 --- a/python/hsfs/core/variable_api.py +++ b/python/hsfs/core/variable_api.py @@ -1,5 +1,5 @@ # -# Copyright 2022 Hopsworks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,66 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from __future__ import annotations -import re +from hopsworks_common.core.variable_api import ( + VariableApi, +) -from hsfs import client -from hsfs.client.exceptions import RestAPIError - -class VariableApi: - def get_version(self, software: str): - _client = client.get_instance() - path_params = [ - "variables", - "versions", - ] - - resp = _client._send_request("GET", path_params) - for entry in resp: - if entry["software"] == software: - return entry["version"] - return None - - def parse_major_and_minor(self, backend_version): - version_pattern = r"(\d+)\.(\d+)" - matches = re.match(version_pattern, backend_version) - - return matches.group(1), matches.group(2) - - def get_flyingduck_enabled(self): - _client = client.get_instance() - path_params = [ - "variables", - "enable_flyingduck", - ] - - resp = _client._send_request("GET", path_params) - return resp["successMessage"] == "true" - - def get_loadbalancer_external_domain(self): - _client = client.get_instance() - path_params = [ - "variables", - "loadbalancer_external_domain", - ] - - try: - resp = _client._send_request("GET", path_params) - return resp["successMessage"] - except RestAPIError: - return "" - - def get_service_discovery_domain(self): - _client = client.get_instance() - path_params = [ - "variables", - "service_discovery_domain", - ] - - try: - resp = _client._send_request("GET", path_params) - return resp["successMessage"] - except RestAPIError: - return "" +__all__ = [ + VariableApi, +] diff --git a/python/hsfs/decorators.py b/python/hsfs/decorators.py new file mode 100644 index 000000000..1165a2daa --- /dev/null +++ b/python/hsfs/decorators.py @@ -0,0 +1,34 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hopsworks_common.decorators import ( + HopsworksConnectionError, + NoHopsworksConnectionError, + connected, + not_connected, + typechecked, + uses_great_expectations, +) + + +__all__ = [ + HopsworksConnectionError, + NoHopsworksConnectionError, + connected, + not_connected, + typechecked, + uses_great_expectations, +] diff --git a/python/hsml/decorators.py b/python/hsml/decorators.py index 826fd5aa2..1165a2daa 100644 --- a/python/hsml/decorators.py +++ b/python/hsml/decorators.py @@ -1,5 +1,5 @@ # -# Copyright 2021 Logical Clocks AB +# Copyright 2024 Hopsworks AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,42 +14,21 @@ # limitations under the License. # -import functools - - -def not_connected(fn): - @functools.wraps(fn) - def if_not_connected(inst, *args, **kwargs): - if inst._connected: - raise HopsworksConnectionError - return fn(inst, *args, **kwargs) - - return if_not_connected - - -def connected(fn): - @functools.wraps(fn) - def if_connected(inst, *args, **kwargs): - if not inst._connected: - raise NoHopsworksConnectionError - return fn(inst, *args, **kwargs) - - return if_connected - - -class HopsworksConnectionError(Exception): - """Thrown when attempted to change connection attributes while connected.""" - - def __init__(self): - super().__init__( - "Connection is currently in use. Needs to be closed for modification." - ) - - -class NoHopsworksConnectionError(Exception): - """Thrown when attempted to perform operation on connection while not connected.""" - - def __init__(self): - super().__init__( - "Connection is not active. Needs to be connected for model registry operations." - ) +from hopsworks_common.decorators import ( + HopsworksConnectionError, + NoHopsworksConnectionError, + connected, + not_connected, + typechecked, + uses_great_expectations, +) + + +__all__ = [ + HopsworksConnectionError, + NoHopsworksConnectionError, + connected, + not_connected, + typechecked, + uses_great_expectations, +] diff --git a/python/tests/core/test_online_store_rest_client.py b/python/tests/core/test_online_store_rest_client.py index 90d368dfd..39ed1f640 100644 --- a/python/tests/core/test_online_store_rest_client.py +++ b/python/tests/core/test_online_store_rest_client.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import hsfs +import hopsworks_common import pytest from furl import furl -from hsfs.client import auth, exceptions, online_store_rest_client +from hopsworks_common.client import auth, exceptions, online_store_rest_client class MockExternalClient: @@ -50,13 +50,15 @@ def test_setup_rest_client_external(self, mocker, monkeypatch): def client_get_instance(): return MockExternalClient() - monkeypatch.setattr(hsfs.client, "get_instance", client_get_instance) + monkeypatch.setattr( + hopsworks_common.client, "get_instance", client_get_instance + ) variable_api_mock = mocker.patch( - "hsfs.core.variable_api.VariableApi.get_loadbalancer_external_domain", + "hopsworks_common.core.variable_api.VariableApi.get_loadbalancer_external_domain", return_value="app.hopsworks.ai", ) ping_rdrs_mock = mocker.patch( - "hsfs.client.online_store_rest_client.OnlineStoreRestClientSingleton.is_connected", + "hopsworks_common.client.online_store_rest_client.OnlineStoreRestClientSingleton.is_connected", ) # Act @@ -86,14 +88,16 @@ def test_setup_online_store_rest_client_internal(self, mocker, monkeypatch): def client_get_instance(): return MockInternalClient() - monkeypatch.setattr(hsfs.client, "get_instance", client_get_instance) + monkeypatch.setattr( + hopsworks_common.client, "get_instance", client_get_instance + ) variable_api_mock = mocker.patch( - "hsfs.core.variable_api.VariableApi.get_service_discovery_domain", + "hopsworks_common.core.variable_api.VariableApi.get_service_discovery_domain", return_value="consul", ) optional_config = {"api_key": "provided_api_key"} ping_rdrs_mock = mocker.patch( - "hsfs.client.online_store_rest_client.OnlineStoreRestClientSingleton.is_connected", + "hopsworks_common.client.online_store_rest_client.OnlineStoreRestClientSingleton.is_connected", ) # Act