Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Parquet Writer tests on [databricks] 14.3 #11673

Merged
merged 3 commits into from
Nov 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@

pytestmark = pytest.mark.nightly_resource_consuming_test

conf_key_parquet_datetimeRebaseModeInWrite = 'spark.sql.parquet.datetimeRebaseModeInWrite'
conf_key_parquet_int96RebaseModeInWrite = 'spark.sql.parquet.int96RebaseModeInWrite'
conf_key_parquet_datetimeRebaseModeInRead = 'spark.sql.parquet.datetimeRebaseModeInRead'
conf_key_parquet_int96RebaseModeInRead = 'spark.sql.parquet.int96RebaseModeInRead'

# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for
# non-cloud
original_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'PERFILE'}
Expand All @@ -37,11 +42,8 @@
reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf,
coalesce_parquet_file_reader_conf]
parquet_decimal_struct_gen= StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(decimal_gens)])
legacy_parquet_datetimeRebaseModeInWrite='spark.sql.parquet.datetimeRebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite'
legacy_parquet_int96RebaseModeInWrite='spark.sql.parquet.int96RebaseModeInWrite' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInWrite'
legacy_parquet_int96RebaseModeInRead='spark.sql.parquet.int96RebaseModeInRead' if is_spark_400_or_later() else 'spark.sql.legacy.parquet.int96RebaseModeInRead'
writer_confs={legacy_parquet_datetimeRebaseModeInWrite: 'CORRECTED',
legacy_parquet_int96RebaseModeInWrite: 'CORRECTED'}
writer_confs={conf_key_parquet_datetimeRebaseModeInWrite: 'CORRECTED',
conf_key_parquet_int96RebaseModeInWrite: 'CORRECTED'}

parquet_basic_gen =[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, TimestampGen(), binary_gen]
Expand Down Expand Up @@ -161,8 +163,8 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase):
lambda spark, path: unary_op_df(spark, gen).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
data_path,
conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase,
legacy_parquet_int96RebaseModeInWrite: ts_rebase,
conf={conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase,
conf_key_parquet_int96RebaseModeInWrite: ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_type})


Expand Down Expand Up @@ -288,8 +290,8 @@ def test_write_sql_save_table(spark_tmp_path, parquet_gens, spark_tmp_table_fact

def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, int96_rebase, datetime_rebase, ts_write):
spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write)
spark.conf.set(legacy_parquet_datetimeRebaseModeInWrite, datetime_rebase)
spark.conf.set(legacy_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310
spark.conf.set(conf_key_parquet_datetimeRebaseModeInWrite, datetime_rebase)
spark.conf.set(conf_key_parquet_int96RebaseModeInWrite, int96_rebase) # for spark 310
with pytest.raises(Exception) as e_info:
df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get())
assert e_info.match(r".*SparkUpgradeException.*")
Expand Down Expand Up @@ -547,8 +549,8 @@ def generate_map_with_empty_validity(spark, path):
def test_parquet_write_fails_legacy_datetime(spark_tmp_path, data_gen, ts_write, ts_rebase_write):
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write,
legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write,
legacy_parquet_int96RebaseModeInWrite: ts_rebase_write}
conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase_write,
conf_key_parquet_int96RebaseModeInWrite: ts_rebase_write}
def writeParquetCatchException(spark, data_gen, data_path):
with pytest.raises(Exception) as e_info:
unary_op_df(spark, data_gen).coalesce(1).write.parquet(data_path)
Expand All @@ -566,12 +568,12 @@ def test_parquet_write_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, dat
ts_rebase_write, ts_rebase_read):
data_path = spark_tmp_path + '/PARQUET_DATA'
all_confs = {'spark.sql.parquet.outputTimestampType': ts_write,
legacy_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0],
legacy_parquet_int96RebaseModeInWrite: ts_rebase_write[1],
conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase_write[0],
conf_key_parquet_int96RebaseModeInWrite: ts_rebase_write[1],
# The rebase modes in read configs should be ignored and overridden by the same
# modes in write configs, which are retrieved from the written files.
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0],
legacy_parquet_int96RebaseModeInRead: ts_rebase_read[1]}
conf_key_parquet_datetimeRebaseModeInRead: ts_rebase_read[0],
conf_key_parquet_int96RebaseModeInRead: ts_rebase_read[1]}
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: unary_op_df(spark, data_gen).coalesce(1).write.parquet(path),
lambda spark, path: spark.read.parquet(path),
Expand Down Expand Up @@ -600,7 +602,7 @@ def test_it(spark):
spark.sql("CREATE TABLE {} LOCATION '{}/ctas' AS SELECT * FROM {}".format(
ctas_with_existing_name, data_path, src_name))
except pyspark.sql.utils.AnalysisException as e:
description = e._desc if is_spark_400_or_later() else e.desc
description = e._desc if (is_spark_400_or_later() or is_databricks_version_or_later(14, 3)) else e.desc
if allow_non_empty or description.find('non-empty directory') == -1:
raise e
with_gpu_session(test_it, conf)
Expand Down Expand Up @@ -829,8 +831,8 @@ def write_partitions(spark, table_path):
)

def hive_timestamp_value(spark_tmp_table_factory, spark_tmp_path, ts_rebase, func):
conf={legacy_parquet_datetimeRebaseModeInWrite: ts_rebase,
legacy_parquet_int96RebaseModeInWrite: ts_rebase}
conf={conf_key_parquet_datetimeRebaseModeInWrite: ts_rebase,
conf_key_parquet_int96RebaseModeInWrite: ts_rebase}

def create_table(spark, path):
tmp_table = spark_tmp_table_factory.get()
Expand Down
Loading