Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSTORE-1436] Add foreign_keys and make helper columns consistent in materialized training dataset #406

Merged
merged 9 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions python/hopsworks_common/core/dataset_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,10 @@ def upload(

if self.exists(destination_path):
if overwrite:
if 'datasetType' in self._get(destination_path):
raise DatasetException("overwrite=True not supported on a top-level dataset")
if "datasetType" in self._get(destination_path):
raise DatasetException(
"overwrite=True not supported on a top-level dataset"
)
else:
self.remove(destination_path)
else:
Expand Down Expand Up @@ -240,7 +242,14 @@ def upload(
# uploading files in the same folder is done concurrently
futures = [
executor.submit(
self._upload_file, f_name, root + os.sep + f_name, remote_base_path, chunk_size, simultaneous_chunks, max_chunk_retries, chunk_retry_interval
self._upload_file,
f_name,
root + os.sep + f_name,
remote_base_path,
chunk_size,
simultaneous_chunks,
max_chunk_retries,
chunk_retry_interval,
)
for f_name in files
]
Expand All @@ -252,13 +261,28 @@ def upload(
except Exception as e:
raise e
else:
self._upload_file(file_name, local_path, upload_path, chunk_size, simultaneous_chunks, max_chunk_retries, chunk_retry_interval)
self._upload_file(
file_name,
local_path,
upload_path,
chunk_size,
simultaneous_chunks,
max_chunk_retries,
chunk_retry_interval,
)

return upload_path + "/" + os.path.basename(local_path)


def _upload_file(self, file_name, local_path, upload_path, chunk_size, simultaneous_chunks, max_chunk_retries, chunk_retry_interval):

def _upload_file(
self,
file_name,
local_path,
upload_path,
chunk_size,
simultaneous_chunks,
max_chunk_retries,
chunk_retry_interval,
):
file_size = os.path.getsize(local_path)

num_chunks = math.ceil(file_size / chunk_size)
Expand Down Expand Up @@ -508,8 +532,10 @@ def copy(self, source_path: str, destination_path: str, overwrite: bool = False)
"""
if self.exists(destination_path):
if overwrite:
if 'datasetType' in self._get(destination_path):
raise DatasetException("overwrite=True not supported on a top-level dataset")
if "datasetType" in self._get(destination_path):
raise DatasetException(
"overwrite=True not supported on a top-level dataset"
)
else:
self.remove(destination_path)
else:
Expand Down Expand Up @@ -551,8 +577,10 @@ def move(self, source_path: str, destination_path: str, overwrite: bool = False)
"""
if self.exists(destination_path):
if overwrite:
if 'datasetType' in self._get(destination_path):
raise DatasetException("overwrite=True not supported on a top-level dataset")
if "datasetType" in self._get(destination_path):
raise DatasetException(
"overwrite=True not supported on a top-level dataset"
)
else:
self.remove(destination_path)
else:
Expand Down
28 changes: 23 additions & 5 deletions python/hsfs/constructor/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,12 @@ def _get_feature_by_name(
query_features[feat.name] = query_features.get(feat.name, []) + [
feature_entry
]
for join_obj in self.joins:

# collect joins. we do it recursively to collect nested joins.
joins = set(self.joins)
[self._fg_rec_add_joins(q_join, joins) for q_join in self.joins]

for join_obj in joins:
for feat in join_obj.query._left_features:
feature_entry = (
feat,
Expand Down Expand Up @@ -815,17 +820,30 @@ def get_feature(self, feature_name: str) -> Feature:
"""
return self._get_feature_by_name(feature_name)[0]

def _fg_rec_add(self, join_object, featuregroups):
def _fg_rec_add_joins(self, join_object, joins):
"""
Recursively get a query object from nested join and add to joins list.

# Arguments
join_object: `Join object`.
"""
if len(join_object.query.joins) > 0:
for nested_join in join_object.query.joins:
self._fg_rec_add_joins(nested_join, joins)
for q_join in join_object.query.joins:
joins.add(q_join)

def _fg_rec_add(self, join_object, feature_groups):
"""
Recursively get a feature groups from nested join and add to featuregroups list.
Recursively get a feature groups from nested join and add to feature_groups list.

# Arguments
join_object: `Join object`.
"""
if len(join_object.query.joins) > 0:
for nested_join in join_object.query.joins:
self._fg_rec_add(nested_join, featuregroups)
featuregroups.add(join_object.query._left_feature_group)
self._fg_rec_add(nested_join, feature_groups)
feature_groups.add(join_object.query._left_feature_group)

def __getattr__(self, name: str) -> Any:
try:
Expand Down
32 changes: 10 additions & 22 deletions python/hsfs/core/delta_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def register_temporary_table(self, delta_fg_alias, read_options):
delta_options = self._setup_delta_read_opts(delta_fg_alias, read_options)
self._spark_session.read.format(self.DELTA_SPARK_FORMAT).options(
**delta_options
).load(location).createOrReplaceTempView(
delta_fg_alias.alias
)
).load(location).createOrReplaceTempView(delta_fg_alias.alias)

def _setup_delta_read_opts(self, delta_fg_alias, read_options):
delta_options = {}
Expand Down Expand Up @@ -89,16 +87,12 @@ def _setup_delta_read_opts(self, delta_fg_alias, read_options):
def delete_record(self, delete_df):
location = self._feature_group.prepare_spark_location()

if not DeltaTable.isDeltaTable(
self._spark_session, location
):
if not DeltaTable.isDeltaTable(self._spark_session, location):
raise FeatureStoreException(
f"This is no data available in Feature group {self._feature_group.name}, or it not DELTA enabled "
)
else:
fg_source_table = DeltaTable.forPath(
self._spark_session, location
)
fg_source_table = DeltaTable.forPath(self._spark_session, location)

source_alias = (
f"{self._feature_group.name}_{self._feature_group.version}_source"
Expand All @@ -112,9 +106,7 @@ def delete_record(self, delete_df):
delete_df.alias(updates_alias), merge_query_str
).whenMatchedDelete().execute()

fg_commit = self._get_last_commit_metadata(
self._spark_session, location
)
fg_commit = self._get_last_commit_metadata(self._spark_session, location)
return self._feature_group_api.commit(self._feature_group, fg_commit)

def _write_delta_dataset(self, dataset, write_options):
Expand All @@ -123,9 +115,7 @@ def _write_delta_dataset(self, dataset, write_options):
if write_options is None:
write_options = {}

if not DeltaTable.isDeltaTable(
self._spark_session, location
):
if not DeltaTable.isDeltaTable(self._spark_session, location):
(
dataset.write.format(DeltaEngine.DELTA_SPARK_FORMAT)
.options(**write_options)
Expand All @@ -138,9 +128,7 @@ def _write_delta_dataset(self, dataset, write_options):
.save(location)
)
else:
fg_source_table = DeltaTable.forPath(
self._spark_session, location
)
fg_source_table = DeltaTable.forPath(self._spark_session, location)

source_alias = (
f"{self._feature_group.name}_{self._feature_group.version}_source"
Expand All @@ -154,13 +142,13 @@ def _write_delta_dataset(self, dataset, write_options):
dataset.alias(updates_alias), merge_query_str
).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()

return self._get_last_commit_metadata(
self._spark_session, location
)
return self._get_last_commit_metadata(self._spark_session, location)

def vacuum(self, retention_hours: int):
location = self._feature_group.prepare_spark_location()
retention = f"RETAIN {retention_hours} HOURS" if retention_hours is not None else ""
retention = (
f"RETAIN {retention_hours} HOURS" if retention_hours is not None else ""
)
self._spark_session.sql(f"VACUUM '{location}' {retention}")

def _generate_merge_query(self, source_alias, updates_alias):
Expand Down
4 changes: 3 additions & 1 deletion python/hsfs/core/external_feature_group_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ def save(self, feature_group):
external_dataset
)

# set primary and partition key columns
# set primary, foreign and partition key columns
# we should move this to the backend
util.verify_attribute_key_names(feature_group, True)
for feat in feature_group.features:
if feat.name in feature_group.primary_key:
feat.primary = True
if feat.name in feature_group.foreign_key:
feat.foreign = True
util.validate_embedding_feature_type(
feature_group.embedding_index, feature_group._features
)
Expand Down
4 changes: 1 addition & 3 deletions python/hsfs/core/feature_group_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,7 @@ def update_table_schema(

headers = {"content-type": "application/json"}
return job.Job.from_response_json(
_client._send_request(
"POST", path_params, headers=headers
),
_client._send_request("POST", path_params, headers=headers),
)

def get_parent_feature_groups(
Expand Down
4 changes: 3 additions & 1 deletion python/hsfs/core/feature_group_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,15 @@ def save_feature_group_metadata(
feature_group.features, dataframe_features
)

# set primary and partition key columns
# set primary, foreign and partition key columns
# we should move this to the backend
util.verify_attribute_key_names(feature_group)

for feat in feature_group.features:
if feat.name in feature_group.primary_key:
feat.primary = True
if feat.name in feature_group.foreign_key:
feat.foreign = True
if feat.name in feature_group.partition_key:
feat.partition = True
if (
Expand Down
6 changes: 3 additions & 3 deletions python/hsfs/core/feature_view_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,9 @@ def create_training_dataset(
training_dataset_obj,
user_write_options,
spine=None,
primary_keys=False,
event_time=False,
training_helper_columns=False,
primary_keys=True,
event_time=True,
training_helper_columns=True,
):
self._set_event_time(feature_view_obj, training_dataset_obj)
updated_instance = self._create_training_data_metadata(
Expand Down
14 changes: 8 additions & 6 deletions python/hsfs/core/hudi_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@


class HudiEngine:

HUDI_SPEC_FEATURE_NAMES = ["_hoodie_record_key", "_hoodie_partition_path",
"_hoodie_commit_time", "_hoodie_file_name", "_hoodie_commit_seqno"]
HUDI_SPEC_FEATURE_NAMES = [
"_hoodie_record_key",
"_hoodie_partition_path",
"_hoodie_commit_time",
"_hoodie_file_name",
"_hoodie_commit_seqno",
]

HUDI_SPARK_FORMAT = "org.apache.hudi"
HUDI_TABLE_NAME = "hoodie.table.name"
Expand Down Expand Up @@ -109,9 +113,7 @@ def register_temporary_table(self, hudi_fg_alias, read_options):
hudi_options = self._setup_hudi_read_opts(hudi_fg_alias, read_options)
self._spark_session.read.format(self.HUDI_SPARK_FORMAT).options(
**hudi_options
).load(location).createOrReplaceTempView(
hudi_fg_alias.alias
)
).load(location).createOrReplaceTempView(hudi_fg_alias.alias)

def _write_hudi_dataset(self, dataset, save_mode, operation, write_options):
location = self._feature_group.prepare_spark_location()
Expand Down
4 changes: 3 additions & 1 deletion python/hsfs/core/spine_group_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ def save(self, feature_group):
feature_group.dataframe
)

# set primary and partition key columns
# set primary, foreign and partition key columns
# we should move this to the backend
util.verify_attribute_key_names(feature_group, True)
for feat in feature_group.features:
if feat.name in feature_group.primary_key:
feat.primary = True
if feat.name in feature_group.foreign_key:
feat.foreign = True

# need to save dataframe during save since otherwise it will be lost
dataframe = feature_group.dataframe
Expand Down
20 changes: 14 additions & 6 deletions python/hsfs/engine/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,11 +1203,11 @@ def save_stream_dataframe(
"Stream ingestion is not available on Python environments, because it requires Spark as engine."
)

def update_table_schema(self, feature_group: Union[FeatureGroup, ExternalFeatureGroup]) -> None:
def update_table_schema(
self, feature_group: Union[FeatureGroup, ExternalFeatureGroup]
) -> None:
_job = self._feature_group_api.update_table_schema(feature_group)
_job._wait_for_job(
await_termination=True
)
_job._wait_for_job(await_termination=True)

def _get_app_options(
self, user_write_options: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -1516,7 +1516,11 @@ def _write_dataframe_kafka(
now = datetime.now(timezone.utc)
feature_group.materialization_job.run(
args=feature_group.materialization_job.config.get("defaultArgs", "")
+ (f" -initialCheckPointString {initial_check_point}" if initial_check_point else ""),
+ (
f" -initialCheckPointString {initial_check_point}"
if initial_check_point
else ""
),
await_termination=offline_write_options.get("wait_for_job", False),
)
offline_backfill_every_hr = offline_write_options.pop(
Expand Down Expand Up @@ -1546,7 +1550,11 @@ def _write_dataframe_kafka(
# provide the initial_check_point as it will reduce the read amplification of materialization job
feature_group.materialization_job.run(
args=feature_group.materialization_job.config.get("defaultArgs", "")
+ (f" -initialCheckPointString {initial_check_point}" if initial_check_point else ""),
+ (
f" -initialCheckPointString {initial_check_point}"
if initial_check_point
else ""
),
await_termination=offline_write_options.get("wait_for_job", False),
)
return feature_group.materialization_job
Expand Down
Loading
Loading