Skip to content

Commit

Permalink
feat: Supports nested struct columns as features, timestamp fields (#153
Browse files Browse the repository at this point in the history
)

* feat: Supports nested struct columns as features, timestamp fields

* fix: Added field mapping support for spark streaming

* feat: Add support for field mapping in SparkOfflineStore

* feat: Add timing for batch write operations in SparkKafkaProcessor

* feat: Enhance SparkKafkaProcessor logging and add unit tests for SparkOfflineStore

* fix: Remove unnecessary f-string usage in SparkOfflineStore tests

* fix: Renamed integration test file name to avoid conflicts

* refactor: Remove unused ingest_df method and clean up imports in ExpediaProvider

---------

Co-authored-by: Bhargav Dodla <[email protected]>
  • Loading branch information
EXPEbdodla and Bhargav Dodla authored Nov 7, 2024
1 parent d56d202 commit cf0f2f2
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 84 deletions.
39 changes: 0 additions & 39 deletions sdk/python/feast/expediagroup/provider/expedia.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import logging
from typing import List, Set

import pandas as pd

from feast.feature_view import FeatureView
from feast.infra.passthrough_provider import PassthroughProvider
from feast.repo_config import RepoConfig
from feast.stream_feature_view import StreamFeatureView

logger = logging.getLogger(__name__)

Expand All @@ -24,37 +19,3 @@ def __init__(self, config: RepoConfig):
)

super().__init__(config)

def ingest_df(
self,
feature_view: FeatureView,
df: pd.DataFrame,
):
drop_list: List[str] = []
fv_schema: Set[str] = set(map(lambda field: field.name, feature_view.schema))
# Add timestamp field to the schema so we don't delete from dataframe
if isinstance(feature_view, StreamFeatureView):
fv_schema.add(feature_view.timestamp_field)
if feature_view.source.created_timestamp_column:
fv_schema.add(feature_view.source.created_timestamp_column)

if isinstance(feature_view, FeatureView):
if feature_view.stream_source is not None:
fv_schema.add(feature_view.stream_source.timestamp_field)
if feature_view.stream_source.created_timestamp_column:
fv_schema.add(feature_view.stream_source.created_timestamp_column)
else:
fv_schema.add(feature_view.batch_source.timestamp_field)
if feature_view.batch_source.created_timestamp_column:
fv_schema.add(feature_view.batch_source.created_timestamp_column)

for column in df.columns:
if column not in fv_schema:
drop_list.append(column)

if len(drop_list) > 0:
print(
f"INFO!!! Dropping extra columns in the dataframe: {drop_list}. Avoid unnecessary columns in the dataframe."
)

super().ingest_df(feature_view, df.drop(drop_list, axis=1))
72 changes: 37 additions & 35 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from types import MethodType
from typing import List, Optional, Set, Union, no_type_check

Expand Down Expand Up @@ -199,7 +200,37 @@ def _ingest_stream_data(self) -> StreamTable:

def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
if isinstance(self.sfv, FeatureView):
return df
# Apply field mapping if it exists.
if self.sfv.stream_source is not None:
if self.sfv.stream_source.field_mapping is not None:
for (
field_mapping_key,
field_mapping_value,
) in self.sfv.stream_source.field_mapping.items():
df = df.withColumn(field_mapping_value, df[field_mapping_key])

# Drop unused columns
## Note: This may need reconsideration when we support writing to offline store for Feature Views
drop_list: List[str] = []
fv_schema: Set[str] = set(
map(lambda field: field.name, self.sfv.schema)
)

fv_schema.add(self.sfv.stream_source.timestamp_field)
if self.sfv.stream_source.created_timestamp_column:
fv_schema.add(self.sfv.stream_source.created_timestamp_column)

for column in df.columns:
if column not in fv_schema:
drop_list.append(column)

if len(drop_list) > 0:
print(
f"INFO!!! Dropping extra columns in the DataFrame: {drop_list}. Avoid unnecessary columns in the dataframe."
)
return df.drop(*drop_list)
else:
raise Exception(f"Stream source is not defined for {self.sfv.name}")
elif isinstance(self.sfv, StreamFeatureView):
return self.sfv.udf.__call__(df) if self.sfv.udf else df

Expand Down Expand Up @@ -271,45 +302,16 @@ def batch_write(
join_keys,
feature_view,
):
drop_list: List[str] = []
fv_schema: Set[str] = set(
map(lambda field: field.name, feature_view.schema)
)
# Add timestamp field to the schema so we don't delete from dataframe
if isinstance(feature_view, StreamFeatureView):
fv_schema.add(feature_view.timestamp_field)
if feature_view.source.created_timestamp_column:
fv_schema.add(feature_view.source.created_timestamp_column)

if isinstance(feature_view, FeatureView):
if feature_view.stream_source is not None:
fv_schema.add(feature_view.stream_source.timestamp_field)
if feature_view.stream_source.created_timestamp_column:
fv_schema.add(
feature_view.stream_source.created_timestamp_column
)
else:
fv_schema.add(feature_view.batch_source.timestamp_field)
if feature_view.batch_source.created_timestamp_column:
fv_schema.add(
feature_view.batch_source.created_timestamp_column
)

for column in df.columns:
if column not in fv_schema:
drop_list.append(column)

if len(drop_list) > 0:
print(
f"INFO!!! Dropping extra columns in the dataframe: {drop_list}. Avoid unnecessary columns in the dataframe."
)

sdf.drop(*drop_list).mapInPandas(
start_time = time.time()
sdf.mapInPandas(
lambda x: batch_write_pandas_df(
x, spark_serialized_artifacts, join_keys
),
"status int",
).count() # dummy action to force evaluation
print(
f"Time taken to write batch {batch_id} is: {(time.time() - start_time) * 1000:.2f} ms"
)

query = (
df.writeStream.outputMode("update")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def _map_by_partition(
) = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
# Spark offline store does the field mapping during pull_latest_from_table_or_query
# This is for the case where the offline store is not spark
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import spark_schema_to_np_dtypes
from feast.utils import _get_fields_with_aliases

# Make sure spark warning are ignored
warnings.simplefilter("ignore", RuntimeWarning)
Expand Down Expand Up @@ -91,16 +92,23 @@ def pull_latest_from_table_or_query(
if created_timestamp_column:
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)

(fields_with_aliases, aliases) = _get_fields_with_aliases(
fields=join_key_columns + feature_name_columns + timestamps,
field_mappings=data_source.field_mapping,
)

fields_as_string = ", ".join(fields_with_aliases)
aliases_as_string = ", ".join(aliases)

start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)
query = f"""
SELECT
{field_string}
{aliases_as_string}
{f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""}
FROM (
SELECT {field_string},
SELECT {fields_as_string},
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_
FROM {from_expression} t1
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}')
Expand Down Expand Up @@ -280,14 +288,19 @@ def pull_all_from_table_or_query(
spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
)

fields = ", ".join(join_key_columns + feature_name_columns + [timestamp_field])
from_expression = data_source.get_table_query_string()
start_date = start_date.astimezone(tz=timezone.utc)
end_date = end_date.astimezone(tz=timezone.utc)

(fields_with_aliases, aliases) = _get_fields_with_aliases(
fields=join_key_columns + feature_name_columns + [timestamp_field],
field_mappings=data_source.field_mapping,
)

fields_with_alias_string = ", ".join(fields_with_aliases)

query = f"""
SELECT {fields}
SELECT {fields_with_alias_string}
FROM {from_expression}
WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}'
"""
Expand Down
33 changes: 29 additions & 4 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def _get_requested_feature_views_to_features_dict(
on_demand_feature_views: List["OnDemandFeatureView"],
) -> Tuple[Dict["FeatureView", List[str]], Dict["OnDemandFeatureView", List[str]]]:
"""Create a dict of FeatureView -> List[Feature] for all requested features.
Set full_feature_names to True to have feature names prefixed by their feature view name."""
Set full_feature_names to True to have feature names prefixed by their feature view name.
"""

feature_views_to_feature_map: Dict["FeatureView", List[str]] = defaultdict(list)
on_demand_feature_views_to_feature_map: Dict["OnDemandFeatureView", List[str]] = (
Expand Down Expand Up @@ -209,6 +210,28 @@ def _run_pyarrow_field_mapping(
return table


def _get_fields_with_aliases(
fields: List[str],
field_mappings: Dict[str, str],
) -> Tuple[List[str], List[str]]:
"""
Get a list of fields with aliases based on the field mappings.
"""
for field in fields:
if "." in field and field not in field_mappings:
raise ValueError(
f"Feature {field} contains a '.' character, which is not allowed in field names. Use field mappings to rename fields."
)
fields_with_aliases = [
f"{field} AS {field_mappings[field]}" if field in field_mappings else field
for field in fields
]
aliases = [
field_mappings[field] if field in field_mappings else field for field in fields
]
return (fields_with_aliases, aliases)


def _coerce_datetime(ts):
"""
Depending on underlying time resolution, arrow to_pydict() sometimes returns pd
Expand Down Expand Up @@ -678,9 +701,11 @@ def _populate_response_from_feature_data(
"""
# Add the feature names to the response.
requested_feature_refs = [
f"{table.projection.name_to_use()}__{feature_name}"
if full_feature_names
else feature_name
(
f"{table.projection.name_to_use()}__{feature_name}"
if full_feature_names
else feature_name
)
for feature_name in requested_features
]
online_features_response.metadata.feature_names.val.extend(requested_feature_refs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from datetime import datetime
from unittest.mock import MagicMock, patch

from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkOfflineStore,
SparkOfflineStoreConfig,
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
SparkSource,
)
from feast.infra.offline_stores.offline_store import RetrievalJob
from feast.repo_config import RepoConfig


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_session):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_nested_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_header.event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, nested_timestamp, created_timestamp
FROM (
SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000')
) t2
WHERE feast_row_ = 1""" # noqa: W293

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_without_nested_timestamp_or_query(
mock_get_spark_session,
):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="event_published_datetime_utc",
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp
FROM (
SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000')
) t2
WHERE feast_row_ = 1""" # noqa: W293

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()

0 comments on commit cf0f2f2

Please sign in to comment.