Skip to content

Commit

Permalink
Fix Parquet Writer tests on [databricks] 14.3 (#11673)
Browse files Browse the repository at this point in the history
* Fix Parquet Writer tests on Databricks 14.3

Fixes #11534.

This commit fixes the test failures in `parquet_write_test.py`, as
listed on #11534.

This is an extension of the changes made in #11615, which attempted to
address similar failures on Apache Spark 4.

Most of the test failures pertain to legacy Parquet writer settings and
conf keys which were removed on Spark 4.  A stray test involves a change
in how the test gleans description strings from exceptions.

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Nov 6, 2024
1 parent 6100334 commit 61acf56
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@

pytestmark = pytest.mark.nightly_resource_consuming_test

conf_key_parquet_datetimeRebaseModeInWrite = 'spark.sql.parquet.datetimeRebaseModeInWrite'
conf_key_parquet_int96RebaseModeInWrite = 'spark.sql.parquet.int96RebaseModeInWrite'
conf_key_parquet_datetimeRebaseModeInRead = 'spark.sql.parquet.datetimeRebaseModeInRead'
conf_key_parquet_int96RebaseModeInRead = 'spark.sql.parquet.int96RebaseModeInRead'

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
# non-cloud
original_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'PERFILE'}
Expand All @@ -37,11 +42,8 @@
reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf,
coalesce_parquet_file_reader_conf]
parquet_decimal_struct_gen= StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(decimal_gens)])
legacy_parquet_datetimeRebaseModeInWrite='spark.sql.parquet.datetimeRebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite'
legacy_parquet_int96RebaseModeInWrite='spark.sql.parquet.int96RebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInWrite'
legacy_parquet_int96RebaseModeInRead='spark.sql.parquet.int96RebaseModeInRead' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInRead'
writer_confs={legacy_parquet_datetimeRebaseModeInWrite: 'CORRECTED',
legacy_parquet_int96RebaseModeInWrite: 'CORRECTED'}
writer_confs={conf_key_parquet_datetimeRebaseModeInWrite: 'CORRECTED',
conf_key_parquet_int96RebaseModeInWrite: 'CORRECTED'}

parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, TimestampGen(), binary_gen]
Expand Down Expand Up @@ -161,8 +163,8 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase):
lambda spark, path: unary_op_df(spark, gen).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase,
legacy_parquet_int96RebaseModeInWrite: ts_rebase,
conf={conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase,
conf_key_parquet_int96RebaseModeInWrite: ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_type})


Expand Down Expand Up @@ -288,8 +290,8 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, spark_tmp_table_fact

def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write):
spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write)
spark.conf.set(legacy_parquet_datetimeRebaseModeInWrite, datetime_rebase)
spark.conf.set(legacy_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310
spark.conf.set(conf_key_parquet_datetimeRebaseModeInWrite, datetime_rebase)
spark.conf.set(conf_key_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310
with pytest.raises(Exception) as e_info:
df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get())
assert e_info.match(r".*SparkUpgradeException.*")
Expand Down Expand Up @@ -547,8 +549,8 @@ def generate_map_with_empty_validity(spark, path):
def test_parquet_write_fails_legacy_datetime(spark_tmp_path, data_gen, ts_write, ts_rebase_write):
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write,
legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write,
legacy_parquet_int96RebaseModeInWrite: ts_rebase_write}
conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase_write,
conf_key_parquet_int96RebaseModeInWrite: ts_rebase_write}
def writeParquetCatchException(spark, data_gen, data_path):
with pytest.raises(Exception) as e_info:
unary_op_df(spark, data_gen).coalesce(1).write.parquet(data_path)
Expand All @@ -566,12 +568,12 @@ def test_parquet_write_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, dat
ts_rebase_write, ts_rebase_read):
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write,
legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0],
legacy_parquet_int96RebaseModeInWrite: ts_rebase_write[1],
conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0],
conf_key_parquet_int96RebaseModeInWrite: ts_rebase_write[1],
# The rebase modes in read configs should be ignored and overridden by the same
# modes in write configs, which are retrieved from the written files.
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0],
legacy_parquet_int96RebaseModeInRead: ts_rebase_read[1]}
conf_key_parquet_datetimeRebaseModeInRead: ts_rebase_read[0],
conf_key_parquet_int96RebaseModeInRead: ts_rebase_read[1]}
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, data_gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
Expand Down Expand Up @@ -600,7 +602,7 @@ def test_it(spark):
spark.sql("CREATE TABLE {} LOCATION '{}/ctas' AS SELECT * FROM {}".format(
ctas_with_existing_name, data_path, src_name))
except pyspark.sql.utils.AnalysisException as e:
description = e._desc if is_spark_400_or_later() else e.desc
description = e._desc if (is_spark_400_or_later() or is_databricks_version_or_later(14, 3)) else e.desc
if allow_non_empty or description.find('non-empty directory') == -1:
raise e
with_gpu_session(test_it, conf)
Expand Down Expand Up @@ -829,8 +831,8 @@ def write_partitions(spark, table_path):
)

def hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func):
conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase,
legacy_parquet_int96RebaseModeInWrite: ts_rebase}
conf={conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase,
conf_key_parquet_int96RebaseModeInWrite: ts_rebase}

def create_table(spark, path):
tmp_table = spark_tmp_table_factory.get()
Expand Down

0 comments on commit 61acf56

Please sign in to comment.