From 4e8aad8ba315642e5ec2c34b70d89fd45ecef8e8 Mon Sep 17 00:00:00 2001 From: Aleksey Veresov Date: Wed, 17 Jul 2024 14:41:51 +0200 Subject: [PATCH] Ruff format --- python/hopsworks/core/environment_api.py | 23 +++-- python/hopsworks/core/secret_api.py | 18 ++-- python/hopsworks/environment.py | 20 +++-- python/hopsworks/job_schedule.py | 2 +- python/hopsworks/util.py | 4 +- python/hsfs/core/explicit_provenance.py | 4 +- python/hsfs/core/feature_logging.py | 32 ++++--- python/hsfs/core/feature_view_api.py | 16 ++-- python/hsfs/core/feature_view_engine.py | 98 +++++++++++++-------- python/hsfs/core/opensearch.py | 18 ++-- python/hsfs/core/storage_connector_api.py | 5 +- python/hsfs/core/vector_db_client.py | 61 ++++++++----- python/hsfs/core/vector_server.py | 22 +++-- python/hsfs/feature.py | 25 ++---- python/hsfs/storage_connector.py | 4 +- python/hsfs/usage.py | 22 ++--- python/hsml/connection.py | 3 + python/hsml/model_serving.py | 7 +- python/hsml/util.py | 2 + python/tests/core/test_feature_group_api.py | 4 +- python/tests/core/test_opensearch.py | 1 - python/tests/core/test_vector_db_client.py | 4 +- python/tests/test_util.py | 4 +- 23 files changed, 241 insertions(+), 158 deletions(-) diff --git a/python/hopsworks/core/environment_api.py b/python/hopsworks/core/environment_api.py index 18c0c55d1..6a9ccf2ea 100644 --- a/python/hopsworks/core/environment_api.py +++ b/python/hopsworks/core/environment_api.py @@ -32,7 +32,13 @@ def __init__( self._environment_engine = environment_engine.EnvironmentEngine(project_id) - def create_environment(self, name: str, description: Optional[str] = None, base_environment_name: Optional[str] = "python-feature-pipeline", await_creation: Optional[bool] = True) -> environment.Environment: + def create_environment( + self, + name: str, + description: Optional[str] = None, + base_environment_name: Optional[str] = "python-feature-pipeline", + await_creation: Optional[bool] = True, + ) -> environment.Environment: """Create Python environment for the project ```python @@ -66,13 +72,14 @@ def create_environment(self, name: str, description: Optional[str] = None, base_ name, ] headers = {"content-type": "application/json"} - data = {"name": name, - "baseImage": { - "name": base_environment_name, - "description": description - }} + data = { + "name": name, + "baseImage": {"name": base_environment_name, "description": description}, + } env = environment.Environment.from_response_json( - _client._send_request("POST", path_params, headers=headers, data=json.dumps(data)), + _client._send_request( + "POST", path_params, headers=headers, data=json.dumps(data) + ), self._project_id, self._project_name, ) @@ -148,4 +155,4 @@ def _delete(self, name): name, ] headers = {"content-type": "application/json"} - _client._send_request("DELETE", path_params, headers=headers), + (_client._send_request("DELETE", path_params, headers=headers),) diff --git a/python/hopsworks/core/secret_api.py b/python/hopsworks/core/secret_api.py index 169ac6ff1..bf47b6ad8 100644 --- a/python/hopsworks/core/secret_api.py +++ b/python/hopsworks/core/secret_api.py @@ -72,7 +72,9 @@ def get_secret(self, name: str, owner: str = None) -> secret.Secret: "shared", ] - return secret.Secret.from_response_json(_client._send_request("GET", path_params, query_params=query_params))[0] + return secret.Secret.from_response_json( + _client._send_request("GET", path_params, query_params=query_params) + )[0] def get(self, name: str, owner: str = None) -> str: """Get the secret's value. @@ -90,16 +92,20 @@ def get(self, name: str, owner: str = None) -> str: return self.get_secret(name=name, owner=owner).value except RestAPIError as e: if ( - e.response.json().get("errorCode", "") == 160048 - and e.response.status_code == 404 - and util.is_interactive() + e.response.json().get("errorCode", "") == 160048 + and e.response.status_code == 404 + and util.is_interactive() ): - secret_input = getpass.getpass(prompt="\nCould not find secret, enter value here to create it: ") + secret_input = getpass.getpass( + prompt="\nCould not find secret, enter value here to create it: " + ) return self.create_secret(name, secret_input).value else: raise e - def create_secret(self, name: str, value: str, project: str = None) -> secret.Secret: + def create_secret( + self, name: str, value: str, project: str = None + ) -> secret.Secret: """Create a new secret. ```python diff --git a/python/hopsworks/environment.py b/python/hopsworks/environment.py index 3d087cad0..f286bdf8c 100644 --- a/python/hopsworks/environment.py +++ b/python/hopsworks/environment.py @@ -133,16 +133,18 @@ def install_wheel(self, path: str, await_installation: Optional[bool] = True): "packageSource": "WHEEL", } - library_rest = self._library_api._install( - library_name, self.name, library_spec - ) + library_rest = self._library_api._install(library_name, self.name, library_spec) if await_installation: - return self._environment_engine.await_library_command(self.name, library_name) + return self._environment_engine.await_library_command( + self.name, library_name + ) return library_rest - def install_requirements(self, path: str, await_installation: Optional[bool] = True): + def install_requirements( + self, path: str, await_installation: Optional[bool] = True + ): """Install libraries specified in a requirements.txt file ```python @@ -184,12 +186,12 @@ def install_requirements(self, path: str, await_installation: Optional[bool] = T "packageSource": "REQUIREMENTS_TXT", } - library_rest = self._library_api._install( - library_name, self.name, library_spec - ) + library_rest = self._library_api._install(library_name, self.name, library_spec) if await_installation: - return self._environment_engine.await_library_command(self.name, library_name) + return self._environment_engine.await_library_command( + self.name, library_name + ) return library_rest diff --git a/python/hopsworks/job_schedule.py b/python/hopsworks/job_schedule.py index 6a4a7f103..301b04122 100644 --- a/python/hopsworks/job_schedule.py +++ b/python/hopsworks/job_schedule.py @@ -30,7 +30,7 @@ def __init__( next_execution_date_time=None, id=None, end_date_time=None, - **kwargs + **kwargs, ): self._id = id self._start_date_time = ( diff --git a/python/hopsworks/util.py b/python/hopsworks/util.py index 35785783f..b5f46f29b 100644 --- a/python/hopsworks/util.py +++ b/python/hopsworks/util.py @@ -81,6 +81,8 @@ def get_hostname_replaced_url(sub_path: str): url_parsed = client.get_instance().replace_public_host(urlparse(href)) return url_parsed.geturl() + def is_interactive(): import __main__ as main - return not hasattr(main, '__file__') + + return not hasattr(main, "__file__") diff --git a/python/hsfs/core/explicit_provenance.py b/python/hsfs/core/explicit_provenance.py index 2ce4f8c80..450a00310 100644 --- a/python/hsfs/core/explicit_provenance.py +++ b/python/hsfs/core/explicit_provenance.py @@ -415,9 +415,7 @@ def default(self, obj): } elif isinstance( obj, - ( - storage_connector.StorageConnector - ), + (storage_connector.StorageConnector), ): return { "name": obj.name, diff --git a/python/hsfs/core/feature_logging.py b/python/hsfs/core/feature_logging.py index b29a7317d..bdf68d2ca 100644 --- a/python/hsfs/core/feature_logging.py +++ b/python/hsfs/core/feature_logging.py @@ -6,25 +6,32 @@ class FeatureLogging: - - def __init__(self, id: int, - transformed_features: "feature_group.FeatureGroup", - untransformed_features: "feature_group.FeatureGroup"): + def __init__( + self, + id: int, + transformed_features: "feature_group.FeatureGroup", + untransformed_features: "feature_group.FeatureGroup", + ): self._id = id self._transformed_features = transformed_features self._untransformed_features = untransformed_features @classmethod - def from_response_json(cls, json_dict: Dict[str, Any]) -> 'FeatureLogging': + def from_response_json(cls, json_dict: Dict[str, Any]) -> "FeatureLogging": from hsfs.feature_group import FeatureGroup # avoid circular import + json_decamelized = humps.decamelize(json_dict) - transformed_features = json_decamelized.get('transformed_log') - untransformed_features = json_decamelized.get('untransformed_log') + transformed_features = json_decamelized.get("transformed_log") + untransformed_features = json_decamelized.get("untransformed_log") if transformed_features: transformed_features = FeatureGroup.from_response_json(transformed_features) if untransformed_features: - untransformed_features = FeatureGroup.from_response_json(untransformed_features) - return cls(json_decamelized.get('id'), transformed_features, untransformed_features) + untransformed_features = FeatureGroup.from_response_json( + untransformed_features + ) + return cls( + json_decamelized.get("id"), transformed_features, untransformed_features + ) @property def transformed_features(self) -> "feature_group.FeatureGroup": @@ -40,9 +47,9 @@ def id(self) -> str: def to_dict(self): return { - 'id': self._id, - 'transformed_log': self._transformed_features, - 'untransformed_log': self._untransformed_features, + "id": self._id, + "transformed_log": self._transformed_features, + "untransformed_log": self._untransformed_features, } def json(self) -> Dict[str, Any]: @@ -50,4 +57,3 @@ def json(self) -> Dict[str, Any]: def __repr__(self): return self.json() - diff --git a/python/hsfs/core/feature_view_api.py b/python/hsfs/core/feature_view_api.py index cf67b0216..ac6a8ef84 100644 --- a/python/hsfs/core/feature_view_api.py +++ b/python/hsfs/core/feature_view_api.py @@ -46,7 +46,6 @@ class FeatureViewApi: _TRANSFORMED_lOG = "transformed" _UNTRANSFORMED_LOG = "untransformed" - def __init__(self, feature_store_id: int) -> None: self._feature_store_id = feature_store_id self._client = client.get_instance() @@ -407,7 +406,8 @@ def get_models_provenance( def enable_feature_logging( self, feature_view_name: str, - feature_view_version: int,): + feature_view_version: int, + ): _client = client.get_instance() path_params = self._base_path + [ feature_view_name, @@ -420,7 +420,8 @@ def enable_feature_logging( def pause_feature_logging( self, feature_view_name: str, - feature_view_version: int,): + feature_view_version: int, + ): _client = client.get_instance() path_params = self._base_path + [ feature_view_name, @@ -434,7 +435,8 @@ def pause_feature_logging( def resume_feature_logging( self, feature_view_name: str, - feature_view_version: int,): + feature_view_version: int, + ): _client = client.get_instance() path_params = self._base_path + [ feature_view_name, @@ -448,7 +450,8 @@ def resume_feature_logging( def materialize_feature_logging( self, feature_view_name: str, - feature_view_version: int,): + feature_view_version: int, + ): _client = client.get_instance() path_params = self._base_path + [ feature_view_name, @@ -469,7 +472,8 @@ def materialize_feature_logging( def get_feature_logging( self, feature_view_name: str, - feature_view_version: int,): + feature_view_version: int, + ): _client = client.get_instance() path_params = self._base_path + [ feature_view_name, diff --git a/python/hsfs/core/feature_view_engine.py b/python/hsfs/core/feature_view_engine.py index a29acf89f..4fdc7fdbf 100644 --- a/python/hsfs/core/feature_view_engine.py +++ b/python/hsfs/core/feature_view_engine.py @@ -822,8 +822,8 @@ def get_batch_data( def transform_batch_data(self, features, transformation_functions): return engine.get_instance()._apply_transformation_function( - transformation_functions, dataset=features, inplace=False - ) + transformation_functions, dataset=features, inplace=False + ) def add_tag( self, feature_view_obj, name: str, value, training_dataset_version=None @@ -996,7 +996,16 @@ def _get_logging_fg(self, fv, transformed): else: return feature_logging.untransformed_features - def log_features(self, fv, features, prediction=None, transformed=False, write_options=None, training_dataset_version=None, hsml_model=None): + def log_features( + self, + fv, + features, + prediction=None, + transformed=False, + write_options=None, + training_dataset_version=None, + hsml_model=None, + ): default_write_options = { "start_offline_materialization": False, } @@ -1017,29 +1026,41 @@ def log_features(self, fv, features, prediction=None, transformed=False, write_o ) return fg.insert(df, write_options=default_write_options) - def read_feature_logs(self, fv, - start_time: Optional[ - Union[str, int, datetime, datetime.date]] = None, - end_time: Optional[ - Union[str, int, datetime, datetime.date]] = None, - filter: Optional[Union[Filter, Logic]]=None, - transformed: Optional[bool]=False, - training_dataset_version=None, - hsml_model=None, - ): + def read_feature_logs( + self, + fv, + start_time: Optional[Union[str, int, datetime, datetime.date]] = None, + end_time: Optional[Union[str, int, datetime, datetime.date]] = None, + filter: Optional[Union[Filter, Logic]] = None, + transformed: Optional[bool] = False, + training_dataset_version=None, + hsml_model=None, + ): fg = self._get_logging_fg(fv, transformed) fv_feat_name_map = self._get_fv_feature_name_map(fv) query = fg.select_all() if start_time: - query = query.filter(fg.get_feature(FeatureViewEngine._LOG_TIME) >= start_time) + query = query.filter( + fg.get_feature(FeatureViewEngine._LOG_TIME) >= start_time + ) if end_time: - query = query.filter(fg.get_feature(FeatureViewEngine._LOG_TIME) <= end_time) + query = query.filter( + fg.get_feature(FeatureViewEngine._LOG_TIME) <= end_time + ) if training_dataset_version: - query = query.filter(fg.get_feature(FeatureViewEngine._LOG_TD_VERSION) == training_dataset_version) + query = query.filter( + fg.get_feature(FeatureViewEngine._LOG_TD_VERSION) + == training_dataset_version + ) if hsml_model: - query = query.filter(fg.get_feature(FeatureViewEngine._HSML_MODEL) == self.get_hsml_model_value(hsml_model)) + query = query.filter( + fg.get_feature(FeatureViewEngine._HSML_MODEL) + == self.get_hsml_model_value(hsml_model) + ) if filter: - query = query.filter(self._convert_to_log_fg_filter(fg, fv, filter, fv_feat_name_map)) + query = query.filter( + self._convert_to_log_fg_filter(fg, fv, filter, fv_feat_name_map) + ) df = query.read() df = df.drop(["log_id", FeatureViewEngine._LOG_TIME], axis=1) return df @@ -1062,9 +1083,12 @@ def _convert_to_log_fg_filter(self, fg, fv, filter, fv_feat_name_map): ) elif isinstance(filter, Filter): fv_feature_name = fv_feat_name_map.get( - f"{filter.feature.feature_group_id}_{filter.feature.name}") + f"{filter.feature.feature_group_id}_{filter.feature.name}" + ) if fv_feature_name is None: - raise FeatureStoreException("Filter feature {filter.feature.name} does not exist in feature view feature.") + raise FeatureStoreException( + "Filter feature {filter.feature.name} does not exist in feature view feature." + ) return Filter( fg.get_feature(filter.feature.name), filter.condition, @@ -1076,32 +1100,30 @@ def _convert_to_log_fg_filter(self, fg, fv, filter, fv_feat_name_map): def _get_fv_feature_name_map(self, fv) -> Dict[str, str]: result_dict = {} for td_feature in fv.features: - fg_feature_key = f"{td_feature.feature_group.id}_{td_feature.feature_group_feature_name}" + fg_feature_key = ( + f"{td_feature.feature_group.id}_{td_feature.feature_group_feature_name}" + ) result_dict[fg_feature_key] = td_feature.name return result_dict - def get_log_timeline(self, fv, - wallclock_time: Optional[ - Union[str, int, datetime, datetime.date]] = None, - limit: Optional[int] = None, - transformed: Optional[bool]=False, - ) -> Dict[str, Dict[str, str]]: + def get_log_timeline( + self, + fv, + wallclock_time: Optional[Union[str, int, datetime, datetime.date]] = None, + limit: Optional[int] = None, + transformed: Optional[bool] = False, + ) -> Dict[str, Dict[str, str]]: fg = self._get_logging_fg(fv, transformed) return fg.commit_details(wallclock_time=wallclock_time, limit=limit) def pause_logging(self, fv): - self._feature_view_api.pause_feature_logging( - fv.name, fv.version - ) + self._feature_view_api.pause_feature_logging(fv.name, fv.version) + def resume_logging(self, fv): - self._feature_view_api.resume_feature_logging( - fv.name, fv.version - ) + self._feature_view_api.resume_feature_logging(fv.name, fv.version) def materialize_feature_logs(self, fv, wait): - jobs = self._feature_view_api.materialize_feature_logging( - fv.name, fv.version - ) + jobs = self._feature_view_api.materialize_feature_logging(fv.name, fv.version) if wait: for job in jobs: try: @@ -1111,6 +1133,4 @@ def materialize_feature_logs(self, fv, wait): return jobs def delete_feature_logs(self, fv, transformed): - self._feature_view_api.delete_feature_logs( - fv.name, fv.version, transformed - ) + self._feature_view_api.delete_feature_logs(fv.name, fv.version, transformed) diff --git a/python/hsfs/core/opensearch.py b/python/hsfs/core/opensearch.py index 3865c7ab0..6e1ca5091 100644 --- a/python/hsfs/core/opensearch.py +++ b/python/hsfs/core/opensearch.py @@ -54,7 +54,8 @@ def error_handler_wrapper(*args, **kw): caused_by = e.info.get("error") and e.info["error"].get("caused_by") if caused_by and caused_by["type"] == "illegal_argument_exception": raise OpenSearchClientSingleton()._create_vector_database_exception( - caused_by["reason"]) from e + caused_by["reason"] + ) from e raise VectorDatabaseException( VectorDatabaseException.OTHERS, f"Error in Opensearch request: {e}", @@ -100,16 +101,19 @@ def get_options(cls, options: dict): attribute values of the OpensearchRequestOption class, and values are obtained either from the provided options or default values if not available. """ - default_option = (cls.DEFAULT_OPTION_MAP - if cls.get_version() < (2, 3) - else cls.DEFAULT_OPTION_MAP_V2_3) + default_option = ( + cls.DEFAULT_OPTION_MAP + if cls.get_version() < (2, 3) + else cls.DEFAULT_OPTION_MAP_V2_3 + ) if options: # make lower case to avoid issues with cases options = {k.lower(): v for k, v in options.items()} new_options = {} for option, value in default_option.items(): if option in options: - if (option == "timeout" + if ( + option == "timeout" and cls.get_version() < (2, 3) and isinstance(options[option], int) ): @@ -161,7 +165,9 @@ def _refresh_opensearch_connection(self): ) @_handle_opensearch_exception def search(self, index=None, body=None, options=None): - return self._opensearch_client.search(body=body, index=index, params=OpensearchRequestOption.get_options(options)) + return self._opensearch_client.search( + body=body, index=index, params=OpensearchRequestOption.get_options(options) + ) @retry( wait_exponential_multiplier=1000, diff --git a/python/hsfs/core/storage_connector_api.py b/python/hsfs/core/storage_connector_api.py index d30201a11..01d1898de 100644 --- a/python/hsfs/core/storage_connector_api.py +++ b/python/hsfs/core/storage_connector_api.py @@ -101,9 +101,7 @@ def get_kafka_connector( _client._send_request("GET", path_params, query_params=query_params) ) - def get_feature_groups_provenance( - self, storage_connector_instance - ): + def get_feature_groups_provenance(self, storage_connector_instance): """Get the generated feature groups using this storage connector, based on explicit provenance. These feature groups can be accessible or inaccessible. Explicit provenance does not track deleted generated feature group links, so deleted @@ -135,6 +133,7 @@ def get_feature_groups_provenance( } links_json = _client._send_request("GET", path_params, query_params) from hsfs.core import explicit_provenance + return explicit_provenance.Links.from_response_json( links_json, explicit_provenance.Links.Direction.DOWNSTREAM, diff --git a/python/hsfs/core/vector_db_client.py b/python/hsfs/core/vector_db_client.py index b9fdc86ab..71060c983 100644 --- a/python/hsfs/core/vector_db_client.py +++ b/python/hsfs/core/vector_db_client.py @@ -96,7 +96,9 @@ def init(self): ) self._embedding_fg_by_join_index[i] = join_fg for embedding_feature in join_fg.embedding_index.get_embeddings(): - self._td_embedding_feature_names.add((join.prefix or "") + embedding_feature.name) + self._td_embedding_feature_names.add( + (join.prefix or "") + embedding_feature.name + ) vdb_col_td_col_map = {} for feat in join_fg.features: vdb_col_td_col_map[ @@ -191,10 +193,13 @@ def find_neighbors( return [ ( 1 / item["_score"] - 1, - self._convert_to_pandas_type(embedding_feature.feature_group.features, self._rewrite_result_key( - item["_source"], - self._fg_vdb_col_td_col_map[embedding_feature.feature_group.id], - )), + self._convert_to_pandas_type( + embedding_feature.feature_group.features, + self._rewrite_result_key( + item["_source"], + self._fg_vdb_col_td_col_map[embedding_feature.feature_group.id], + ), + ), ) for item in results["hits"]["hits"] ] @@ -207,11 +212,15 @@ def _convert_to_pandas_type(self, schema, result): if not feature_value: # Feature value can be null continue elif feature_type == "date": - result[feature_name] = datetime.utcfromtimestamp(feature_value // 10**3).date() + result[feature_name] = datetime.utcfromtimestamp( + feature_value // 10**3 + ).date() elif feature_type == "timestamp": # convert timestamp in ms to datetime in s result[feature_name] = datetime.utcfromtimestamp(feature_value // 10**3) - elif feature_type == "binary" or (feature.is_complex() and feature not in self._embedding_features): + elif feature_type == "binary" or ( + feature.is_complex() and feature not in self._embedding_features + ): result[feature_name] = base64.b64decode(feature_value) return result @@ -337,18 +346,20 @@ def read(self, fg_id, schema, keys=None, pk=None, index_name=None, n=10): if VectorDbClient._index_result_limit_n.get(index_name) is None: try: query["size"] = 2**31 - 1 - self._opensearch_client.search(body=query, - index=index_name) + self._opensearch_client.search(body=query, index=index_name) except VectorDatabaseException as e: if ( - e.reason == VectorDatabaseException.REQUESTED_NUM_RESULT_TOO_LARGE + e.reason + == VectorDatabaseException.REQUESTED_NUM_RESULT_TOO_LARGE and e.info.get( - VectorDatabaseException.REQUESTED_NUM_RESULT_TOO_LARGE_INFO_N - ) - ): - VectorDbClient._index_result_limit_n[index_name] = e.info.get( VectorDatabaseException.REQUESTED_NUM_RESULT_TOO_LARGE_INFO_N ) + ): + VectorDbClient._index_result_limit_n[index_name] = ( + e.info.get( + VectorDatabaseException.REQUESTED_NUM_RESULT_TOO_LARGE_INFO_N + ) + ) else: raise e query["size"] = VectorDbClient._index_result_limit_n.get(index_name) @@ -356,24 +367,32 @@ def read(self, fg_id, schema, keys=None, pk=None, index_name=None, n=10): results = self._opensearch_client.search(body=query, index=index_name) # https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces return [ - self._convert_to_pandas_type(schema, self._rewrite_result_key( - item["_source"], self._fg_vdb_col_td_col_map[fg_id] - )) + self._convert_to_pandas_type( + schema, + self._rewrite_result_key( + item["_source"], self._fg_vdb_col_td_col_map[fg_id] + ), + ) for item in results["hits"]["hits"] ] @staticmethod - def read_feature_group(feature_group: "hsfs.feature_group.FeatureGroup", n: int =None) -> list: + def read_feature_group( + feature_group: "hsfs.feature_group.FeatureGroup", n: int = None + ) -> list: if feature_group.embedding_index: vector_db_client = VectorDbClient(feature_group.select_all()) results = vector_db_client.read( feature_group.id, feature_group.features, - pk=feature_group.embedding_index.col_prefix + feature_group.primary_key[0], + pk=feature_group.embedding_index.col_prefix + + feature_group.primary_key[0], index_name=feature_group.embedding_index.index_name, - n=n + n=n, ) - return [[result[f.name] for f in feature_group.features] for result in results] + return [ + [result[f.name] for f in feature_group.features] for result in results + ] else: raise FeatureStoreException("Feature group does not have embedding.") diff --git a/python/hsfs/core/vector_server.py b/python/hsfs/core/vector_server.py index 44a522564..97d9b83b0 100755 --- a/python/hsfs/core/vector_server.py +++ b/python/hsfs/core/vector_server.py @@ -99,10 +99,7 @@ def __init__( self._untransformed_feature_vector_col_name = [ feat.name for feat in features - if not ( - feat.label - or feat.training_helper_column - ) + if not (feat.label or feat.training_helper_column) ] self._inference_helper_col_name = [ feat.name for feat in features if feat.inference_helper_column @@ -451,17 +448,26 @@ def assemble_feature_vector( for fname in self.transformed_feature_vector_col_name ] else: - return [result_dict.get(fname, None) for fname in self._untransformed_feature_vector_col_name] + return [ + result_dict.get(fname, None) + for fname in self._untransformed_feature_vector_col_name + ] def transform_feature_vectors(self, batch_features): - return [self.apply_transformation(self.get_untransformed_features_map(features)) + return [ + self.apply_transformation(self.get_untransformed_features_map(features)) for features in batch_features ] def get_untransformed_features_map(self, features) -> Dict[str, Any]: return dict( - [(fname, fvalue) for fname, fvalue - in zip(self._untransformed_feature_vector_col_name, features)]) + [ + (fname, fvalue) + for fname, fvalue in zip( + self._untransformed_feature_vector_col_name, features + ) + ] + ) def handle_feature_vector_return_type( self, diff --git a/python/hsfs/feature.py b/python/hsfs/feature.py index 896980567..f66fa9807 100644 --- a/python/hsfs/feature.py +++ b/python/hsfs/feature.py @@ -209,36 +209,29 @@ def feature_group_id(self) -> Optional[int]: def _get_filter_value(self, value: Any) -> Any: if self.type == "timestamp": - return (datetime.fromtimestamp( - util.convert_event_time_to_timestamp(value)/1000) - .strftime("%Y-%m-%d %H:%M:%S") - ) + return datetime.fromtimestamp( + util.convert_event_time_to_timestamp(value) / 1000 + ).strftime("%Y-%m-%d %H:%M:%S") else: return value def __lt__(self, other: Any) -> "filter.Filter": - return filter.Filter(self, filter.Filter.LT, - self._get_filter_value(other)) + return filter.Filter(self, filter.Filter.LT, self._get_filter_value(other)) def __le__(self, other: Any) -> "filter.Filter": - return filter.Filter(self, filter.Filter.LE, - self._get_filter_value(other)) + return filter.Filter(self, filter.Filter.LE, self._get_filter_value(other)) def __eq__(self, other: Any) -> "filter.Filter": - return filter.Filter(self, filter.Filter.EQ, - self._get_filter_value(other)) + return filter.Filter(self, filter.Filter.EQ, self._get_filter_value(other)) def __ne__(self, other: Any) -> "filter.Filter": - return filter.Filter(self, filter.Filter.NE, - self._get_filter_value(other)) + return filter.Filter(self, filter.Filter.NE, self._get_filter_value(other)) def __ge__(self, other: Any) -> "filter.Filter": - return filter.Filter(self, filter.Filter.GE, - self._get_filter_value(other)) + return filter.Filter(self, filter.Filter.GE, self._get_filter_value(other)) def __gt__(self, other: Any) -> "filter.Filter": - return filter.Filter(self, filter.Filter.GT, - self._get_filter_value(other)) + return filter.Filter(self, filter.Filter.GT, self._get_filter_value(other)) def contains(self, other: Union[str, List[Any]]) -> "filter.Filter": """ diff --git a/python/hsfs/storage_connector.py b/python/hsfs/storage_connector.py index 96596a5b0..8e0c90b0b 100644 --- a/python/hsfs/storage_connector.py +++ b/python/hsfs/storage_connector.py @@ -211,7 +211,9 @@ def get_feature_groups(self): feature_groups_provenance = self.get_feature_groups_provenance() if feature_groups_provenance.inaccessible or feature_groups_provenance.deleted: - _logger.info("There are deleted or inaccessible feature groups. For more details access `get_feature_groups_provenance`") + _logger.info( + "There are deleted or inaccessible feature groups. For more details access `get_feature_groups_provenance`" + ) if feature_groups_provenance.accessible: return feature_groups_provenance.accessible diff --git a/python/hsfs/usage.py b/python/hsfs/usage.py index 3428de21f..bd724c293 100644 --- a/python/hsfs/usage.py +++ b/python/hsfs/usage.py @@ -85,16 +85,18 @@ def get_timezone(self): return self._timezone def json(self): - return json.dumps({ - "platform": self.get_platform(), - "hsml_version": self.get_hsml_version(), - "hsfs_version": self.get_hsfs_version(), - "hopsworks_version": self.get_hopsworks_version(), - "user_id": self.get_user_id(), - "backend_version": self.get_backend_version(), - "timezone": str(self.get_timezone()), - "python_version": self.get_python_version(), - }) + return json.dumps( + { + "platform": self.get_platform(), + "hsml_version": self.get_hsml_version(), + "hsfs_version": self.get_hsfs_version(), + "hopsworks_version": self.get_hopsworks_version(), + "user_id": self.get_user_id(), + "backend_version": self.get_backend_version(), + "timezone": str(self.get_timezone()), + "python_version": self.get_python_version(), + } + ) class MethodCounter: diff --git a/python/hsml/connection.py b/python/hsml/connection.py index 899589a4e..f4ca72512 100644 --- a/python/hsml/connection.py +++ b/python/hsml/connection.py @@ -97,6 +97,7 @@ def __init__( api_key_value: str = None, ): from hsml.core import model_api, model_registry_api, model_serving_api + self._host = host self._port = port self._project = project @@ -163,6 +164,7 @@ def connect(self): """ from hsml import client from hsml.core import model_api + self._connected = True try: # init client @@ -196,6 +198,7 @@ def close(self): Usage is recommended but optional. """ from hsml import client + client.stop() self._model_api = None self._connected = False diff --git a/python/hsml/model_serving.py b/python/hsml/model_serving.py index 21d04b833..d298e669f 100644 --- a/python/hsml/model_serving.py +++ b/python/hsml/model_serving.py @@ -285,7 +285,12 @@ def postprocess(self, outputs): return Transformer(script_file=script_file, resources=resources) - def create_deployment(self, predictor: Predictor, name: Optional[str] = None, environment: Optional[str] = None): + def create_deployment( + self, + predictor: Predictor, + name: Optional[str] = None, + environment: Optional[str] = None, + ): """Create a Deployment metadata object. !!! example diff --git a/python/hsml/util.py b/python/hsml/util.py index 6ef6d9053..6fffc4033 100644 --- a/python/hsml/util.py +++ b/python/hsml/util.py @@ -100,6 +100,7 @@ def set_model_class(model): from hsml.sklearn.model import Model as SkLearnModel from hsml.tensorflow.model import Model as TFModel from hsml.torch.model import Model as TorchModel + if "href" in model: _ = model.pop("href") if "type" in model: # backwards compatibility @@ -241,6 +242,7 @@ def get_predictor_for_model(model, **kwargs): from hsml.tensorflow.predictor import Predictor as TFPredictor from hsml.torch.model import Model as TorchModel from hsml.torch.predictor import Predictor as TorchPredictor + if not isinstance(model, BaseModel): raise ValueError( "model is of type {}, but an instance of {} class is expected".format( diff --git a/python/tests/core/test_feature_group_api.py b/python/tests/core/test_feature_group_api.py index 37459d897..9366f4401 100644 --- a/python/tests/core/test_feature_group_api.py +++ b/python/tests/core/test_feature_group_api.py @@ -54,9 +54,7 @@ def test_get_smart_with_infer_type(self, mocker, backend_fixtures): def test_check_features(self, mocker, backend_fixtures): # Arrange fg_api = feature_group_api.FeatureGroupApi() - json = backend_fixtures["feature_group"]["get_basic_info"][ - "response" - ] + json = backend_fixtures["feature_group"]["get_basic_info"]["response"] fg = fg_mod.FeatureGroup.from_response_json(json) # Act diff --git a/python/tests/core/test_opensearch.py b/python/tests/core/test_opensearch.py index 3ae804cdc..5a4bcb681 100644 --- a/python/tests/core/test_opensearch.py +++ b/python/tests/core/test_opensearch.py @@ -69,7 +69,6 @@ def test_create_vector_database_exception( class TestOpensearchRequestOption: - def test_version_1_no_options(self): OpensearchRequestOption.get_version = lambda: (1, 1) options = OpensearchRequestOption.get_options({}) diff --git a/python/tests/core/test_vector_db_client.py b/python/tests/core/test_vector_db_client.py index 4f17a1dbe..a4261a5dd 100644 --- a/python/tests/core/test_vector_db_client.py +++ b/python/tests/core/test_vector_db_client.py @@ -220,7 +220,9 @@ def test_check_filter_when_filter_is_not_logic_or_filter(self): self.target._check_filter("f1 > 20", self.fg2) def test_read_with_keys(self): - actual = self.target.read(self.fg.id, self.fg.features, keys={"f1": 10, "f2": 20}) + actual = self.target.read( + self.fg.id, self.fg.features, keys={"f1": 10, "f2": 20} + ) expected_query = { "query": {"bool": {"must": [{"match": {"f1": 10}}, {"match": {"f2": 20}}]}}, diff --git a/python/tests/test_util.py b/python/tests/test_util.py index b39501162..330c76b5c 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -736,7 +736,9 @@ def test_get_dataset_type_HIVEDB_with_dfs(self): assert db_type == "HIVEDB" def test_get_dataset_type_DATASET(self): - db_type = hsfs.util.get_dataset_type("/Projects/temp/Resources/kafka__tstore.jks") + db_type = hsfs.util.get_dataset_type( + "/Projects/temp/Resources/kafka__tstore.jks" + ) assert db_type == "DATASET" def test_get_dataset_type_DATASET_with_dfs(self):