Skip to content

Commit

Permalink
Use aggregate pool in Spark query runner (facebookincubator#10798)
Browse files Browse the repository at this point in the history
Summary:
In facebookincubator#10568, an 'aggregate' pool is passed to query runner's constructor. This PR
uses it to create child pools in the Spark query runner.

Pull Request resolved: facebookincubator#10798

Reviewed By: xiaoxmeng

Differential Revision: D62142058

Pulled By: bikramSingh91

fbshipit-source-id: 750f589e04abdd540d09d33da02d840fc41f4347
  • Loading branch information
rui-mo authored and facebook-github-bot committed Sep 3, 2024
1 parent 71d35e5 commit fd06bd9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion velox/functions/sparksql/fuzzer/SparkQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ std::vector<RowVectorPtr> SparkQueryRunner::executeVector(
// Write the input to a Parquet file.
auto tempFile = exec::test::TempFilePath::create();
const auto& filePath = tempFile->getPath();
auto writerPool = rootPool()->addAggregateChild("writer");
auto writerPool = aggregatePool()->addAggregateChild("writer");
writeToFile(filePath, input, writerPool.get());

// Create temporary view 'tmp' in Spark by reading the generated Parquet file.
Expand Down
21 changes: 8 additions & 13 deletions velox/functions/sparksql/fuzzer/SparkQueryRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@ class SparkQueryRunner : public velox::exec::test::ReferenceQueryRunner {
public:
/// @param coordinatorUri Spark connect server endpoint, e.g. localhost:15002.
SparkQueryRunner(
memory::MemoryPool* aggregatePool,
memory::MemoryPool* pool,
const std::string& coordinatorUri,
const std::string& userId,
const std::string& userName)
: ReferenceQueryRunner(aggregatePool),
: ReferenceQueryRunner(pool),
userId_(userId),
userName_(userName),
sessionId_(generateUUID()),
stub_(spark::connect::SparkConnectService::NewStub(grpc::CreateChannel(
coordinatorUri,
grpc::InsecureChannelCredentials()))){};
grpc::InsecureChannelCredentials()))) {
pool_ = aggregatePool()->addLeafChild("leaf");
copyPool_ = aggregatePool()->addLeafChild("copy");
};

/// Converts Velox query plan to Spark SQL. Supports Values -> Aggregation.
/// Values node is converted into reading from 'tmp' table.
Expand Down Expand Up @@ -90,10 +93,6 @@ class SparkQueryRunner : public velox::exec::test::ReferenceQueryRunner {
return pool_.get();
}

velox::memory::MemoryPool* rootPool() {
return rootPool_.get();
}

// Reads the arrow IPC-format string data with arrow IPC reader and convert
// them into Velox RowVectors.
std::vector<velox::RowVectorPtr> readArrowData(const std::string& data);
Expand All @@ -111,11 +110,7 @@ class SparkQueryRunner : public velox::exec::test::ReferenceQueryRunner {
const std::string sessionId_;
// Used to make gRPC calls to the SparkConnectService.
std::unique_ptr<spark::connect::SparkConnectService::Stub> stub_;
std::shared_ptr<velox::memory::MemoryPool> rootPool_{
velox::memory::memoryManager()->addRootPool()};
std::shared_ptr<velox::memory::MemoryPool> pool_{
rootPool_->addLeafChild("leaf")};
std::shared_ptr<velox::memory::MemoryPool> copyPool_{
rootPool_->addLeafChild("copy")};
std::shared_ptr<velox::memory::MemoryPool> pool_;
std::shared_ptr<velox::memory::MemoryPool> copyPool_;
};
} // namespace facebook::velox::functions::sparksql::fuzzer
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ class SparkQueryRunnerTest : public ::testing::Test,
// This test requires a Spark coordinator running at localhost, so disable it
// by default.
TEST_F(SparkQueryRunnerTest, DISABLED_basic) {
auto aggregatePool = rootPool_->addAggregateChild("basic");
auto queryRunner = std::make_unique<fuzzer::SparkQueryRunner>(
pool(), "localhost:15002", "test", "basic");
aggregatePool.get(), "localhost:15002", "test", "basic");

auto input = makeRowVector({
makeConstant<int64_t>(1, 25),
Expand Down Expand Up @@ -93,8 +94,9 @@ TEST_F(SparkQueryRunnerTest, DISABLED_fuzzer) {
.project({"a0", "array_sort(a1)"})
.planNode();

auto aggregatePool = rootPool_->addAggregateChild("fuzzer");
auto queryRunner = std::make_unique<fuzzer::SparkQueryRunner>(
pool(), "localhost:15002", "test", "fuzzer");
aggregatePool.get(), "localhost:15002", "test", "fuzzer");
auto sql = queryRunner->toSql(plan);
ASSERT_TRUE(sql.has_value());

Expand All @@ -107,8 +109,9 @@ TEST_F(SparkQueryRunnerTest, DISABLED_fuzzer) {
}

TEST_F(SparkQueryRunnerTest, toSql) {
auto aggregatePool = rootPool_->addAggregateChild("toSql");
auto queryRunner = std::make_unique<fuzzer::SparkQueryRunner>(
pool(), "unused", "unused", "unused");
aggregatePool.get(), "unused", "unused", "unused");

auto dataType = ROW({"c0", "c1", "c2"}, {DOUBLE(), DOUBLE(), BOOLEAN()});
auto plan = exec::test::PlanBuilder()
Expand Down

0 comments on commit fd06bd9

Please sign in to comment.