From f6c5db6f7c444125a241cde5062c7ca6acd06dd2 Mon Sep 17 00:00:00 2001 From: Ralph Rassweiler Date: Wed, 21 Aug 2024 13:01:32 -0300 Subject: [PATCH] Fix dup code (#373) * fix: dedup code --- .../historical_feature_store_writer.py | 9 ++++ tests/unit/butterfree/transform/conftest.py | 53 ++++++------------- 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/butterfree/load/writers/historical_feature_store_writer.py b/butterfree/load/writers/historical_feature_store_writer.py index 0be7d6af..99bfe66a 100644 --- a/butterfree/load/writers/historical_feature_store_writer.py +++ b/butterfree/load/writers/historical_feature_store_writer.py @@ -93,6 +93,15 @@ class HistoricalFeatureStoreWriter(Writer): improve queries performance. The data is stored in partition folders in AWS S3 based on time (per year, month and day). + >>> spark_client = SparkClient() + >>> writer = HistoricalFeatureStoreWriter() + >>> writer.write(feature_set=feature_set, + ... dataframe=dataframe, + ... spark_client=spark_client + ... merge_on=["id", "timestamp"]) + + This procedure will skip dataframe write and will activate Delta Merge. + Use it when the table already exist. """ PARTITION_BY = [ diff --git a/tests/unit/butterfree/transform/conftest.py b/tests/unit/butterfree/transform/conftest.py index 104300c9..d66d1c39 100644 --- a/tests/unit/butterfree/transform/conftest.py +++ b/tests/unit/butterfree/transform/conftest.py @@ -16,6 +16,15 @@ from butterfree.transform.utils import Function +def create_dataframe(data, timestamp_col="ts"): + pdf = ps.DataFrame.from_dict(data) + df = pdf.to_spark() + df = df.withColumn( + TIMESTAMP_COLUMN, df[timestamp_col].cast(DataType.TIMESTAMP.spark) + ) + return df + + def make_dataframe(spark_context, spark_session): data = [ { @@ -54,11 +63,7 @@ def make_dataframe(spark_context, spark_session): "nonfeature": 0, }, ] - pdf = ps.DataFrame.from_dict(data) - df = pdf.to_spark() - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_filtering_dataframe(spark_context, spark_session): @@ -71,11 +76,7 @@ def make_filtering_dataframe(spark_context, spark_session): {"id": 1, "ts": 6, "feature1": None, "feature2": None, "feature3": None}, {"id": 1, "ts": 7, "feature1": None, "feature2": None, "feature3": None}, ] - pdf = ps.DataFrame.from_dict(data) - df = pdf.to_spark() - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_output_filtering_dataframe(spark_context, spark_session): @@ -86,11 +87,7 @@ def make_output_filtering_dataframe(spark_context, spark_session): {"id": 1, "ts": 4, "feature1": 0, "feature2": 1, "feature3": 1}, {"id": 1, "ts": 6, "feature1": None, "feature2": None, "feature3": None}, ] - pdf = ps.DataFrame.from_dict(data) - df = pdf.to_spark() - df = df.withColumn(TIMESTAMP_COLUMN, df.ts.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data) def make_rolling_windows_agg_dataframe(spark_context, spark_session): @@ -126,11 +123,7 @@ def make_rolling_windows_agg_dataframe(spark_context, spark_session): "feature2__avg_over_1_week_rolling_windows": None, }, ] - pdf = ps.DataFrame.from_dict(data) - df = pdf.to_spark() - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data, timestamp_col="timestamp") def make_rolling_windows_hour_slide_agg_dataframe(spark_context, spark_session): @@ -154,11 +147,7 @@ def make_rolling_windows_hour_slide_agg_dataframe(spark_context, spark_session): "feature2__avg_over_1_day_rolling_windows": 500.0, }, ] - pdf = ps.DataFrame.from_dict(data) - df = pdf.to_spark() - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data, timestamp_col="timestamp") def make_multiple_rolling_windows_hour_slide_agg_dataframe( @@ -202,11 +191,7 @@ def make_multiple_rolling_windows_hour_slide_agg_dataframe( "feature2__avg_over_3_days_rolling_windows": 500.0, }, ] - pdf = ps.DataFrame.from_dict(data) - df = pdf.to_spark() - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) - - return df + return create_dataframe(data, timestamp_col="timestamp") def make_fs(spark_context, spark_session): @@ -253,9 +238,7 @@ def make_fs_dataframe_with_distinct(spark_context, spark_session): "h3": "86a8100efffffff", }, ] - pdf = ps.DataFrame.from_dict(data) - df = pdf.to_spark() - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) + df = create_dataframe(data, "timestamp") return df @@ -283,9 +266,7 @@ def make_target_df_distinct(spark_context, spark_session): "feature__sum_over_3_days_rolling_windows": None, }, ] - pdf = ps.DataFrame.from_dict(data) - df = pdf.to_spark() - df = df.withColumn("timestamp", df.timestamp.cast(DataType.TIMESTAMP.spark)) + df = create_dataframe(data, "timestamp") return df