diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala index a0f28c5ab8a2..ad7694ea21af 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala @@ -26,7 +26,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning} -import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode} +import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode, LongHashedRelation} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.TaskResources @@ -96,9 +96,8 @@ object BroadcastUtils { // HashedRelation to ColumnarBuildSideRelation. val fromBroadcast = from.asInstanceOf[Broadcast[HashedRelation]] val fromRelation = fromBroadcast.value.asReadOnlyCopy() - val keys = fromRelation.keys() val toRelation = TaskResources.runUnsafe { - val batchItr: Iterator[ColumnarBatch] = fn(keys.flatMap(key => fromRelation.get(key))) + val batchItr: Iterator[ColumnarBatch] = fn(reconstructRows(fromRelation)) val serialized: Array[Array[Byte]] = serializeStream(batchItr) match { case ColumnarBatchSerializeResult.EMPTY => Array() @@ -170,4 +169,17 @@ object BroadcastUtils { } serializeResult } + + private def reconstructRows(relation: HashedRelation): Iterator[InternalRow] = { + // It seems that LongHashedRelation and UnsafeHashedRelation don't follow the same + // criteria while getting values from them. + // Should review the internals of this part of code. + relation match { + case relation: LongHashedRelation if relation.keyIsUnique => + relation.keys().map(k => relation.getValue(k)) + case relation: LongHashedRelation if !relation.keyIsUnique => + relation.keys().flatMap(k => relation.get(k)) + case other => other.valuesWithKeyIndex().map(_.getValue) + } + } }