Skip to content

Commit

Permalink
Do not rename nested dafaframe (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yuwang authored Oct 2, 2024
1 parent 176fbaf commit 3ce554e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
52 changes: 52 additions & 0 deletions src/it/scala/net/snowflake/spark/snowflake/ParquetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ParquetSuite extends IntegrationSuiteBase {
val test_special_char_to_exist: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_column_map_parquet: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_column_map_not_match: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString
val test_nested_dataframe: String = Random.alphanumeric.filter(_.isLetter).take(10).mkString

override def afterAll(): Unit = {
runSql(s"drop table if exists $test_all_type")
Expand Down Expand Up @@ -548,9 +549,60 @@ class ParquetSuite extends IntegrationSuiteBase {
df1.write
.format(SNOWFLAKE_SOURCE_SHORT_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_column_map_not_match)
.mode(SaveMode.Append)
.save()
}
}

test("test nested dataframe"){

val data = sc.parallelize(
List(
Row(123, Array(1, 2, 3), Map("a" -> 1), Row("abc")),
Row(456, Array(4, 5, 6), Map("b" -> 2), Row("def")),
Row(789, Array(7, 8, 9), Map("c" -> 3), Row("ghi"))
)
)

val schema1 = new StructType(
Array(
StructField("NUM", IntegerType, nullable = true),
StructField("ARR", ArrayType(IntegerType), nullable = true),
StructField("MAP", MapType(StringType, IntegerType), nullable = true),
StructField(
"OBJ",
StructType(Array(StructField("str", StringType, nullable = true)))
)
)
)

val df = sparkSession.createDataFrame(data, schema1)

df.write
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option(Parameters.PARAM_USE_PARQUET_IN_WRITE, "true")
.option("dbtable", test_nested_dataframe)
.mode(SaveMode.Overwrite)
.save()

val out = sparkSession.read
.format(SNOWFLAKE_SOURCE_NAME)
.options(connectorOptionsNoTable)
.option("dbtable", test_nested_dataframe)
.schema(schema1)
.load()
val result = out.collect()
assert(result.length == 3)

assert(result(0).getInt(0) == 123)
assert(result(0).getList[Int](1).get(0) == 1)
assert(result(1).getList[Int](1).get(1) == 5)
assert(result(2).getList[Int](1).get(2) == 9)
assert(result(1).getMap[String, Int](2)("b") == 2)
assert(result(2).getStruct(3).getString(0) == "ghi")
assert(result(2).getAs[Row]("OBJ").getAs[String]("str") == "ghi")
}
}
6 changes: 4 additions & 2 deletions src/main/scala/net/snowflake/spark/snowflake/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,14 @@ object Parameters {
snowflakeTableSchema
}

def replaceSpecialCharacter(name: String): String = {
def replaceSpecialCharacter(name: String, toUpperCase: Boolean): String = {
var res = name.replaceAll("(^\\d|[^a-zA-Z0-9_])", "_")
while(stagingToSnowflakeColumnMap.contains(res.toUpperCase)){
res += "_"
}
res = res.toUpperCase
if (toUpperCase) {
res = res.toUpperCase
}
snowflakeToStagingColumnMap += (name -> res)
stagingToSnowflakeColumnMap += (res -> name)
res
Expand Down
23 changes: 16 additions & 7 deletions src/main/scala/net/snowflake/spark/snowflake/SnowflakeWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,23 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
* function that map spark style column name to snowflake style column name
*/
def mapColumn(schema: StructType,
params: MergedParameters
params: MergedParameters,
snowflakeStyle: Boolean
): StructType = {
params.columnMap match {
case Some(map) =>
StructType(schema.map {
case StructField(name, dataType, nullable, metadata) =>
StructField(
params.replaceSpecialCharacter(
snowflakeStyleString(map.getOrElse(name, name), params)),
if (snowflakeStyle) {
snowflakeStyleString(map.getOrElse(name, name), params)
} else {
map.getOrElse(name, name)
}, snowflakeStyle),
dataType match {
case datatype: StructType =>
mapColumn(datatype, params)
mapColumn(datatype, params, snowflakeStyle = false)
case _ =>
dataType
},
Expand All @@ -147,7 +152,11 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
})
case _ =>
val newSchema = if (params.snowflakeTableSchema == null) {
snowflakeStyleSchema(schema, params)
if (snowflakeStyle) {
snowflakeStyleSchema(schema, params)
} else {
schema
}
} else {
StructType(schema.zip(params.snowflakeTableSchema).map{
case (field1, field2) =>
Expand All @@ -163,10 +172,10 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {
StructType(newSchema.map {
case StructField(name, dataType, nullable, metadata) =>
StructField(
params.replaceSpecialCharacter(name),
params.replaceSpecialCharacter(name, snowflakeStyle),
dataType match {
case datatype: StructType =>
mapColumn(datatype, params)
mapColumn(datatype, params, snowflakeStyle = false)
case _ =>
dataType
},
Expand All @@ -188,7 +197,7 @@ private[snowflake] class SnowflakeWriter(jdbcWrapper: JDBCWrapper) {

format match {
case SupportedFormat.PARQUET =>
val snowflakeStyleSchema = mapColumn(data.schema, params)
val snowflakeStyleSchema = mapColumn(data.schema, params, snowflakeStyle = true)
val schema = io.ParquetUtils.convertStructToAvro(snowflakeStyleSchema)
(data.rdd.map (row => {
def rowToAvroRecord(row: Row,
Expand Down

0 comments on commit 3ce554e

Please sign in to comment.