diff --git a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala index 8b01c04..dee701e 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/daria/sql/types/StructTypeHelpers.scala @@ -40,43 +40,33 @@ object StructTypeHelpers { } private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = { - def handleNestedType(t: DataType, name: String, outerCol: Column, firstLevel: Boolean = false): Column = - t match { + def childFieldToCol(childFieldType: DataType, childFieldName: String, parentCol: Column, firstLevel: Boolean = false): Column = + childFieldType match { case st: StructType => struct( st.fields .sortBy(f) .map(field => - handleNestedType( + childFieldToCol( field.dataType, field.name, field.dataType match { - case StructType(_) | ArrayType(_: StructType, _) => outerCol(field.name) - case _ => outerCol + case StructType(_) | ArrayType(_: StructType, _) => parentCol(field.name) + case _ => parentCol } ).as(field.name) ): _* - ).as(name) - case ArrayType(_, _) => handleArrayType(t, name, outerCol).as(name) - case _ if firstLevel => outerCol - case _ if !firstLevel => outerCol(name) + ).as(childFieldName) + case ArrayType(innerType, _) => + transform(parentCol, childCol => childFieldToCol(innerType, childFieldName, childCol)).as(childFieldName) + case _ if firstLevel => parentCol + case _ if !firstLevel => parentCol(childFieldName) } - // For handling reordering of nested arrays - def handleArrayType(t: DataType, name: String, outer: Column): Column = - t match { - case ArrayType(innerType: ArrayType, _) => - transform(outer, inner => handleArrayType(innerType, name, inner)).as(name) - case ArrayType(innerType: StructType, _) => - transform(outer, inner => handleNestedType(innerType, name, inner).as(name)).as(name) - case _ => outer.as(name) - } - - val result = schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { + schema.fields.sortBy(f).foldLeft(Seq.empty[Column]) { case (acc, field) => - acc :+ handleNestedType(field.dataType, field.name, col(field.name), firstLevel = true) + acc :+ childFieldToCol(field.dataType, field.name, col(field.name), firstLevel = true) } - result } /**