Skip to content

Commit

Permalink
[SPARK-50979][CONNECT] Remove .expr/.typedExpr implicits
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR removed the .expr/.typedExpr Column conversion implicits from the Connect client.

### Why are the changes needed?
Code clean-up.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #49657 from hvanhovell/SPARK-50979.

Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
hvanhovell committed Jan 29, 2025
1 parent 840f74a commit 6bbfa2d
Show file tree
Hide file tree
Showing 12 changed files with 72 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ object CurrentOrigin {
}
}

private val sparkCodePattern = Pattern.compile("(org\\.apache\\.spark\\.sql\\." +
"(?:(classic|connect)\\.)?" +
"(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder)" +
"(?:|\\..*|\\$.*))" +
"|(scala\\.collection\\..*)")
private val sparkCodePattern = Pattern.compile(
"(org\\.apache\\.spark\\.sql\\." +
"(?:(classic|connect)\\.)?" +
"(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions|DatasetHolder" +
"|SparkSession|ColumnNodeToProtoConverter)" +
"(?:|\\..*|\\$.*))" +
"|(scala\\.collection\\..*)")

private def sparkCode(ste: StackTraceElement): Boolean = {
sparkCodePattern.matcher(ste.getClassName).matches()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import org.apache.spark.connect.proto.{NAReplace, Relation}
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
import org.apache.spark.connect.proto.NAReplace.Replacement
import org.apache.spark.sql
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.functions

/**
* Functionality for working with missing data in `DataFrame`s.
Expand All @@ -33,7 +33,6 @@ import org.apache.spark.sql.functions
*/
final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation)
extends sql.DataFrameNaFunctions {
import sparkSession.RichColumn

override protected def drop(minNonNulls: Option[Int]): DataFrame =
buildDropDataFrame(None, minNonNulls)
Expand Down Expand Up @@ -103,7 +102,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
sparkSession.newDataFrame { builder =>
val fillNaBuilder = builder.getFillNaBuilder.setInput(root)
values.map { case (colName, replaceValue) =>
fillNaBuilder.addCols(colName).addValues(functions.lit(replaceValue).expr.getLiteral)
fillNaBuilder.addCols(colName).addValues(toLiteral(replaceValue).getLiteral)
}
}
}
Expand Down Expand Up @@ -143,8 +142,8 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
replacementMap.map { case (oldValue, newValue) =>
Replacement
.newBuilder()
.setOldValue(functions.lit(oldValue).expr.getLiteral)
.setNewValue(functions.lit(newValue).expr.getLiteral)
.setOldValue(toLiteral(oldValue).getLiteral)
.setNewValue(toLiteral(newValue).getLiteral)
.build()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.DataFrameStatFunctions.approxQuantileResultEncoder
import org.apache.spark.sql.functions.lit

/**
* Statistic functions for `DataFrame`s.
Expand Down Expand Up @@ -120,20 +120,19 @@ final class DataFrameStatFunctions private[sql] (protected val df: DataFrame)

/** @inheritdoc */
def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = {
import sparkSession.RichColumn
require(
fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
sparkSession.newDataFrame { builder =>
val sampleByBuilder = builder.getSampleByBuilder
.setInput(root)
.setCol(col.expr)
.setCol(toExpr(col))
.setSeed(seed)
fractions.foreach { case (k, v) =>
sampleByBuilder.addFractions(
StatSampleBy.Fraction
.newBuilder()
.setStratum(lit(k).expr.getLiteral)
.setStratum(toLiteral(k).getLiteral)
.setFraction(v))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.connect.proto
import org.apache.spark.sql
import org.apache.spark.sql.Column
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr

/**
* Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2
Expand All @@ -33,7 +34,6 @@ import org.apache.spark.sql.Column
@Experimental
final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
extends sql.DataFrameWriterV2[T] {
import ds.sparkSession.RichColumn

private val builder = proto.WriteOperationV2
.newBuilder()
Expand Down Expand Up @@ -73,7 +73,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
/** @inheritdoc */
@scala.annotation.varargs
override def partitionedBy(column: Column, columns: Column*): this.type = {
builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava)
builder.addAllPartitioningColumns((column +: columns).map(toExpr).asJava)
this
}

Expand Down Expand Up @@ -106,7 +106,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])

/** @inheritdoc */
def overwrite(condition: Column): Unit = {
builder.setOverwriteCondition(condition.expr)
builder.setOverwriteCondition(toExpr(condition))
executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
Expand Down Expand Up @@ -140,7 +141,6 @@ class Dataset[T] private[sql] (
@DeveloperApi val plan: proto.Plan,
val encoder: Encoder[T])
extends sql.Dataset[T] {
import sparkSession.RichColumn

// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)
Expand Down Expand Up @@ -336,7 +336,7 @@ class Dataset[T] private[sql] (
buildJoin(right, Seq(joinExprs)) { builder =>
builder
.setJoinType(toJoinType(joinType))
.setJoinCondition(joinExprs.expr)
.setJoinCondition(toExpr(joinExprs))
}
}

Expand Down Expand Up @@ -375,7 +375,7 @@ class Dataset[T] private[sql] (
.setLeft(plan.getRoot)
.setRight(other.plan.getRoot)
.setJoinType(joinTypeValue)
.setJoinCondition(condition.expr)
.setJoinCondition(toExpr(condition))
.setJoinDataType(joinBuilder.getJoinDataTypeBuilder
.setIsLeftStruct(this.agnosticEncoder.isStruct)
.setIsRightStruct(other.agnosticEncoder.isStruct))
Expand All @@ -396,7 +396,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(joinExprs.toSeq) { builder =>
val lateralJoinBuilder = builder.getLateralJoinBuilder
lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr))
joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(toExpr(c)))
lateralJoinBuilder.setJoinType(joinTypeValue)
}
}
Expand Down Expand Up @@ -440,7 +440,7 @@ class Dataset[T] private[sql] (
builder.getHintBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
.addAllParameters(parameters.map(p => toLiteral(p)).asJava)
}

private def getPlanId: Option[Long] =
Expand Down Expand Up @@ -486,7 +486,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(encoder) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addExpressions(col.typedExpr(this.encoder))
.addExpressions(toTypedExpr(col, this.encoder))
}
}

Expand All @@ -504,14 +504,14 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(encoder, cols) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava)
.addAllExpressions(cols.map(c => toTypedExpr(c, this.encoder)).asJava)
}
}

/** @inheritdoc */
def filter(condition: Column): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(toExpr(condition))
}
}

Expand All @@ -523,12 +523,12 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) { builder =>
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
.addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
.addAllIds(ids.toImmutableArraySeq.map(toExpr).asJava)
.setVariableColumnName(variableColumnName)
.setValueColumnName(valueColumnName)
valuesOption.foreach { values =>
unpivot.getValuesBuilder
.addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
.addAllValues(values.toImmutableArraySeq.map(toExpr).asJava)
}
}
}
Expand All @@ -537,7 +537,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(indices) { builder =>
val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
indices.foreach { indexColumn =>
transpose.addIndexColumns(indexColumn.expr)
transpose.addIndexColumns(toExpr(indexColumn))
}
}

Expand All @@ -553,7 +553,7 @@ class Dataset[T] private[sql] (
function = func,
inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil,
outputEncoder = agnosticEncoder)
val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr
val reduceExpr = toExpr(Column.fn("reduce", udf.apply(col("*"), col("*"))))

val result = sparkSession
.newDataset(agnosticEncoder) { builder =>
Expand Down Expand Up @@ -590,7 +590,7 @@ class Dataset[T] private[sql] (
val groupingSetMsgs = groupingSets.map { groupingSet =>
val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
for (groupCol <- groupingSet) {
groupingSetMsg.addGroupingSet(groupCol.expr)
groupingSetMsg.addGroupingSet(toExpr(groupCol))
}
groupingSetMsg.build()
}
Expand Down Expand Up @@ -779,7 +779,7 @@ class Dataset[T] private[sql] (
s"The size of column names: ${names.size} isn't equal to " +
s"the size of columns: ${values.size}")
val aliases = values.zip(names).map { case (value, name) =>
value.name(name).expr.getAlias
toExpr(value.name(name)).getAlias
}
sparkSession.newDataFrame(values) { builder =>
builder.getWithColumnsBuilder
Expand Down Expand Up @@ -812,7 +812,7 @@ class Dataset[T] private[sql] (
def withMetadata(columnName: String, metadata: Metadata): DataFrame = {
val newAlias = proto.Expression.Alias
.newBuilder()
.setExpr(col(columnName).expr)
.setExpr(toExpr(col(columnName)))
.addName(columnName)
.setMetadata(metadata.json)
sparkSession.newDataFrame { builder =>
Expand Down Expand Up @@ -845,7 +845,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(cols) { builder =>
builder.getDropBuilder
.setInput(plan.getRoot)
.addAllColumns(cols.map(_.expr).asJava)
.addAllColumns(cols.map(toExpr).asJava)
}
}

Expand Down Expand Up @@ -915,7 +915,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset[T](agnosticEncoder) { builder =>
builder.getFilterBuilder
.setInput(plan.getRoot)
.setCondition(udf.apply(col("*")).expr)
.setCondition(toExpr(udf.apply(col("*"))))
}
}

Expand Down Expand Up @@ -944,7 +944,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(outputEncoder) { builder =>
builder.getMapPartitionsBuilder
.setInput(plan.getRoot)
.setFunc(udf.apply(col("*")).expr.getCommonInlineUserDefinedFunction)
.setFunc(toExpr(udf.apply(col("*"))).getCommonInlineUserDefinedFunction)
}
}

Expand Down Expand Up @@ -1020,7 +1020,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
.addAllPartitionExprs(partitionExprs.map(toExpr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
}
}
Expand All @@ -1036,7 +1036,7 @@ class Dataset[T] private[sql] (
// The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
// However, we don't want to complicate the semantics of this API method.
// Instead, let's give users a friendly error message, pointing them to the new method.
val sortOrders = partitionExprs.filter(_.expr.hasSortOrder)
val sortOrders = partitionExprs.filter(e => toExpr(e).hasSortOrder)
if (sortOrders.nonEmpty) {
throw new IllegalArgumentException(
s"Invalid partitionExprs specified: $sortOrders\n" +
Expand All @@ -1050,7 +1050,7 @@ class Dataset[T] private[sql] (
partitionExprs: Seq[Column]): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortExprs = partitionExprs.map {
case e if e.expr.hasSortOrder => e
case e if toExpr(e).hasSortOrder => e
case e => e.asc
}
buildRepartitionByExpression(numPartitions, sortExprs)
Expand Down Expand Up @@ -1158,7 +1158,7 @@ class Dataset[T] private[sql] (
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllMetrics((expr +: exprs).map(_.expr).asJava)
.addAllMetrics((expr +: exprs).map(toExpr).asJava)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql
import org.apache.spark.sql.{Column, Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfUtils}
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
Expand Down Expand Up @@ -394,7 +394,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
private val valueMapFunc: Option[IV => V],
private val keysFunc: () => Dataset[IK])
extends KeyValueGroupedDataset[K, V] {
import sparkSession.RichColumn

override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
new KeyValueGroupedDatasetImpl[L, V, IK, IV](
Expand Down Expand Up @@ -436,7 +435,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
sparkSession.newDataset[U](outputEncoder) { builder =>
builder.getGroupMapBuilder
.setInput(plan.getRoot)
.addAllSortingExpressions(sortExprs.map(e => e.expr).asJava)
.addAllSortingExpressions(sortExprs.map(toExpr).asJava)
.addAllGroupingExpressions(groupingExprs)
.setFunc(getUdf(nf, outputEncoder)(ivEncoder))
}
Expand All @@ -453,10 +452,10 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
builder.getCoGroupMapBuilder
.setInput(plan.getRoot)
.addAllInputGroupingExpressions(groupingExprs)
.addAllInputSortingExpressions(thisSortExprs.map(e => e.expr).asJava)
.addAllInputSortingExpressions(thisSortExprs.map(toExpr).asJava)
.setOther(otherImpl.plan.getRoot)
.addAllOtherGroupingExpressions(otherImpl.groupingExprs)
.addAllOtherSortingExpressions(otherSortExprs.map(e => e.expr).asJava)
.addAllOtherSortingExpressions(otherSortExprs.map(toExpr).asJava)
.setFunc(getUdf(nf, outputEncoder)(ivEncoder, otherImpl.ivEncoder))
}
}
Expand All @@ -469,7 +468,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
.setInput(plan.getRoot)
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
.addAllGroupingExpressions(groupingExprs)
.addAllAggregateExpressions(columns.map(_.typedExpr(vEncoder)).asJava)
.addAllAggregateExpressions(columns.map(c => toTypedExpr(c, vEncoder)).asJava)
}
}

Expand Down Expand Up @@ -534,7 +533,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
function = nf,
inputEncoders = inputEncoders,
outputEncoder = outputEncoder)
udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction
toExpr(udf.apply(inputEncoders.map(_ => col("*")): _*)).getCommonInlineUserDefinedFunction
}

private def getUdf[U: Encoder, S: Encoder](
Expand All @@ -549,7 +548,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
function = nf,
inputEncoders = inputEncoders,
outputEncoder = outputEncoder)
udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction
toExpr(udf.apply(inputEncoders.map(_ => col("*")): _*)).getCommonInlineUserDefinedFunction
}

/**
Expand Down
Loading

0 comments on commit 6bbfa2d

Please sign in to comment.