Skip to content

Commit

Permalink
[SPARK-49025] Make Column implementation agnostic (#3913)
Browse files Browse the repository at this point in the history
#### Which Delta project/connector is this regarding?
- [x] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description
This PR ports the changes made in SPARK-49025 to Delta.

## How was this patch tested?
Existing tests.

## Does this PR introduce _any_ user-facing changes?
No.
  • Loading branch information
hvanhovell authored Jan 13, 2025
1 parent a920885 commit 95e1826
Show file tree
Hide file tree
Showing 41 changed files with 138 additions and 43 deletions.
28 changes: 28 additions & 0 deletions spark/src/main/scala-spark-3.5/shims/ColumnConversionShims.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.Expression

/**
* Conversions from a [[org.apache.spark.sql.Column]] to an
* [[org.apache.spark.sql.catalyst.expressions.Expression]], and vice versa.
*/
object ClassicColumnConversions {
def expression(c: Column): Expression = c.expr
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta

import org.apache.spark.sql.classic.ClassicConversions
import org.apache.spark.sql.classic.ColumnConversions

/**
* Conversions from a [[org.apache.spark.sql.Column]] to an
* [[org.apache.spark.sql.catalyst.expressions.Expression]], and vice versa.
*
* @note [[org.apache.spark.sql.internal.ExpressionUtils#expression]] is a cheap alternative for
* [[org.apache.spark.sql.Column]] to [[org.apache.spark.sql.catalyst.expressions.Expression]]
* conversions. However this can only be used when the produced expression is used in a Column
* later on.
*/
object ClassicColumnConversions
extends ClassicConversions
with ColumnConversions
4 changes: 2 additions & 2 deletions spark/src/main/scala/io/delta/tables/DeltaMergeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ import scala.collection.JavaConverters._
import scala.collection.Map

import org.apache.spark.sql.delta.{DeltaAnalysisException, PostHocResolveUpCast, PreprocessTableMerge, ResolveDeltaMergeInto}
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.DeltaViewHelper
import org.apache.spark.sql.delta.commands.MergeIntoCommand
import org.apache.spark.sql.delta.util.AnalysisHelper

import org.apache.spark.annotation._
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.internal.SQLConf
Expand Down
1 change: 1 addition & 0 deletions spark/src/main/scala/io/delta/tables/DeltaTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.delta.tables
import scala.collection.JavaConverters._

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils
import org.apache.spark.sql.delta.catalog.DeltaTableV2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.collection.Map

import org.apache.spark.sql.catalyst.TimeTravel
import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog}
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.DeltaTableUtils.withActiveSession
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.commands.{DeltaGenerateCommand, DescribeDeltaDetailCommand, VacuumCommand}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.util.Try
import scala.util.control.NonFatal

// scalastyle:off import.ordering.noEmptyLine
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.{Action, CheckpointMetadata, Metadata, SidecarFile, SingleAction}
import org.apache.spark.sql.delta.logging.DeltaLogKeys
import org.apache.spark.sql.delta.metering.DeltaLogging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta
// scalastyle:off import.ordering.noEmptyLine
import scala.collection.mutable

import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.{Metadata, Protocol}
import org.apache.spark.sql.delta.commands.cdc.CDCReader
import org.apache.spark.sql.delta.constraints.{Constraint, Constraints}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.util.Try
import scala.util.control.NonFatal

import com.databricks.spark.util.TagDefinitions._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.commands.WriteIntoDelta
import org.apache.spark.sql.delta.coordinatedcommits.CoordinatedCommitsUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta
// scalastyle:off import.ordering.noEmptyLine
import java.util.Locale

import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.{Metadata, Protocol}
import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeFileIndex}
import org.apache.spark.sql.delta.metering.DeltaLogging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta

import scala.collection.mutable

import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.sources.DeltaSourceUtils._
import org.apache.spark.sql.delta.sources.DeltaSQLConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.control.NonFatal

import com.databricks.spark.util.TagDefinitions.TAG_LOG_STORE_CLASS
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.DeltaOperations.{ChangeColumn, CreateTable, Operation, ReplaceColumns, ReplaceTable, UpdateSchema}
import org.apache.spark.sql.delta.RowId.RowTrackingMetadataDomain
import org.apache.spark.sql.delta.actions._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.commands.DeletionVectorUtils
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.util.{DeltaFileOperations, JsonUtils, Utils => DeltaUtils}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.generic.Sizing

import org.apache.spark.sql.catalyst.expressions.aggregation.BitmapAggregator
import org.apache.spark.sql.delta.{DeltaLog, DeltaParquetFileFormat, OptimisticTransaction, Snapshot}
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.DeltaParquetFileFormat._
import org.apache.spark.sql.delta.actions.{AddFile, DeletionVectorDescriptor, FileAction}
import org.apache.spark.sql.delta.deletionvectors.{RoaringBitmapArray, RoaringBitmapArrayFormat, StoredBitmap}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit

import org.apache.spark.sql.delta.metric.IncrementMetric
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.{Action, AddCDCFile, AddFile, FileAction}
import org.apache.spark.sql.delta.commands.DeleteCommand.{rewritingFilesMsg, FINDING_TOUCHED_FILES_MSG}
import org.apache.spark.sql.delta.commands.MergeIntoCommandBase.totalBytesAndDistinctPartitionValues
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit

import org.apache.spark.sql.delta.metric.IncrementMetric
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction}
import org.apache.spark.sql.delta.commands.cdc.CDCReader.{CDC_TYPE_COLUMN_NAME, CDC_TYPE_NOT_CDC, CDC_TYPE_UPDATE_POSTIMAGE, CDC_TYPE_UPDATE_PREIMAGE}
import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeFileIndex}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.collection.mutable

// scalastyle:off import.ordering.noEmptyLine
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.commands.DMLUtils.TaggedCommitData
import org.apache.spark.sql.delta.commands.cdc.CDCReader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.util.control.NonFatal
import org.apache.spark.sql.delta.skipping.clustering.ClusteredTableUtils
import org.apache.spark.sql.delta.skipping.clustering.ClusteringColumnInfo
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.Protocol
import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils
import org.apache.spark.sql.delta.catalog.DeltaTableV2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.{ListBuffer, Map => MutableMap}
import scala.util.Try

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.commands.DeletionVectorUtils
import org.apache.spark.sql.delta.deletionvectors.{RoaringBitmapArray, RoaringBitmapArrayFormat}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta.commands.merge
import scala.collection.JavaConverters._

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction}
import org.apache.spark.sql.delta.commands.{DeletionVectorBitmapGenerator, DMLWithDeletionVectorsHelper, MergeIntoCommandBase}
import org.apache.spark.sql.delta.commands.cdc.CDCReader.{CDC_TYPE_COLUMN_NAME, CDC_TYPE_NOT_CDC}
Expand Down Expand Up @@ -246,7 +247,7 @@ trait ClassicMergeExecutor extends MergeOutputGeneration {
*/
protected def generateFilterForModifiedRows(): Expression = {
val matchedExpression = if (matchedClauses.nonEmpty) {
And(Column(condition).expr, clauseDisjunction(matchedClauses))
And(condition, clauseDisjunction(matchedClauses))
} else {
Literal.FalseLiteral
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.commands.merge

import org.apache.spark.sql.delta.metric.IncrementMetric
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.{AddFile, FileAction}
import org.apache.spark.sql.delta.commands.MergeIntoCommandBase

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta.commands.merge
import scala.collection.mutable

import org.apache.spark.sql.delta.{RowCommitVersion, RowId}
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.commands.MergeIntoCommandBase
import org.apache.spark.sql.delta.commands.cdc.CDCReader

Expand Down Expand Up @@ -164,8 +165,8 @@ trait MergeOutputGeneration { self: MergeIntoCommandBase =>
// That is, conditionally invokes them based on whether there was a match in the outer join.

// Predicates to check whether there was a match in the full outer join.
val ifSourceRowNull = col(SOURCE_ROW_PRESENT_COL).isNull.expr
val ifTargetRowNull = col(TARGET_ROW_PRESENT_COL).isNull.expr
val ifSourceRowNull = expression(col(SOURCE_ROW_PRESENT_COL).isNull)
val ifTargetRowNull = expression(col(TARGET_ROW_PRESENT_COL).isNull)

val outputCols = targetWriteColNames.zipWithIndex.map { case (name, i) =>
// Coupled with the clause conditions, the resultant possibly-nested CaseWhens can
Expand Down Expand Up @@ -213,7 +214,8 @@ trait MergeOutputGeneration { self: MergeIntoCommandBase =>
Column(Alias(caseWhen, name)())
}
}
logDebug("writeAllChanges: join output expressions\n\t" + seqToString(outputCols.map(_.expr)))
logDebug("writeAllChanges: join output expressions\n\t" + seqToString(
outputCols.map(expression)))
outputCols
}.toIndexedSeq

Expand Down Expand Up @@ -421,36 +423,36 @@ trait MergeOutputGeneration { self: MergeIntoCommandBase =>

val cdcTypeCol = outputCols.last
val cdcArray = Column(CaseWhen(Seq(
EqualNullSafe(cdcTypeCol.expr, Literal(CDC_TYPE_INSERT)) -> array(
struct(outputCols: _*)).expr,
EqualNullSafe(expression(cdcTypeCol), Literal(CDC_TYPE_INSERT)) -> expression(array(
struct(outputCols: _*))),

EqualNullSafe(cdcTypeCol.expr, Literal(CDC_TYPE_UPDATE_POSTIMAGE)) -> array(
EqualNullSafe(expression(cdcTypeCol), Literal(CDC_TYPE_UPDATE_POSTIMAGE)) -> expression(array(
struct(updatePreimageCdcOutput: _*),
struct(outputCols: _*)).expr,
struct(outputCols: _*))),

EqualNullSafe(cdcTypeCol.expr, CDC_TYPE_DELETE) -> array(
struct(deleteCdcOutput: _*)).expr),
EqualNullSafe(expression(cdcTypeCol), CDC_TYPE_DELETE) -> expression(array(
struct(deleteCdcOutput: _*)))),

// If none of the CDC cases apply (true for purely rewritten target rows, dropped source
// rows, etc.) just stick to the normal output.
array(struct(mainDataOutput: _*)).expr
expression(array(struct(mainDataOutput: _*)))
))

val cdcToMainDataArray = Column(If(
Or(
EqualNullSafe(col(s"packedCdc.$CDC_TYPE_COLUMN_NAME").expr,
EqualNullSafe(expression(col(s"packedCdc.$CDC_TYPE_COLUMN_NAME")),
Literal(CDC_TYPE_INSERT)),
EqualNullSafe(col(s"packedCdc.$CDC_TYPE_COLUMN_NAME").expr,
EqualNullSafe(expression(col(s"packedCdc.$CDC_TYPE_COLUMN_NAME")),
Literal(CDC_TYPE_UPDATE_POSTIMAGE))),
array(
expression(array(
col("packedCdc"),
struct(
outputColNames
.dropRight(1)
.map { n => col(s"packedCdc.`$n`") }
:+ Column(CDC_TYPE_NOT_CDC).as(CDC_TYPE_COLUMN_NAME): _*)
).expr,
array(col("packedCdc")).expr
)),
expression(array(col("packedCdc")))
))

if (deduplicateDeletes.enabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.delta.hooks
import java.net.URI

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions._
import org.apache.spark.sql.delta.commands.DeletionVectorUtils.isTableDVFree
import org.apache.spark.sql.delta.logging.DeltaLogKeys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.OptimizablePartitionExpression._

import org.apache.spark.sql.Column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.util.control.NonFatal

import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaColumnMappingMode, DeltaErrors, DeltaLog, GeneratedColumn, NoMapping, TypeWidening, TypeWideningMode}
import org.apache.spark.sql.delta.{RowCommitVersion, RowId}
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.Protocol
import org.apache.spark.sql.delta.commands.cdc.CDCReader
import org.apache.spark.sql.delta.logging.DeltaLogKeys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.skipping

// scalastyle:off import.ordering.noEmptyLine
import org.apache.spark.sql.delta.expressions.{HilbertByteArrayIndex, HilbertLongIndex, InterleaveBits, RangePartitionId}
import org.apache.spark.sql.delta.ClassicColumnConversions._

import org.apache.spark.SparkException
import org.apache.spark.sql.Column
Expand All @@ -37,7 +38,7 @@ object MultiDimClusteringFunctions {
* partition range ids as (0, 0, 1, 1, 2, 2).
*/
def range_partition_id(col: Column, numPartitions: Int): Column = withExpr {
RangePartitionId(col.expr, numPartitions)
RangePartitionId(expression(col), numPartitions)
}

/**
Expand All @@ -54,7 +55,7 @@ object MultiDimClusteringFunctions {
* @note Only supports input expressions of type Int for now.
*/
def interleave_bits(cols: Column*): Column = withExpr {
InterleaveBits(cols.map(_.expr))
InterleaveBits(cols.map(expression))
}

// scalastyle:off line.size.limit
Expand All @@ -73,9 +74,9 @@ object MultiDimClusteringFunctions {
}
val hilbertBits = cols.length * numBits
if (hilbertBits < 64) {
HilbertLongIndex(numBits, cols.map(_.expr))
HilbertLongIndex(numBits, cols.map(expression))
} else {
Cast(HilbertByteArrayIndex(numBits, cols.map(_.expr)), StringType)
Cast(HilbertByteArrayIndex(numBits, cols.map(expression)), StringType)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta.sources
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.DeltaOperations.StreamingUpdate
import org.apache.spark.sql.delta.actions.{FileAction, Metadata, Protocol, SetTransaction}
import org.apache.spark.sql.delta.logging.DeltaLogKeys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.apache.spark.sql.delta.stats

import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.stats.DeltaStatistics.{MAX, MIN}

import org.apache.spark.sql.Column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.delta.skipping.clustering.{ClusteredTableUtils, ClusteringColumnInfo}
import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaLog, DeltaTableUtils}
import org.apache.spark.sql.delta.ClassicColumnConversions._
import org.apache.spark.sql.delta.actions.{AddFile, Metadata}
import org.apache.spark.sql.delta.implicits._
import org.apache.spark.sql.delta.metering.DeltaLogging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.delta.stats

import org.apache.spark.sql.delta.{DeltaErrors, DeltaUDF}
import org.apache.spark.sql.delta.ClassicColumnConversions._

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
Expand Down
Loading

0 comments on commit 95e1826

Please sign in to comment.