Skip to content

Commit 50a328b

Browse files
chris-twinerhvanhovell
authored andcommitted
[SPARK-49960][SQL] Custom ExpressionEncoder support and TransformingEncoder fixes
### What changes were proposed in this pull request? 4.0.0-preview2 introduced, as part of SPARK-49025 pr apache#47785, changes which drive ExpressionEncoder derivation purely from AgnosticEncoders. This PR adds a trait: ```scala DeveloperApi trait AgnosticExpressionPathEncoder[T] extends AgnosticEncoder[T] { def toCatalyst(input: Expression): Expression def fromCatalyst(inputPath: Expression): Expression } ``` and hooks in the De/SerializationBuildHelper matches to allow seamless extension of non-connect custom encoders (such as [frameless](https://github.com/typelevel/frameless) or [sparksql-scalapb](https://github.com/scalapb/sparksql-scalapb)). SPARK-49960 provides the same information. Additionally this PR provides fixes necessary to use TransformingEncoder as a root encoder with an OptionalEncoder, use as an ArrayType and MapType entry/key. ### Why are the changes needed? Without this change (or similar) there is no way for custom encoders to integrate with 4.0.0-preview2 derived encoders, something which has worked and devs have benefited from since pre 2.4 days. This stops code such as Dataset.joinWith from deriving a tuple encoder which works (as the provided ExpressionEncoder is now discarded under preview2). Supplying a custom AgnosticEncoder under preview2 also fails as only the preview2 AgnosticEncoders are supported in De/SerializationBuildHelper, triggering a MatchError. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test was added using a "custom" string encoder and joinWith based on an existing joinWith test. Removing the case statements in either BuildHelper will trigger the MatchError. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#50023 from chris-twiner/temp/expressionEncoder_compat_TransformingEncoder_fixes. Authored-by: Chris Twiner <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent 496fe7a commit 50a328b

File tree

9 files changed

+291
-20
lines changed

9 files changed

+291
-20
lines changed

sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,14 @@ object AgnosticEncoders {
276276
* another encoder. This is fallback for scenarios where objects can't be represented using
277277
* standard encoders, an example of this is where we use a different (opaque) serialization
278278
* format (i.e. java serialization, kryo serialization, or protobuf).
279+
* @param nullable
280+
* defaults to false indicating the codec guarantees decode / encode results are non-nullable
279281
*/
280282
case class TransformingEncoder[I, O](
281283
clsTag: ClassTag[I],
282284
transformed: AgnosticEncoder[O],
283-
codecProvider: () => Codec[_ >: I, O])
285+
codecProvider: () => Codec[_ >: I, O],
286+
override val nullable: Boolean = false)
284287
extends AgnosticEncoder[I] {
285288
override def isPrimitive: Boolean = transformed.isPrimitive
286289
override def dataType: DataType = transformed.dataType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala

+6-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.sql.catalyst.{expressions => exprs}
2121
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
22-
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
22+
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
2323
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
2424
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
2525
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
@@ -270,6 +270,8 @@ object DeserializerBuildHelper {
270270
enc: AgnosticEncoder[_],
271271
path: Expression,
272272
walkedTypePath: WalkedTypePath): Expression = enc match {
273+
case ae: AgnosticExpressionPathEncoder[_] =>
274+
ae.fromCatalyst(path)
273275
case _ if isNativeEncoder(enc) =>
274276
path
275277
case _: BoxedLeafEncoder[_, _] =>
@@ -447,13 +449,13 @@ object DeserializerBuildHelper {
447449
val result = InitializeJavaBean(newInstance, setters.toMap)
448450
exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)
449451

450-
case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec =>
452+
case TransformingEncoder(tag, _, codec, _) if codec == JavaSerializationCodec =>
451453
DecodeUsingSerializer(path, tag, kryo = false)
452454

453-
case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec =>
455+
case TransformingEncoder(tag, _, codec, _) if codec == KryoSerializationCodec =>
454456
DecodeUsingSerializer(path, tag, kryo = true)
455457

456-
case TransformingEncoder(tag, encoder, provider) =>
458+
case TransformingEncoder(tag, encoder, provider, _) =>
457459
Invoke(
458460
Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
459461
"decode",

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala

+10-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.language.existentials
2121

2222
import org.apache.spark.sql.catalyst.{expressions => exprs}
2323
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
24-
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
24+
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
2525
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
2626
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
2727
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
@@ -306,6 +306,7 @@ object SerializerBuildHelper {
306306
* by encoder `enc`.
307307
*/
308308
private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match {
309+
case ae: AgnosticExpressionPathEncoder[_] => ae.toCatalyst(input)
309310
case _ if isNativeEncoder(enc) => input
310311
case BoxedBooleanEncoder => createSerializerForBoolean(input)
311312
case BoxedByteEncoder => createSerializerForByte(input)
@@ -418,18 +419,21 @@ object SerializerBuildHelper {
418419
}
419420
createSerializerForObject(input, serializedFields)
420421

421-
case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec =>
422+
case TransformingEncoder(_, _, codec, _) if codec == JavaSerializationCodec =>
422423
EncodeUsingSerializer(input, kryo = false)
423424

424-
case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec =>
425+
case TransformingEncoder(_, _, codec, _) if codec == KryoSerializationCodec =>
425426
EncodeUsingSerializer(input, kryo = true)
426427

427-
case TransformingEncoder(_, encoder, codecProvider) =>
428+
case TransformingEncoder(_, encoder, codecProvider, _) =>
428429
val encoded = Invoke(
429430
Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])),
430431
"encode",
431432
externalDataTypeFor(encoder),
432-
input :: Nil)
433+
input :: Nil,
434+
propagateNull = input.nullable,
435+
returnNullable = input.nullable
436+
)
433437
createSerializer(encoder, encoded)
434438
}
435439

@@ -486,6 +490,7 @@ object SerializerBuildHelper {
486490
nullable: Boolean): Expression => Expression = { input =>
487491
val expected = enc match {
488492
case OptionEncoder(_) => lenientExternalDataTypeFor(enc)
493+
case TransformingEncoder(_, transformed, _, _) => lenientExternalDataTypeFor(transformed)
489494
case _ => enc.dataType
490495
}
491496

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala

+25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders
1818

1919
import scala.collection.Map
2020

21+
import org.apache.spark.annotation.DeveloperApi
2122
import org.apache.spark.sql.catalyst.InternalRow
2223
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder}
2324
import org.apache.spark.sql.catalyst.expressions.Expression
@@ -26,6 +27,30 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
2627
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType}
2728
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
2829

30+
/**
31+
* :: DeveloperApi ::
32+
* Extensible [[AgnosticEncoder]] providing conversion extension points over type T
33+
* @tparam T over T
34+
*/
35+
@DeveloperApi
36+
@deprecated("This trait is intended only as a migration tool and will be removed in 4.1")
37+
trait AgnosticExpressionPathEncoder[T]
38+
extends AgnosticEncoder[T] {
39+
/**
40+
* Converts from T to InternalRow
41+
* @param input the starting input path
42+
* @return
43+
*/
44+
def toCatalyst(input: Expression): Expression
45+
46+
/**
47+
* Converts from InternalRow to T
48+
* @param inputPath path expression from InternalRow
49+
* @return
50+
*/
51+
def fromCatalyst(inputPath: Expression): Expression
52+
}
53+
2954
/**
3055
* Helper class for Generating [[ExpressionEncoder]]s.
3156
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

+10-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.SparkRuntimeException
2424
import org.apache.spark.sql.{Encoder, Row}
2525
import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, JavaTypeInference, ScalaReflection, SerializerBuildHelper}
2626
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
27+
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{OptionEncoder, TransformingEncoder}
2728
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
2829
import org.apache.spark.sql.catalyst.expressions._
2930
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
@@ -215,6 +216,13 @@ case class ExpressionEncoder[T](
215216
StructField(s.name, s.dataType, s.nullable)
216217
})
217218

219+
private def transformerOfOption(enc: AgnosticEncoder[_]): Boolean =
220+
enc match {
221+
case t: TransformingEncoder[_, _] => transformerOfOption(t.transformed)
222+
case _: OptionEncoder[_] => true
223+
case _ => false
224+
}
225+
218226
/**
219227
* Returns true if the type `T` is serialized as a struct by `objSerializer`.
220228
*/
@@ -228,7 +236,8 @@ case class ExpressionEncoder[T](
228236
* returns true if `T` is serialized as struct and is not `Option` type.
229237
*/
230238
def isSerializedAsStructForTopLevel: Boolean = {
231-
isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)
239+
isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass) &&
240+
!transformerOfOption(encoder)
232241
}
233242

234243
// serializer expressions are used to encode an object to a row, while the object is usually an

0 commit comments

Comments
 (0)