diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java index 065be9de2557..9cb49b6a2d30 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java @@ -44,6 +44,7 @@ private static native long nativeBuild( long rowCount, String joinKeys, int joinType, + boolean hasMixedFiltCondition, byte[] namedStruct); private StorageJoinBuilder() {} @@ -79,6 +80,7 @@ public static long build( rowCount, joinKey, SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(), + broadCastContext.hasMixedFiltCondition(), toNameStruct(output).toByteArray()); } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 6004f7f861bf..a7e7769e7736 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -82,6 +82,7 @@ case class CHBroadcastBuildSideRDD( case class BroadCastHashJoinContext( buildSideJoinKeys: Seq[Expression], joinType: JoinType, + hasMixedFiltCondition: Boolean, buildSideStructure: Seq[Attribute], buildHashTableId: String) @@ -139,9 +140,26 @@ case class CHBroadcastHashJoinExecTransformer( } val broadcast = buildPlan.executeBroadcast[BuildSideRelation]() val context = - BroadCastHashJoinContext(buildKeyExprs, joinType, buildPlan.output, buildHashTableId) + BroadCastHashJoinContext( + buildKeyExprs, + joinType, + isMixedCondition(condition), + buildPlan.output, + buildHashTableId) val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context) // FIXME: Do we have to make build side a RDD? streamedRDD :+ broadcastRDD } + + def isMixedCondition(cond: Option[Expression]): Boolean = { + val res = if (cond.isDefined) { + val leftOutputSet = left.outputSet + val rightOutputSet = right.outputSet + val allReferences = cond.get.references + !(allReferences.subsetOf(leftOutputSet) || allReferences.subsetOf(rightOutputSet)) + } else { + false + } + res + } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index ada980a20bc2..ee495457edee 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2593,13 +2593,21 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("create table ineq_join_t2 (key bigint, value bigint) using parquet"); spark.sql("insert into ineq_join_t1 values(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)"); spark.sql("insert into ineq_join_t2 values(2, 2), (2, 1), (3, 3), (4, 6), (5, 3)"); - val sql = + val sql1 = """ | select t1.key, t1.value, t2.key, t2.value from ineq_join_t1 as t1 | left join ineq_join_t2 as t2 | on t1.key = t2.key and t1.value > t2.value |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql1, true, { _ => }) + + val sql2 = + """ + | select t1.key, t1.value from ineq_join_t1 as t1 + | left join ineq_join_t2 as t2 + | on t1.key = t2.key and t1.value > t2.value and t1.value > t2.key + |""".stripMargin + compareResultsAgainstVanillaSpark(sql2, true, { _ => }) spark.sql("drop table ineq_join_t1") spark.sql("drop table ineq_join_t2") } diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala index 487433c469c1..8d4bee554625 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/execution/benchmarks/CHHashBuildBenchmark.scala @@ -104,7 +104,7 @@ object CHHashBuildBenchmark extends SqlBasedBenchmark with CHSqlBasedBenchmark w ( countsAndBytes.flatMap(_._2), countsAndBytes.map(_._1).sum, - BroadCastHashJoinContext(Seq(child.output.head), Inner, child.output, "") + BroadCastHashJoinContext(Seq(child.output.head), Inner, false, child.output, "") ) } } diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index f1b3ac2fbd9c..1c79a00a7c4c 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -82,6 +82,7 @@ std::shared_ptr buildJoin( jlong row_count, const std::string & join_keys, substrait::JoinRel_JoinType join_type, + bool has_mixed_join_condition, const std::string & named_struct) { auto join_key_list = Poco::StringTokenizer(join_keys, ","); @@ -105,6 +106,7 @@ std::shared_ptr buildJoin( true, kind, strictness, + has_mixed_join_condition, columns_description, ConstraintsDescription(), key, diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h index 5aa1e0876ed0..9a6837e35a0a 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h @@ -36,6 +36,7 @@ std::shared_ptr buildJoin( jlong row_count, const std::string & join_keys, substrait::JoinRel_JoinType join_type, + bool has_mixed_join_condition, const std::string & named_struct); void cleanBuildHashTable(const std::string & hash_table_id, jlong instance); std::shared_ptr getJoin(const std::string & hash_table_id); diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp index f0aec6af686d..af306564a4c5 100644 --- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp +++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp @@ -74,6 +74,7 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( bool use_nulls_, DB::JoinKind kind, DB::JoinStrictness strictness, + bool has_mixed_join_condition, const ColumnsDescription & columns, const ConstraintsDescription & constraints, const String & comment, @@ -91,7 +92,11 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( key_names.push_back(RIHGT_COLUMN_PREFIX + name); auto table_join = std::make_shared(SizeLimits(), true, kind, strictness, key_names); right_sample_block = rightSampleBlock(use_nulls, storage_metadata, table_join->kind()); - buildJoin(in, right_sample_block, table_join); + /// If there is mixed join conditions, need to build the hash join lazily, which rely on the real table join. + if (!has_mixed_join_condition) + buildJoin(in, right_sample_block, table_join); + else + collectAllInputs(in, right_sample_block); } /// The column names may be different in two blocks. @@ -135,6 +140,51 @@ void StorageJoinFromReadBuffer::buildJoin(DB::ReadBuffer & in, const Block heade } } +void StorageJoinFromReadBuffer::collectAllInputs(DB::ReadBuffer & in, const DB::Block header) +{ + local_engine::NativeReader block_stream(in); + ProfileInfo info; + while (Block block = block_stream.read()) + { + DB::ColumnsWithTypeAndName columns; + for (size_t i = 0; i < block.columns(); ++i) + { + const auto & column = block.getByPosition(i); + columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i))); + } + DB::Block final_block(columns); + info.update(final_block); + input_blocks.emplace_back(std::move(final_block)); + } +} + +void StorageJoinFromReadBuffer::buildJoinLazily(DB::Block header, std::shared_ptr analyzed_join) +{ + { + std::shared_lock lock(join_mutex); + if (join) + return; + } + std::unique_lock lock(join_mutex); + if (join) + return; + join = std::make_shared(analyzed_join, header, overwrite, row_count); + while(!input_blocks.empty()) + { + auto & block = *input_blocks.begin(); + DB::ColumnsWithTypeAndName columns; + for (size_t i = 0; i < block.columns(); ++i) + { + const auto & column = block.getByPosition(i); + columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i))); + } + DB::Block final_block(columns); + join->addBlockToJoin(final_block, true); + input_blocks.pop_front(); + } +} + + /// The column names of 'rgiht_header' could be different from the ones in `input_blocks`, and we must /// use 'right_header' to build the HashJoin. Otherwise, it will cause exceptions with name mismatches. /// @@ -148,7 +198,7 @@ DB::JoinPtr StorageJoinFromReadBuffer::getJoinLocked(std::shared_ptr(analyzed_join, right_sample_block); /// reuseJoinedData will set the flag `HashJoin::from_storage_join` which is required by `FilledStep` join_clone->reuseJoinedData(static_cast(*join)); diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h index af623c0cd717..ddefda69c30f 100644 --- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h +++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h @@ -15,6 +15,7 @@ * limitations under the License. */ #pragma once +#include #include #include @@ -40,6 +41,7 @@ class StorageJoinFromReadBuffer bool use_nulls_, DB::JoinKind kind, DB::JoinStrictness strictness, + bool has_mixed_join_condition, const DB::ColumnsDescription & columns_, const DB::ConstraintsDescription & constraints_, const String & comment, @@ -58,9 +60,13 @@ class StorageJoinFromReadBuffer size_t row_count; bool overwrite; DB::Block right_sample_block; + std::shared_mutex join_mutex; + std::list input_blocks; std::shared_ptr join = nullptr; void readAllBlocksFromInput(DB::ReadBuffer & in); void buildJoin(DB::ReadBuffer & in, const DB::Block header, std::shared_ptr analyzed_join); + void collectAllInputs(DB::ReadBuffer & in, const DB::Block header); + void buildJoinLazily(DB::Block header, std::shared_ptr analyzed_join); }; } diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index be28b9fabeff..38f188293726 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -1172,7 +1172,15 @@ JNIEXPORT jobject Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn } JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild( - JNIEnv * env, jclass, jstring key, jbyteArray in, jlong row_count_, jstring join_key_, jint join_type_, jbyteArray named_struct) + JNIEnv * env, + jclass, + jstring key, + jbyteArray in, + jlong row_count_, + jstring join_key_, + jint join_type_, + jboolean has_mixed_join_condition, + jbyteArray named_struct) { LOCAL_ENGINE_JNI_METHOD_START const auto hash_table_id = jstring2string(env, key); @@ -1186,8 +1194,8 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild local_engine::ReadBufferFromByteArray read_buffer_from_java_array(in, length); DB::CompressedReadBuffer input(read_buffer_from_java_array); local_engine::configureCompressedReadBuffer(input); - const auto * obj - = make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin(hash_table_id, input, row_count_, join_key, join_type, struct_string)); + const auto * obj = make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin( + hash_table_id, input, row_count_, join_key, join_type, has_mixed_join_condition, struct_string)); env->ReleaseByteArrayElements(named_struct, struct_address, JNI_ABORT); return obj->instance(); LOCAL_ENGINE_JNI_METHOD_END(env, 0)