diff --git a/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index e64077f..f2a9a8b 100644 --- a/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -40,31 +40,52 @@ object StructTypeHelpers { } private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A, baseField: String = "")(implicit ord: Ordering[A]): Seq[Column] = { - def handleDataType(t: DataType, colName: String, simpleName: String): Column = + def handleNestedType(t: DataType, name: String, outerCol: Column): Column = t match { case st: StructType => - struct(schemaToSortedSelectExpr(st, f, colName): _*).as(simpleName) + val sortedFields = st.fields.sortBy(f) + struct( + sortedFields.map(f => + f.dataType match { + case st: StructType => + handleNestedType(st, f.name, outerCol(f.name)).as(f.name) + case ArrayType(innerType: StructType, _) => + handleArrayType(f.dataType, name, outerCol(f.name)).as(f.name) + case _ => + handleNestedType(f.dataType, f.name, outerCol).as(f.name) + } + ): _* + ).as(name) case ArrayType(_, _) => - handleArrayType(t, col(colName), simpleName).as(simpleName) + handleArrayType(t, name, outerCol).as(name) case _ => - col(colName) + outerCol(name) } // For handling reordering of nested arrays - def handleArrayType(t: DataType, innerCol: Column, simpleName: String): Column = + def handleArrayType(t: DataType, name: String, outer: Column): Column = t match { - case ArrayType(innerType: ArrayType, _) => transform(innerCol, outer => handleArrayType(innerType, outer, simpleName)) + case ArrayType(innerType: ArrayType, _) => + transform(outer, inner => handleArrayType(innerType, name, inner)).as(name) case ArrayType(innerType: StructType, _) => - val cols = schemaToSortedSelectExpr(innerType, f) - transform(innerCol, innerCol1 => struct(cols.map(c => innerCol1.getField(c.toString).as(c.toString)): _*)) - case _ => innerCol + transform(outer, inner => handleNestedType(innerType, name, inner).as(name)).as(name) + case _ => outer.as(name) } - schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { + val result = schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case (acc, field) => - val colName = if (baseField.isEmpty) field.name else s"$baseField.${field.name}" - acc :+ handleDataType(field.dataType, colName, field.name) + val name = field.name + val sortedCol = field.dataType match { + case st: StructType => + handleNestedType(st, name, col(name)) + case arr: ArrayType => + handleArrayType(arr, name, col(name)) + case _ => col(name) + } + + acc :+ sortedCol } + result } /**