Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not rename nested dafaframe #587

Merged
merged 3 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ParquetSuite extends IntegrationSuiteBase {
val test_special_char_to_exist: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_column_map_parquet: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_column_map_not_match: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_nested_dataframe: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString

override def afterAll(): Unit = {
runSql(s"drop table if exists $test_all_type")
Expand Down Expand Up @@ -548,9 +549,60 @@ class ParquetSuite extends IntegrationSuiteBase {
df1.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_not_match)
.mode(SaveMode.Append)
.save()
}
}

test("test nested dataframe"){

val data = sc.parallelize(
List(
Row(123, Array(1, 2, 3), Map("a" -> 1), Row("abc")),
Row(456, Array(4, 5, 6), Map("b" -> 2), Row("def")),
Row(789, Array(7, 8, 9), Map("c" -> 3), Row("ghi"))
)
)

val schema1 = new StructType(
Array(
StructField("NUM", IntegerType, nullable = true),
StructField("ARR", ArrayType(IntegerType), nullable = true),
StructField("MAP", MapType(StringType, IntegerType), nullable = true),
StructField(
"OBJ",
StructType(Array(StructField("str", StringType, nullable = true)))
)
)
)

val df = sparkSession.createDataFrame(data, schema1)

df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_nested_dataframe)
.mode(SaveMode.Overwrite)
.save()

val out = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_nested_dataframe)
.schema(schema1)
.load()
val result = out.collect()
assert(result.length == 3)

assert(result(0).getInt(0) == 123)
assert(result(0).getList[Int](1).get(0) == 1)
assert(result(1).getList[Int](1).get(1) == 5)
assert(result(2).getList[Int](1).get(2) == 9)
assert(result(1).getMap[String, Int](2)("b") == 2)
assert(result(2).getStruct(3).getString(0) == "ghi")
Copy link
Contributor

@sfc-gh-bli sfc-gh-bli Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's modify this test

  1. change name "STR" to "str"
  2. add test assert(df.collect().head.getAs[Row]("OBJ").getAs[String]("str") == "ghi")

assert(result(2).getAs[Row]("OBJ").getAs[String]("str") == "ghi")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,14 @@ object Parameters {
snowflakeTableSchema
}

def replaceSpecialCharacter(name: String): String = {
def replaceSpecialCharacter(name: String, toUpperCase: Boolean): String = {
var res = name.replaceAll("(^\\d|[^a-zA-Z0-9_])", "_")
while(stagingToSnowflakeColumnMap.contains(res.toUpperCase)){
res += "_"
}
res = res.toUpperCase
if (toUpperCase) {
res = res.toUpperCase
}
snowflakeToStagingColumnMap += (name -> res)
stagingToSnowflakeColumnMap += (res -> name)
res
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,23 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
* function that map spark style column name to snowflake style column name
*/
def mapColumn(schema: StructType,
params: MergedParameters
params: MergedParameters,
snowflakeStyle: Boolean
): StructType = {
params.columnMap match {
case Some(map) =>
StructType(schema.map {
case StructField(name, dataType, nullable, metadata) =>
StructField(
params.replaceSpecialCharacter(
snowflakeStyleString(map.getOrElse(name, name), params)),
if (snowflakeStyle) {
snowflakeStyleString(map.getOrElse(name, name), params)
} else {
map.getOrElse(name, name)
}, snowflakeStyle),
dataType match {
case datatype: StructType =>
mapColumn(datatype, params)
mapColumn(datatype, params, snowflakeStyle = false)
case _ =>
dataType
},
Expand All @@ -147,7 +152,11 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
})
case _ =>
val newSchema = if (params.snowflakeTableSchema == null) {
snowflakeStyleSchema(schema, params)
if (snowflakeStyle) {
snowflakeStyleSchema(schema, params)
} else {
schema
}
} else {
StructType(schema.zip(params.snowflakeTableSchema).map{
case (field1, field2) =>
Expand All @@ -163,10 +172,10 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
StructType(newSchema.map {
case StructField(name, dataType, nullable, metadata) =>
StructField(
params.replaceSpecialCharacter(name),
params.replaceSpecialCharacter(name, snowflakeStyle),
dataType match {
case datatype: StructType =>
mapColumn(datatype, params)
mapColumn(datatype, params, snowflakeStyle = false)
case _ =>
dataType
},
Expand All @@ -188,7 +197,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {

format match {
case SupportedFormat.PARQUET =>
val snowflakeStyleSchema = mapColumn(data.schema, params)
val snowflakeStyleSchema = mapColumn(data.schema, params, snowflakeStyle = true)
val schema = io.ParquetUtils.convertStructToAvro(snowflakeStyleSchema)
(data.rdd.map (row => {
def rowToAvroRecord(row: Row,
Expand Down
Loading