From ff7746ba6245e797a4e5f645ee6609f21d6fc4e7 Mon Sep 17 00:00:00 2001 From: Matthew Powers Date: Wed, 31 Jan 2024 11:20:26 -0500 Subject: [PATCH] blackify code (#46) --- dat/generated_tables.py | 581 ++++++++++++---------- dat/main.py | 17 +- dat/models.py | 10 +- dat/spark_builder.py | 20 +- tests/conftest.py | 4 +- tests/pyspark_delta/test_pyspark_delta.py | 45 +- 6 files changed, 355 insertions(+), 322 deletions(-) diff --git a/dat/generated_tables.py b/dat/generated_tables.py index 4671cc4..a7a77f8 100644 --- a/dat/generated_tables.py +++ b/dat/generated_tables.py @@ -24,16 +24,16 @@ def get_version_metadata(case: TestCaseInfo) -> TableVersionMetadata: return TableVersionMetadata( version=table.history(1).collect()[0].version, - properties=detail['properties'], - min_reader_version=detail['minReaderVersion'], - min_writer_version=detail['minWriterVersion'], + properties=detail["properties"], + min_reader_version=detail["minReaderVersion"], + min_writer_version=detail["minWriterVersion"], ) def save_expected(case: TestCaseInfo, as_latest=False) -> None: """Save the specified version of a Delta Table as a Parquet file.""" spark = get_spark_session() - df = spark.read.format('delta').load(case.delta_root) + df = spark.read.format("delta").load(case.delta_root) version_metadata = get_version_metadata(case) version = None if as_latest else version_metadata.version @@ -43,24 +43,21 @@ def save_expected(case: TestCaseInfo, as_latest=False) -> None: df.write.parquet(case.expected_path(version)) - out_path = case.expected_root(version) / 'table_version_metadata.json' - with open(out_path, 'w') as f: + out_path = case.expected_root(version) / "table_version_metadata.json" + with open(out_path, "w") as f: f.write(version_metadata.json(indent=2)) def reference_table(name: str, description: str): - case = TestCaseInfo( - name=name, - description=description - ) + case = TestCaseInfo(name=name, description=description) def wrapper(create_table): def inner(): spark = get_spark_session() create_table(case, spark) - with open(case.root / 'test_case_info.json', 'w') as f: - f.write(case.json(indent=2, separators=(',', ': '))) + with open(case.root / "test_case_info.json", "w") as f: + f.write(case.json(indent=2, separators=(",", ": "))) # Write out latest save_expected(case, as_latest=True) @@ -73,469 +70,501 @@ def inner(): @reference_table( - name='basic_append', - description='A basic table with two append writes.' + name="basic_append", description="A basic table with two append writes." ) def create_basic_append(case: TestCaseInfo, spark: SparkSession): - columns = ['letter', 'number', 'a_float'] - data = [('a', 1, 1.1), ('b', 2, 2.2), ('c', 3, 3.3)] + columns = ["letter", "number", "a_float"] + data = [("a", 1, 1.1), ("b", 2, 2.2), ("c", 3, 3.3)] df = spark.createDataFrame(data, schema=columns) - df.repartition(1).write.format('delta').save(case.delta_root) + df.repartition(1).write.format("delta").save(case.delta_root) save_expected(case) - data = [('d', 4, 4.4), ('e', 5, 5.5)] + data = [("d", 4, 4.4), ("e", 5, 5.5)] df = spark.createDataFrame(data, schema=columns) - df.repartition(1).write.format('delta').mode( - 'append').save(case.delta_root) + df.repartition(1).write.format("delta").mode("append").save(case.delta_root) save_expected(case) @reference_table( - name='basic_partitioned', - description='A basic partitioned table', + name="basic_partitioned", + description="A basic partitioned table", ) def create_basic_partitioned(case: TestCaseInfo, spark: SparkSession): - columns = ['letter', 'number', 'a_float'] - data = [('a', 1, 1.1), ('b', 2, 2.2), ('c', 3, 3.3)] + columns = ["letter", "number", "a_float"] + data = [("a", 1, 1.1), ("b", 2, 2.2), ("c", 3, 3.3)] df = spark.createDataFrame(data, schema=columns) - df.repartition(1).write.partitionBy('letter').format( - 'delta').save(case.delta_root) + df.repartition(1).write.partitionBy("letter").format("delta").save(case.delta_root) save_expected(case) - data2 = [('a', 4, 4.4), ('e', 5, 5.5), ('f', 6, 6.6)] + data2 = [("a", 4, 4.4), ("e", 5, 5.5), ("f", 6, 6.6)] df = spark.createDataFrame(data2, schema=columns) - df.repartition(1).write.partitionBy('letter').format( - 'delta').mode('append').save(case.delta_root) + df.repartition(1).write.partitionBy("letter").format("delta").mode("append").save( + case.delta_root + ) save_expected(case) @reference_table( - name='partitioned_with_null', - description='A partitioned table with a null partition', + name="partitioned_with_null", + description="A partitioned table with a null partition", ) def create_partitioned_with_null(case: TestCaseInfo, spark: SparkSession): - columns = ['letter', 'number', 'a_float'] - data = [('a', 1, 1.1), ('b', 2, 2.2), ('c', 3, 3.3)] + columns = ["letter", "number", "a_float"] + data = [("a", 1, 1.1), ("b", 2, 2.2), ("c", 3, 3.3)] df = spark.createDataFrame(data, schema=columns) - df.repartition(1).write.partitionBy('letter').format( - 'delta').save(case.delta_root) + df.repartition(1).write.partitionBy("letter").format("delta").save(case.delta_root) save_expected(case) - data2 = [('a', 4, 4.4), ('e', 5, 5.5), (None, 6, 6.6)] + data2 = [("a", 4, 4.4), ("e", 5, 5.5), (None, 6, 6.6)] df = spark.createDataFrame(data2, schema=columns) - df.repartition(1).write.partitionBy('letter').format( - 'delta').mode('append').save(case.delta_root) + df.repartition(1).write.partitionBy("letter").format("delta").mode("append").save( + case.delta_root + ) save_expected(case) @reference_table( - name='multi_partitioned', - description=('A table with multiple partitioning columns. Partition ' - 'values include nulls and escape characters.'), + name="multi_partitioned", + description=( + "A table with multiple partitioning columns. Partition " + "values include nulls and escape characters." + ), ) def create_multi_partitioned(case: TestCaseInfo, spark: SparkSession): - columns = ['letter', 'date', 'data', 'number'] - partition_columns = ['letter', 'date', 'data'] + columns = ["letter", "date", "data", "number"] + partition_columns = ["letter", "date", "data"] data = [ - ('a', date(1970, 1, 1), b'hello', 1), - ('b', date(1970, 1, 1), b'world', 2), - ('b', date(1970, 1, 2), b'world', 3) + ("a", date(1970, 1, 1), b"hello", 1), + ("b", date(1970, 1, 1), b"world", 2), + ("b", date(1970, 1, 2), b"world", 3), ] df = spark.createDataFrame(data, schema=columns) schema = df.schema - df.repartition(1).write.format('delta').partitionBy( - *partition_columns).save(case.delta_root) + df.repartition(1).write.format("delta").partitionBy(*partition_columns).save( + case.delta_root + ) save_expected(case) # Introduce null values in partition columns - data2 = [ - ('a', None, b'x', 4), - (None, None, None, 5) - ] + data2 = [("a", None, b"x", 4), (None, None, None, 5)] df = spark.createDataFrame(data2, schema=schema) - df.repartition(1).write.format('delta').mode( - 'append').save(case.delta_root) + df.repartition(1).write.format("delta").mode("append").save(case.delta_root) save_expected(case) # Introduce escape characters data3 = [ - ('/%20%f', date(1970, 1, 1), b'hello', 6), - ('b', date(1970, 1, 1), '😈'.encode(), 7) + ("/%20%f", date(1970, 1, 1), b"hello", 6), + ("b", date(1970, 1, 1), "😈".encode(), 7), ] df = spark.createDataFrame(data3, schema=schema) - df.repartition(1).write.format('delta').mode( - 'overwrite').save(case.delta_root) + df.repartition(1).write.format("delta").mode("overwrite").save(case.delta_root) save_expected(case) @reference_table( - name='multi_partitioned_2', - description=('Multiple levels of partitioning, with boolean, timestamp, and ' - 'decimal partition columns') + name="multi_partitioned_2", + description=( + "Multiple levels of partitioning, with boolean, timestamp, and " + "decimal partition columns" + ), ) def create_multi_partitioned_2(case: TestCaseInfo, spark: SparkSession): - columns = ['bool', 'time', 'amount', 'int'] - partition_columns = ['bool', 'time', 'amount'] + columns = ["bool", "time", "amount", "int"] + partition_columns = ["bool", "time", "amount"] data = [ - (True, datetime(1970, 1, 1), Decimal('200.00'), 1), - (True, datetime(1970, 1, 1, 12, 30), Decimal('200.00'), 2), - (False, datetime(1970, 1, 2, 8, 45), Decimal('12.00'), 3) + (True, datetime(1970, 1, 1), Decimal("200.00"), 1), + (True, datetime(1970, 1, 1, 12, 30), Decimal("200.00"), 2), + (False, datetime(1970, 1, 2, 8, 45), Decimal("12.00"), 3), ] df = spark.createDataFrame(data, schema=columns) - df.repartition(1).write.format('delta').partitionBy( - *partition_columns).save(case.delta_root) + df.repartition(1).write.format("delta").partitionBy(*partition_columns).save( + case.delta_root + ) @reference_table( - name='with_schema_change', - description='Table which has schema change using overwriteSchema=True.', + name="with_schema_change", + description="Table which has schema change using overwriteSchema=True.", ) def with_schema_change(case: TestCaseInfo, spark: SparkSession): - columns = ['letter', 'number'] - data = [('a', 1), ('b', 2), ('c', 3)] + columns = ["letter", "number"] + data = [("a", 1), ("b", 2), ("c", 3)] df = spark.createDataFrame(data, schema=columns) - df.repartition(1).write.format('delta').save(case.delta_root) + df.repartition(1).write.format("delta").save(case.delta_root) - columns = ['num1', 'num2'] + columns = ["num1", "num2"] data2 = [(22, 33), (44, 55), (66, 77)] df = spark.createDataFrame(data2, schema=columns) - df.repartition(1).write.mode('overwrite').option( - 'overwriteSchema', True).format('delta').save( - case.delta_root) + df.repartition(1).write.mode("overwrite").option("overwriteSchema", True).format( + "delta" + ).save(case.delta_root) save_expected(case) @reference_table( - name='all_primitive_types', - description='Table containing all non-nested types', + name="all_primitive_types", + description="Table containing all non-nested types", ) def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession): - schema = types.StructType([ - types.StructField('utf8', types.StringType()), - types.StructField('int64', types.LongType()), - types.StructField('int32', types.IntegerType()), - types.StructField('int16', types.ShortType()), - types.StructField('int8', types.ByteType()), - types.StructField('float32', types.FloatType()), - types.StructField('float64', types.DoubleType()), - types.StructField('bool', types.BooleanType()), - types.StructField('binary', types.BinaryType()), - types.StructField('decimal', types.DecimalType(5, 3)), - types.StructField('date32', types.DateType()), - types.StructField('timestamp', types.TimestampType()), - ]) - - df = spark.createDataFrame([ - ( - str(i), - i, - i, - i, - i, - float(i), - float(i), - i % 2 == 0, - bytes(i), - Decimal('10.000') + i, - date(1970, 1, 1) + timedelta(days=i), - datetime(1970, 1, 1) + timedelta(hours=i) - ) - for i in range(5) - ], schema=schema) + schema = types.StructType( + [ + types.StructField("utf8", types.StringType()), + types.StructField("int64", types.LongType()), + types.StructField("int32", types.IntegerType()), + types.StructField("int16", types.ShortType()), + types.StructField("int8", types.ByteType()), + types.StructField("float32", types.FloatType()), + types.StructField("float64", types.DoubleType()), + types.StructField("bool", types.BooleanType()), + types.StructField("binary", types.BinaryType()), + types.StructField("decimal", types.DecimalType(5, 3)), + types.StructField("date32", types.DateType()), + types.StructField("timestamp", types.TimestampType()), + ] + ) + + df = spark.createDataFrame( + [ + ( + str(i), + i, + i, + i, + i, + float(i), + float(i), + i % 2 == 0, + bytes(i), + Decimal("10.000") + i, + date(1970, 1, 1) + timedelta(days=i), + datetime(1970, 1, 1) + timedelta(hours=i), + ) + for i in range(5) + ], + schema=schema, + ) - df.repartition(1).write.format('delta').save(case.delta_root) + df.repartition(1).write.format("delta").save(case.delta_root) @reference_table( - name='nested_types', - description='Table containing various nested types', + name="nested_types", + description="Table containing various nested types", ) def create_nested_types(case: TestCaseInfo, spark: SparkSession): - schema = types.StructType([ - types.StructField( - 'pk', types.IntegerType() - ), - types.StructField( - 'struct', types.StructType( - [types.StructField( - 'float64', types.DoubleType()), - types.StructField( - 'bool', types.BooleanType()), ])), - types.StructField( - 'array', types.ArrayType( - types.ShortType())), - types.StructField( - 'map', types.MapType( - types.StringType(), - types.IntegerType())), ]) - - df = spark.createDataFrame([ - ( - i, - {'float64': float(i), 'bool': i % 2 == 0}, - list(range(i + 1)), - {str(i): i for i in range(i)} - ) - for i in range(5) - ], schema=schema) + schema = types.StructType( + [ + types.StructField("pk", types.IntegerType()), + types.StructField( + "struct", + types.StructType( + [ + types.StructField("float64", types.DoubleType()), + types.StructField("bool", types.BooleanType()), + ] + ), + ), + types.StructField("array", types.ArrayType(types.ShortType())), + types.StructField( + "map", types.MapType(types.StringType(), types.IntegerType()) + ), + ] + ) - df.repartition(1).write.format('delta').save(case.delta_root) + df = spark.createDataFrame( + [ + ( + i, + {"float64": float(i), "bool": i % 2 == 0}, + list(range(i + 1)), + {str(i): i for i in range(i)}, + ) + for i in range(5) + ], + schema=schema, + ) + + df.repartition(1).write.format("delta").save(case.delta_root) def get_sample_data( - spark: SparkSession, seed: int = 42, nrows: int = 5) -> pyspark.sql.DataFrame: + spark: SparkSession, seed: int = 42, nrows: int = 5 +) -> pyspark.sql.DataFrame: # Use seed to get consistent data between runs, for reproducibility random.seed(seed) - return spark.createDataFrame([ - ( - random.choice(['a', 'b', 'c', None]), - random.randint(0, 1000), - date(random.randint(1970, 2020), random.randint(1, 12), 1) - ) - for i in range(nrows) - ], schema=['letter', 'int', 'date']) + return spark.createDataFrame( + [ + ( + random.choice(["a", "b", "c", None]), + random.randint(0, 1000), + date(random.randint(1970, 2020), random.randint(1, 12), 1), + ) + for i in range(nrows) + ], + schema=["letter", "int", "date"], + ) @reference_table( - name='with_checkpoint', - description='Table with a checkpoint', + name="with_checkpoint", + description="Table with a checkpoint", ) def create_with_checkpoint(case: TestCaseInfo, spark: SparkSession): - spark.conf.set('spark.databricks.delta.legacy.allowAmbiguousPathsInCreateTable', 'true') + spark.conf.set( + "spark.databricks.delta.legacy.allowAmbiguousPathsInCreateTable", "true" + ) df = get_sample_data(spark) - (DeltaTable.create(spark) - .location(str(Path(case.delta_root).absolute())) - .addColumns(df.schema) - .property('delta.checkpointInterval', '2') - .execute()) + ( + DeltaTable.create(spark) + .location(str(Path(case.delta_root).absolute())) + .addColumns(df.schema) + .property("delta.checkpointInterval", "2") + .execute() + ) for i in range(3): df = get_sample_data(spark, seed=i, nrows=5) - df.repartition(1).write.format('delta').mode( - 'overwrite').save(case.delta_root) + df.repartition(1).write.format("delta").mode("overwrite").save(case.delta_root) - assert any(path.suffixes == ['.checkpoint', '.parquet'] - for path in (Path(case.delta_root) / '_delta_log').iterdir()) + assert any( + path.suffixes == [".checkpoint", ".parquet"] + for path in (Path(case.delta_root) / "_delta_log").iterdir() + ) def remove_log_file(delta_root: str, version: int): - os.remove(os.path.join(delta_root, '_delta_log', f'{version:0>20}.json')) + os.remove(os.path.join(delta_root, "_delta_log", f"{version:0>20}.json")) @reference_table( - name='no_replay', - description='Table with a checkpoint and prior commits cleaned up', + name="no_replay", + description="Table with a checkpoint and prior commits cleaned up", ) def create_no_replay(case: TestCaseInfo, spark: SparkSession): - spark.conf.set( - 'spark.databricks.delta.retentionDurationCheck.enabled', 'false') + spark.conf.set("spark.databricks.delta.retentionDurationCheck.enabled", "false") df = get_sample_data(spark) - table = (DeltaTable.create(spark) - .location(str(Path(case.delta_root).absolute())) - .addColumns(df.schema) - .property('delta.checkpointInterval', '2') - .execute()) + table = ( + DeltaTable.create(spark) + .location(str(Path(case.delta_root).absolute())) + .addColumns(df.schema) + .property("delta.checkpointInterval", "2") + .execute() + ) for i in range(3): df = get_sample_data(spark, seed=i, nrows=5) - df.repartition(1).write.format('delta').mode( - 'overwrite').save(case.delta_root) + df.repartition(1).write.format("delta").mode("overwrite").save(case.delta_root) table.vacuum(retentionHours=0) remove_log_file(case.delta_root, version=0) remove_log_file(case.delta_root, version=1) - files_in_log = list((Path(case.delta_root) / '_delta_log').iterdir()) - assert any(path.suffixes == ['.checkpoint', '.parquet'] - for path in files_in_log) - assert not any(path.name == f'{0:0>20}.json' for path in files_in_log) + files_in_log = list((Path(case.delta_root) / "_delta_log").iterdir()) + assert any(path.suffixes == [".checkpoint", ".parquet"] for path in files_in_log) + assert not any(path.name == f"{0:0>20}.json" for path in files_in_log) @reference_table( - name='stats_as_struct', - description='Table with stats only written as struct (not JSON) with Checkpoint', + name="stats_as_struct", + description="Table with stats only written as struct (not JSON) with Checkpoint", ) def create_stats_as_struct(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) - (DeltaTable.create(spark) - .location(str(Path(case.delta_root).absolute())) - .addColumns(df.schema) - .property('delta.checkpointInterval', '2') - .property('delta.checkpoint.writeStatsAsStruct', 'true') - .property('delta.checkpoint.writeStatsAsJson', 'false') - .execute()) + ( + DeltaTable.create(spark) + .location(str(Path(case.delta_root).absolute())) + .addColumns(df.schema) + .property("delta.checkpointInterval", "2") + .property("delta.checkpoint.writeStatsAsStruct", "true") + .property("delta.checkpoint.writeStatsAsJson", "false") + .execute() + ) for i in range(3): df = get_sample_data(spark, seed=i, nrows=5) - df.repartition(1).write.format('delta').mode( - 'overwrite').save(case.delta_root) + df.repartition(1).write.format("delta").mode("overwrite").save(case.delta_root) @reference_table( - name='no_stats', - description='Table with no stats', + name="no_stats", + description="Table with no stats", ) def create_no_stats(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) - (DeltaTable.create(spark) - .location(str(Path(case.delta_root).absolute())) - .addColumns(df.schema) - .property('delta.checkpointInterval', '2') - .property('delta.checkpoint.writeStatsAsStruct', 'false') - .property('delta.checkpoint.writeStatsAsJson', 'false') - .property('delta.dataSkippingNumIndexedCols', '0') - .execute()) + ( + DeltaTable.create(spark) + .location(str(Path(case.delta_root).absolute())) + .addColumns(df.schema) + .property("delta.checkpointInterval", "2") + .property("delta.checkpoint.writeStatsAsStruct", "false") + .property("delta.checkpoint.writeStatsAsJson", "false") + .property("delta.dataSkippingNumIndexedCols", "0") + .execute() + ) for i in range(3): df = get_sample_data(spark, seed=i, nrows=5) - df.repartition(1).write.format('delta').mode( - 'overwrite').save(case.delta_root) + df.repartition(1).write.format("delta").mode("overwrite").save(case.delta_root) @reference_table( - name='deletion_vectors', - description='Table with deletion vectors', + name="deletion_vectors", + description="Table with deletion vectors", ) def create_deletion_vectors(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) delta_path = str(Path(case.delta_root).absolute()) - delta_table: DeltaTable = (DeltaTable.create(spark) - .location(delta_path) - .addColumns(df.schema) - .property('delta.enableDeletionVectors', 'true') - .execute()) + delta_table: DeltaTable = ( + DeltaTable.create(spark) + .location(delta_path) + .addColumns(df.schema) + .property("delta.enableDeletionVectors", "true") + .execute() + ) - df.repartition(1).write.format('delta').mode('append').save(case.delta_root) + df.repartition(1).write.format("delta").mode("append").save(case.delta_root) delta_table.delete(col("letter") == "a") -@reference_table( - name='check_constraints', - description='Table with a check constraint' -) +@reference_table(name="check_constraints", description="Table with a check constraint") def check_constraint_table(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) delta_path = str(Path(case.delta_root).absolute()) - (DeltaTable.create(spark) - .location(delta_path) - .addColumns(df.schema) - .property('delta.enableDeletionVectors', 'true') - .execute()) + ( + DeltaTable.create(spark) + .location(delta_path) + .addColumns(df.schema) + .property("delta.enableDeletionVectors", "true") + .execute() + ) - df.repartition(1).write.format('delta').mode('append').save(case.delta_root) - spark.sql(f"ALTER TABLE delta.`{delta_path}` ADD CONSTRAINT const1 CHECK (int > 0);") + df.repartition(1).write.format("delta").mode("append").save(case.delta_root) + spark.sql( + f"ALTER TABLE delta.`{delta_path}` ADD CONSTRAINT const1 CHECK (int > 0);" + ) @reference_table( - name='cdf', - description='Table with cdf turned on', + name="cdf", + description="Table with cdf turned on", ) def create_change_data_feed(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) delta_path = str(Path(case.delta_root).absolute()) - delta_table: DeltaTable = (DeltaTable.create(spark) - .location(delta_path) - .addColumns(df.schema) - .property('delta.enableChangeDataFeed', 'true') - .execute()) + delta_table: DeltaTable = ( + DeltaTable.create(spark) + .location(delta_path) + .addColumns(df.schema) + .property("delta.enableChangeDataFeed", "true") + .execute() + ) - df.repartition(1).write.format('delta').mode('append').save(case.delta_root) + df.repartition(1).write.format("delta").mode("append").save(case.delta_root) - delta_table.update( - condition=col("letter") == "c", - set={"letter": lit("a")} - ) + delta_table.update(condition=col("letter") == "c", set={"letter": lit("a")}) delta_table.delete(col("letter") == "a") @reference_table( - name='generated_columns', - description='Table with a generated column', + name="generated_columns", + description="Table with a generated column", ) def create_generated_columns(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) delta_path = str(Path(case.delta_root).absolute()) - (DeltaTable.create(spark) - .location(delta_path) - .addColumns(df.schema) - .addColumn("creation", types.DateType(), generatedAlwaysAs="CAST(now() AS DATE)") - .execute()) + ( + DeltaTable.create(spark) + .location(delta_path) + .addColumns(df.schema) + .addColumn( + "creation", types.DateType(), generatedAlwaysAs="CAST(now() AS DATE)" + ) + .execute() + ) - df.repartition(1).write.format('delta').mode('append').save(case.delta_root) + df.repartition(1).write.format("delta").mode("append").save(case.delta_root) @reference_table( - name='column_mapping', - description='Table with column mapping turned on', + name="column_mapping", + description="Table with column mapping turned on", ) def create_column_mapping(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) delta_path = str(Path(case.delta_root).absolute()) - (DeltaTable.create(spark) - .location(delta_path) - .addColumns(df.schema) - .property('delta.columnMapping.mode', 'name') - .execute()) + ( + DeltaTable.create(spark) + .location(delta_path) + .addColumns(df.schema) + .property("delta.columnMapping.mode", "name") + .execute() + ) - df.repartition(1).write.format('delta').mode('append').save(case.delta_root) + df.repartition(1).write.format("delta").mode("append").save(case.delta_root) spark.sql(f"ALTER TABLE delta.`{delta_path}` RENAME COLUMN int TO new_int;") - (df.withColumnRenamed('int', 'new_int') - .repartition(1) - .write - .format('delta') - .mode('append') - .save(case.delta_root)) + ( + df.withColumnRenamed("int", "new_int") + .repartition(1) + .write.format("delta") + .mode("append") + .save(case.delta_root) + ) @reference_table( - name='timestamp_ntz', - description='Table with not timezone timestamps in it', + name="timestamp_ntz", + description="Table with not timezone timestamps in it", ) def create_timestamp_ntz(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) delta_path = str(Path(case.delta_root).absolute()) - delta_table: DeltaTable = (DeltaTable.create(spark) - .location(delta_path) - .addColumns(df.schema) - .addColumn("timestampNTZ", types.TimestampNTZType()) - .execute()) + delta_table: DeltaTable = ( + DeltaTable.create(spark) + .location(delta_path) + .addColumns(df.schema) + .addColumn("timestampNTZ", types.TimestampNTZType()) + .execute() + ) delta_table.upgradeTableProtocol(3, 7) - (df.withColumn("timestampNTZ", now().cast(types.TimestampNTZType())) - .repartition(1) - .write - .format('delta') - .mode('append') - .save(case.delta_root)) + ( + df.withColumn("timestampNTZ", now().cast(types.TimestampNTZType())) + .repartition(1) + .write.format("delta") + .mode("append") + .save(case.delta_root) + ) @reference_table( - name='iceberg_compat_v1', - description='Table with Iceberg compatability v1 turned on', + name="iceberg_compat_v1", + description="Table with Iceberg compatability v1 turned on", ) def create_iceberg_compat_v1(case: TestCaseInfo, spark: SparkSession): df = get_sample_data(spark) delta_path = str(Path(case.delta_root).absolute()) - delta_table: DeltaTable = (DeltaTable.create(spark) - .location(delta_path) - .addColumns(df.schema) - .property('delta.enableIcebergCompatV1', 'true') - .execute()) + delta_table: DeltaTable = ( + DeltaTable.create(spark) + .location(delta_path) + .addColumns(df.schema) + .property("delta.enableIcebergCompatV1", "true") + .execute() + ) delta_table.upgradeTableProtocol(3, 7) - df.repartition(1).write.format('delta').mode('append').save(case.delta_root) + df.repartition(1).write.format("delta").mode("append").save(case.delta_root) diff --git a/dat/main.py b/dat/main.py index 5ea298d..917c7c2 100644 --- a/dat/main.py +++ b/dat/main.py @@ -33,22 +33,21 @@ def cli(): @click.command() -@click.option('--table-name') +@click.option("--table-name") def write_generated_reference_tables(table_name: Optional[str]): if table_name: for metadata, create_table in generated_tables.registered_reference_tables: if metadata.name == table_name: logging.info("Writing table '%s'", metadata.name) - out_base = Path('out/reader_tests/generated') / table_name + out_base = Path("out/reader_tests/generated") / table_name shutil.rmtree(out_base, ignore_errors=True) create_table() break else: - raise ValueError( - f"Could not find generated table named '{table_name}'") + raise ValueError(f"Could not find generated table named '{table_name}'") else: - out_base = Path('out/reader_tests/generated') + out_base = Path("out/reader_tests/generated") shutil.rmtree(out_base, ignore_errors=True) for metadata, create_table in generated_tables.registered_reference_tables: @@ -58,18 +57,18 @@ def write_generated_reference_tables(table_name: Optional[str]): @click.command() def write_model_schemas(): - out_base = Path('out/schemas') + out_base = Path("out/schemas") os.makedirs(out_base, exist_ok=True) - with open(out_base / 'TestCaseInfo.json', 'w') as f: + with open(out_base / "TestCaseInfo.json", "w") as f: f.write(TestCaseInfo.schema_json(indent=2)) - with open(out_base / 'TableVersionMetadata.json', 'w') as f: + with open(out_base / "TableVersionMetadata.json", "w") as f: f.write(TableVersionMetadata.schema_json(indent=2)) cli.add_command(write_generated_reference_tables) cli.add_command(write_model_schemas) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/dat/models.py b/dat/models.py index 27f702b..a2279ff 100644 --- a/dat/models.py +++ b/dat/models.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -OUT_ROOT = Path('./out/reader_tests/generated/') +OUT_ROOT = Path("./out/reader_tests/generated/") class TestCaseInfo(BaseModel): @@ -16,14 +16,14 @@ def root(self) -> Path: @property def delta_root(self) -> str: - return str(self.root / 'delta') + return str(self.root / "delta") def expected_root(self, version: Optional[int] = None) -> Path: - version_path = 'latest' if version is None else f'v{version}' - return self.root / 'expected' / version_path + version_path = "latest" if version is None else f"v{version}" + return self.root / "expected" / version_path def expected_path(self, version: Optional[int] = None) -> str: - return str(self.expected_root(version) / 'table_content') + return str(self.expected_root(version) / "table_content") class TableVersionMetadata(BaseModel): diff --git a/dat/spark_builder.py b/dat/spark_builder.py index 1c4520a..21ba2ed 100644 --- a/dat/spark_builder.py +++ b/dat/spark_builder.py @@ -8,14 +8,18 @@ def get_spark_session(): global builder # Only configure the builder once if builder is None: - builder = SparkSession.builder.appName( - 'DAT', - ).config( - 'spark.sql.extensions', - 'io.delta.sql.DeltaSparkSessionExtension', - ).config( - 'spark.sql.catalog.spark_catalog', - 'org.apache.spark.sql.delta.catalog.DeltaCatalog', + builder = ( + SparkSession.builder.appName( + "DAT", + ) + .config( + "spark.sql.extensions", + "io.delta.sql.DeltaSparkSessionExtension", + ) + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) ) builder = delta.configure_spark_with_delta_pip(builder) return builder.getOrCreate() diff --git a/tests/conftest.py b/tests/conftest.py index bd4e43e..48245b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,9 @@ from dat import spark_builder -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def spark_session(request): spark = spark_builder.get_spark_session() - spark.sparkContext.setLogLevel('ERROR') + spark.sparkContext.setLogLevel("ERROR") request.addfinalizer(spark.stop) return spark diff --git a/tests/pyspark_delta/test_pyspark_delta.py b/tests/pyspark_delta/test_pyspark_delta.py index d38c658..6e8d872 100644 --- a/tests/pyspark_delta/test_pyspark_delta.py +++ b/tests/pyspark_delta/test_pyspark_delta.py @@ -5,7 +5,7 @@ import chispa import pytest -TEST_ROOT = Path('out/reader_tests/') +TEST_ROOT = Path("out/reader_tests/") MAX_SUPPORTED_READER_VERSION = 2 @@ -22,42 +22,42 @@ class ReadCase(NamedTuple): cases: List[ReadCase] = [] -for path in (TEST_ROOT / 'generated').iterdir(): +for path in (TEST_ROOT / "generated").iterdir(): if path.is_dir(): - with open(path / 'test_case_info.json') as f: + with open(path / "test_case_info.json") as f: case_metadata = json.load(f) - for version_path in (path / 'expected').iterdir(): + for version_path in (path / "expected").iterdir(): if version_path.is_dir(): - if version_path.name[0] == 'v': + if version_path.name[0] == "v": version = int(version_path.name[1:]) - elif version_path.name == 'latest': + elif version_path.name == "latest": version = None else: continue - with open(version_path / 'table_version_metadata.json') as f: + with open(version_path / "table_version_metadata.json") as f: expected_metadata = json.load(f) case = ReadCase( - delta_root=path / 'delta', + delta_root=path / "delta", version=version, parquet_root=version_path / "table_content", - name=case_metadata['name'], - description=case_metadata['description'], - min_reader_version=expected_metadata['min_reader_version'], - min_writer_version=expected_metadata['min_writer_version'], + name=case_metadata["name"], + description=case_metadata["description"], + min_reader_version=expected_metadata["min_reader_version"], + min_writer_version=expected_metadata["min_writer_version"], ) cases.append(case) -@pytest.mark.parametrize('case', cases, - ids=lambda - case: f'{case.name} (version={case.version})') +@pytest.mark.parametrize( + "case", cases, ids=lambda case: f"{case.name} (version={case.version})" +) def test_readers_dat(spark_session, case: ReadCase): - query = spark_session.read.format('delta') + query = spark_session.read.format("delta") if case.version is not None: - query = query.option('versionAsOf', case.version) + query = query.option("versionAsOf", case.version) if case.min_reader_version > MAX_SUPPORTED_READER_VERSION: # If it's a reader version we don't support, assert failure @@ -66,12 +66,13 @@ def test_readers_dat(spark_session, case: ReadCase): else: actual_df = query.load(str(case.delta_root)) - expected_df = spark_session.read.format('parquet').load( - str(case.parquet_root) + '/*.parquet') + expected_df = spark_session.read.format("parquet").load( + str(case.parquet_root) + "/*.parquet" + ) - if 'pk' in actual_df.columns: - actual_df = actual_df.orderBy('pk') - expected_df = expected_df.orderBy('pk') + if "pk" in actual_df.columns: + actual_df = actual_df.orderBy("pk") + expected_df = expected_df.orderBy("pk") chispa.assert_df_equality(actual_df, expected_df) else: chispa.assert_df_equality(actual_df, expected_df, ignore_row_order=True)