Skip to content

Commit

Permalink
Merge client, connection, version and variable_api
Browse files Browse the repository at this point in the history
  • Loading branch information
aversey committed Jun 28, 2024
1 parent d6b578a commit 227a4e1
Show file tree
Hide file tree
Showing 23 changed files with 895 additions and 1,815 deletions.
30 changes: 19 additions & 11 deletions python/hopsworks/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
#

from typing import Literal, Optional, Union

from hopsworks.client import external, hopsworks


Expand All @@ -22,16 +24,19 @@


def init(
client_type,
host=None,
port=None,
project=None,
hostname_verification=None,
trust_store_path=None,
cert_folder=None,
api_key_file=None,
api_key_value=None,
):
client_type: Union[Literal["hopsworks"], Literal["external"]],
host: Optional[str] = None,
port: Optional[int] = None,
project: Optional[str] = None,
engine: Optional[str] = None,
region_name: Optional[str] = None,
secrets_store=None,
hostname_verification: Optional[bool] = None,
trust_store_path: Optional[str] = None,
cert_folder: Optional[str] = None,
api_key_file: Optional[str] = None,
api_key_value: Optional[str] = None,
) -> None:
global _client
if not _client:
if client_type == "hopsworks":
Expand All @@ -41,6 +46,9 @@ def init(
host,
port,
project,
engine,
region_name,
secrets_store,
hostname_verification,
trust_store_path,
cert_folder,
Expand All @@ -49,7 +57,7 @@ def init(
)


def get_instance():
def get_instance() -> Union[hopsworks.Client, external.Client]:
global _client
if _client:
return _client
Expand Down
21 changes: 17 additions & 4 deletions python/hopsworks/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,39 @@
# limitations under the License.
#

from __future__ import annotations

import requests


class BearerAuth(requests.auth.AuthBase):
"""Class to encapsulate a Bearer token."""

def __init__(self, token):
def __init__(self, token: str) -> requests.Request:
self._token = token

def __call__(self, r):
def __call__(self, r: requests.Request) -> requests.Request:
r.headers["Authorization"] = "Bearer " + self._token.strip()
return r


class ApiKeyAuth(requests.auth.AuthBase):
"""Class to encapsulate an API key."""

def __init__(self, token):
def __init__(self, token: str) -> None:
self._token = token

def __call__(self, r):
def __call__(self, r: requests.Request) -> requests.Request:
r.headers["Authorization"] = "ApiKey " + self._token.strip()
return r


class OnlineStoreKeyAuth(requests.auth.AuthBase):
"""Class to encapsulate an API key."""

def __init__(self, token: str) -> None:
self._token = token.strip()

def __call__(self, r: requests.Request) -> requests.Request:
r.headers["X-API-KEY"] = self._token
return r
132 changes: 120 additions & 12 deletions python/hopsworks/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
# limitations under the License.
#

from __future__ import annotations

import base64
import os
from abc import ABC, abstractmethod
import textwrap
import time
from pathlib import Path

import furl
import requests
Expand All @@ -24,21 +29,26 @@
from hopsworks.decorators import connected


try:
import jks
except ImportError:
pass


urllib3.disable_warnings(urllib3.exceptions.SecurityWarning)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


class Client(ABC):
class Client:
TOKEN_FILE = "token.jwt"
TOKEN_EXPIRED_RETRY_INTERVAL = 0.6
TOKEN_EXPIRED_MAX_RETRIES = 10

APIKEY_FILE = "api.key"
REST_ENDPOINT = "REST_ENDPOINT"
DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV = "DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV"
HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST"

@abstractmethod
def __init__(self):
"""To be implemented by clients."""
pass

def _get_verify(self, verify, trust_store_path):
"""Get verification method for sending HTTP requests to Hopsworks.
Expand Down Expand Up @@ -163,11 +173,9 @@ def _send_request(

if response.status_code == 401 and self.REST_ENDPOINT in os.environ:
# refresh token and retry request - only on hopsworks
self._auth = auth.BearerAuth(self._read_jwt())
# Update request with the new token
request.auth = self._auth
prepped = self._session.prepare_request(request)
response = self._session.send(prepped, verify=self._verify, stream=stream)
response = self._retry_token_expired(
request, stream, self.TOKEN_EXPIRED_RETRY_INTERVAL, 1
)

if response.status_code // 100 != 2:
raise exceptions.RestAPIError(url, response)
Expand All @@ -180,6 +188,106 @@ def _send_request(
return None
return response.json()

def _retry_token_expired(self, request, stream, wait, retries):
"""Refresh the JWT token and retry the request. Only on Hopsworks.
As the token might take a while to get refreshed. Keep trying
"""
# Sleep the waited time before re-issuing the request
time.sleep(wait)

self._auth = auth.BearerAuth(self._read_jwt())
# Update request with the new token
request.auth = self._auth
prepped = self._session.prepare_request(request)
response = self._session.send(prepped, verify=self._verify, stream=stream)

if response.status_code == 401 and retries < self.TOKEN_EXPIRED_MAX_RETRIES:
# Try again.
return self._retry_token_expired(request, stream, wait * 2, retries + 1)
else:
# If the number of retries have expired, the _send_request method
# will throw an exception to the user as part of the status_code validation.
return response

def _close(self):
"""Closes a client. Can be implemented for clean up purposes, not mandatory."""
self._connected = False

def _write_pem(
self, keystore_path, keystore_pw, truststore_path, truststore_pw, prefix
):
ks = jks.KeyStore.load(Path(keystore_path), keystore_pw, try_decrypt_keys=True)
ts = jks.KeyStore.load(
Path(truststore_path), truststore_pw, try_decrypt_keys=True
)

ca_chain_path = os.path.join("/tmp", f"{prefix}_ca_chain.pem")
self._write_ca_chain(ks, ts, ca_chain_path)

client_cert_path = os.path.join("/tmp", f"{prefix}_client_cert.pem")
self._write_client_cert(ks, client_cert_path)

client_key_path = os.path.join("/tmp", f"{prefix}_client_key.pem")
self._write_client_key(ks, client_key_path)

return ca_chain_path, client_cert_path, client_key_path

def _write_ca_chain(self, ks, ts, ca_chain_path):
"""
Converts JKS keystore and truststore file into ca chain PEM to be compatible with Python libraries
"""
ca_chain = ""
for store in [ks, ts]:
for _, c in store.certs.items():
ca_chain = ca_chain + self._bytes_to_pem_str(c.cert, "CERTIFICATE")

with Path(ca_chain_path).open("w") as f:
f.write(ca_chain)

def _write_client_cert(self, ks, client_cert_path):
"""
Converts JKS keystore file into client cert PEM to be compatible with Python libraries
"""
client_cert = ""
for _, pk in ks.private_keys.items():
for c in pk.cert_chain:
client_cert = client_cert + self._bytes_to_pem_str(c[1], "CERTIFICATE")

with Path(client_cert_path).open("w") as f:
f.write(client_cert)

def _write_client_key(self, ks, client_key_path):
"""
Converts JKS keystore file into client key PEM to be compatible with Python libraries
"""
client_key = ""
for _, pk in ks.private_keys.items():
client_key = client_key + self._bytes_to_pem_str(
pk.pkey_pkcs8, "PRIVATE KEY"
)

with Path(client_key_path).open("w") as f:
f.write(client_key)

def _bytes_to_pem_str(self, der_bytes, pem_type):
"""
Utility function for creating PEM files
Args:
der_bytes: DER encoded bytes
pem_type: type of PEM, e.g Certificate, Private key, or RSA private key
Returns:
PEM String for a DER-encoded certificate or private key
"""
pem_str = ""
pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n"
pem_str = (
pem_str
+ "\r\n".join(
textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64)
)
+ "\n"
)
pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n"
return pem_str
61 changes: 59 additions & 2 deletions python/hopsworks/client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,29 @@
# limitations under the License.
#

from __future__ import annotations

from enum import Enum
from typing import Any, Union

import requests


class RestAPIError(Exception):
"""REST Exception encapsulating the response object and url."""

def __init__(self, url, response):
class FeatureStoreErrorCode(int, Enum):
FEATURE_GROUP_COMMIT_NOT_FOUND = 270227
STATISTICS_NOT_FOUND = 270228

def __eq__(self, other: Union[int, Any]) -> bool:
if isinstance(other, int):
return self.value == other
if isinstance(other, self.__class__):
return self is other
return False

def __init__(self, url: str, response: requests.Response) -> None:
try:
error_object = response.json()
except Exception:
Expand Down Expand Up @@ -77,8 +95,47 @@ class JobExecutionException(Exception):
"""Generic job executions exception"""


class FeatureStoreException(Exception):
"""Generic feature store exception"""


class ExternalClientError(TypeError):
"""Raised when external client cannot be initialized due to missing arguments."""

def __init__(self, message):
def __init__(self, missing_argument: str) -> None:
message = (
"{0} cannot be of type NoneType, {0} is a non-optional "
"argument to connect to hopsworks from an external environment."
).format(missing_argument)
super().__init__(message)


class VectorDatabaseException(Exception):
# reason
REQUESTED_K_TOO_LARGE = "REQUESTED_K_TOO_LARGE"
REQUESTED_NUM_RESULT_TOO_LARGE = "REQUESTED_NUM_RESULT_TOO_LARGE"
OTHERS = "OTHERS"

# info
REQUESTED_K_TOO_LARGE_INFO_K = "k"
REQUESTED_NUM_RESULT_TOO_LARGE_INFO_N = "n"

def __init__(self, reason: str, message: str, info: str) -> None:
super().__init__(message)
self._info = info
self._reason = reason

@property
def reason(self) -> str:
return self._reason

@property
def info(self) -> str:
return self._info


class DataValidationException(FeatureStoreException):
"""Raised when data validation fails only when using "STRICT" validation ingestion policy."""

def __init__(self, message: str) -> None:
super().__init__(message)
Loading

0 comments on commit 227a4e1

Please sign in to comment.