Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sortColumnsBy DataFraneExt #162

Merged
merged 14 commits into from
Oct 8, 2024
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package com.github.mrpowers.spark.daria.sql

import com.github.mrpowers.spark.daria.sql.types.StructTypeHelpers
import com.github.mrpowers.spark.daria.sql.types.StructTypeHelpers.StructTypeOps
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame}

import scala.collection.mutable

case class DataFrameColumnsException(smth: String) extends Exception(smth)

object DataFrameExt {
Expand Down Expand Up @@ -407,6 +410,43 @@ object DataFrameExt {
StructType(loop(df.schema))
)
}
}

/**
* Sorts this DataFrame columns order according to the Ordering which results from transforming
* an implicitly given Ordering with a transformation function.
* This function will also sort [[StructType]] columns and [[ArrayType]]([[StructType]]) columns.
* @see [[scala.math.Ordering]]
* @param f the transformation function mapping elements of type [[StructField]]
* to some other domain `A`.
* @param ord the ordering assumed on domain `A`.
* @tparam A the target type of the transformation `f`, and the type where
* the ordering `ord` is defined.
* @return a DataFrame consisting of the fields of this DataFrame
* sorted according to the ordering where `x < y` if
* `ord.lt(f(x), f(y))`.
*
* @example {{{
* // Example DataFrame
* val df = spark.createDataFrame(
* Seq(
* ("John", 30, 2000.0),
* ("Jane", 25, 3000.0)
* )
* ).toDF("name", "age", "salary")
*
* // Sort columns by name
* val sortedByNameDF = df.sortColumnsBy(_.name)
* sortedByNameDF.show()
* // Output:
* // +---+----+------+
* // |age|name|salary|
* // +---+----+------+
* // | 30|John|2000.0|
* // | 25|Jane|3000.0|
* // +---+----+------+
* }}}
*/
def sortColumnsBy[A](f: StructField => A)(implicit ord: Ordering[A]): DataFrame =
df.select(df.schema.toSortedSelectExpr(f): _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package com.github.mrpowers.spark.daria.sql.types

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.functions._

import scala.annotation.tailrec
import scala.reflect.runtime.universe._

object StructTypeHelpers {
Expand Down Expand Up @@ -38,6 +39,33 @@ object StructTypeHelpers {
})
}

private def schemaToSortedSelectExpr[A](schema: StructType, f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = {
def childFieldToCol(childFieldType: DataType, childFieldName: String, parentCol: Column, firstLevel: Boolean = false): Column =
childFieldType match {
case st: StructType =>
struct(
st.fields
.sortBy(f)
.map(field =>
childFieldToCol(
field.dataType,
field.name,
field.dataType match {
case StructType(_) | ArrayType(_: StructType, _) => parentCol(field.name)
case _ => parentCol
}
).as(field.name)
): _*
).as(childFieldName)
case ArrayType(innerType, _) =>
transform(parentCol, childCol => childFieldToCol(innerType, childFieldName, childCol)).as(childFieldName)
case _ if firstLevel => parentCol
case _ if !firstLevel => parentCol(childFieldName)
}

schema.fields.sortBy(f).map(field => childFieldToCol(field.dataType, field.name, col(field.name), firstLevel = true))
}

/**
* gets a StructType from a Scala type and
* transforms field names from camel case to snake case
Expand All @@ -50,4 +78,7 @@ object StructTypeHelpers {
})
}

implicit class StructTypeOps(schema: StructType) {
def toSortedSelectExpr[A](f: StructField => A)(implicit ord: Ordering[A]): Seq[Column] = schemaToSortedSelectExpr(schema, f)
}
}
Loading
Loading