Skip to content

Commit

Permalink
feat(plan-builder): Add TableWriterBuilder
Browse files Browse the repository at this point in the history
Summary:
Add TableWriterBuilder API following the same pattern of
TableScanBuilder to simplify the creation of complex TableWriter nodes.

Differential Revision: D67422718
  • Loading branch information
pedroerp authored and facebook-github-bot committed Dec 19, 2024
1 parent 3a9089a commit ff965b7
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 101 deletions.
215 changes: 114 additions & 101 deletions velox/exec/tests/utils/PlanBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,82 @@ core::PlanNodePtr PlanBuilder::TableScanBuilder::build(core::PlanNodeId id) {
id, outputType_, tableHandle_, assignments_);
}

core::PlanNodePtr PlanBuilder::TableWriterBuilder::build(core::PlanNodeId id) {
auto upstreamNode = planBuilder_.planNode();
VELOX_CHECK_NOT_NULL(upstreamNode, "TableWrite cannot be the source node");

// If outputType wasn't explicit specified, fallback to use the output of the
// upstream operator.
auto outputType = outputType_ ? outputType_ : upstreamNode->outputType();

// Create column handles.
std::vector<std::shared_ptr<const connector::hive::HiveColumnHandle>>
columnHandles;
for (auto i = 0; i < outputType->size(); ++i) {
const auto column = outputType->nameOf(i);
const bool isPartitionKey =
std::find(partitionBy_.begin(), partitionBy_.end(), column) !=
partitionBy_.end();
columnHandles.push_back(std::make_shared<connector::hive::HiveColumnHandle>(
column,
isPartitionKey
? connector::hive::HiveColumnHandle::ColumnType::kPartitionKey
: connector::hive::HiveColumnHandle::ColumnType::kRegular,
outputType->childAt(i),
outputType->childAt(i)));
}

auto locationHandle = std::make_shared<connector::hive::LocationHandle>(
outputDirectoryPath_,
outputDirectoryPath_,
connector::hive::LocationHandle::TableType::kNew,
outputFileName_);

std::shared_ptr<HiveBucketProperty> bucketProperty;
if (bucketCount_ != 0) {
bucketProperty = buildHiveBucketProperty(
outputType_, bucketCount_, bucketedBy_, sortBy_);
}

auto hiveHandle = std::make_shared<connector::hive::HiveInsertTableHandle>(
columnHandles,
locationHandle,
fileFormat_,
bucketProperty,
compressionKind_,
serdeParameters_,
options_);

auto insertHandle =
std::make_shared<core::InsertTableHandle>(connectorId_, hiveHandle);

std::shared_ptr<core::AggregationNode> aggregationNode;
if (!aggregates_.empty()) {
auto aggregatesAndNames = planBuilder_.createAggregateExpressionsAndNames(
aggregates_, {}, core::AggregationNode::Step::kPartial);
aggregationNode = std::make_shared<core::AggregationNode>(
planBuilder_.nextPlanNodeId(),
core::AggregationNode::Step::kPartial,
std::vector<core::FieldAccessTypedExprPtr>{}, // groupingKeys
std::vector<core::FieldAccessTypedExprPtr>{}, // preGroupedKeys
aggregatesAndNames.names, // ignoreNullKeys
aggregatesAndNames.aggregates,
false,
upstreamNode);
}

return std::make_shared<core::TableWriteNode>(
planBuilder_.nextPlanNodeId(),
outputType_,
outputType->names(),
aggregationNode,
insertHandle,
false,
TableWriteTraits::outputType(aggregationNode),
connector::CommitStrategy::kNoCommit,
upstreamNode);
}

PlanBuilder& PlanBuilder::values(
const std::vector<RowVectorPtr>& values,
bool parallelizable,
Expand Down Expand Up @@ -377,18 +453,13 @@ PlanBuilder& PlanBuilder::tableWrite(
const std::vector<std::string>& aggregates,
const std::shared_ptr<dwio::common::WriterOptions>& options,
const std::string& outputFileName) {
return tableWrite(
outputDirectoryPath,
{},
0,
{},
{},
fileFormat,
aggregates,
kHiveDefaultConnectorId,
{},
options,
outputFileName);
return TableWriterBuilder(*this)
.outputDirectoryPath(outputDirectoryPath)
.outputFileName(outputFileName)
.fileFormat(fileFormat)
.aggregates(aggregates)
.options(options)
.endTableWriter();
}

PlanBuilder& PlanBuilder::tableWrite(
Expand All @@ -397,17 +468,13 @@ PlanBuilder& PlanBuilder::tableWrite(
const dwio::common::FileFormat fileFormat,
const std::vector<std::string>& aggregates,
const std::shared_ptr<dwio::common::WriterOptions>& options) {
return tableWrite(
outputDirectoryPath,
partitionBy,
0,
{},
{},
fileFormat,
aggregates,
kHiveDefaultConnectorId,
{},
options);
return TableWriterBuilder(*this)
.outputDirectoryPath(outputDirectoryPath)
.partitionBy(partitionBy)
.fileFormat(fileFormat)
.aggregates(aggregates)
.options(options)
.endTableWriter();
}

PlanBuilder& PlanBuilder::tableWrite(
Expand All @@ -418,17 +485,15 @@ PlanBuilder& PlanBuilder::tableWrite(
const dwio::common::FileFormat fileFormat,
const std::vector<std::string>& aggregates,
const std::shared_ptr<dwio::common::WriterOptions>& options) {
return tableWrite(
outputDirectoryPath,
partitionBy,
bucketCount,
bucketedBy,
{},
fileFormat,
aggregates,
kHiveDefaultConnectorId,
{},
options);
return TableWriterBuilder(*this)
.outputDirectoryPath(outputDirectoryPath)
.partitionBy(partitionBy)
.bucketCount(bucketCount)
.bucketedBy(bucketedBy)
.fileFormat(fileFormat)
.aggregates(aggregates)
.options(options)
.endTableWriter();
}

PlanBuilder& PlanBuilder::tableWrite(
Expand All @@ -445,73 +510,21 @@ PlanBuilder& PlanBuilder::tableWrite(
const std::string& outputFileName,
const common::CompressionKind compressionKind,
const RowTypePtr& schema) {
VELOX_CHECK_NOT_NULL(planNode_, "TableWrite cannot be the source node");
auto rowType = schema ? schema : planNode_->outputType();

std::vector<std::shared_ptr<const connector::hive::HiveColumnHandle>>
columnHandles;
for (auto i = 0; i < rowType->size(); ++i) {
const auto column = rowType->nameOf(i);
const bool isPartitionKey =
std::find(partitionBy.begin(), partitionBy.end(), column) !=
partitionBy.end();
columnHandles.push_back(std::make_shared<connector::hive::HiveColumnHandle>(
column,
isPartitionKey
? connector::hive::HiveColumnHandle::ColumnType::kPartitionKey
: connector::hive::HiveColumnHandle::ColumnType::kRegular,
rowType->childAt(i),
rowType->childAt(i)));
}

auto locationHandle = std::make_shared<connector::hive::LocationHandle>(
outputDirectoryPath,
outputDirectoryPath,
connector::hive::LocationHandle::TableType::kNew,
outputFileName);
std::shared_ptr<HiveBucketProperty> bucketProperty;
if (bucketCount != 0) {
bucketProperty =
buildHiveBucketProperty(rowType, bucketCount, bucketedBy, sortBy);
}
auto hiveHandle = std::make_shared<connector::hive::HiveInsertTableHandle>(
columnHandles,
locationHandle,
fileFormat,
bucketProperty,
compressionKind,
serdeParameters,
options);

auto insertHandle = std::make_shared<core::InsertTableHandle>(
std::string(connectorId), hiveHandle);

std::shared_ptr<core::AggregationNode> aggregationNode;
if (!aggregates.empty()) {
auto aggregatesAndNames = createAggregateExpressionsAndNames(
aggregates, {}, core::AggregationNode::Step::kPartial);
aggregationNode = std::make_shared<core::AggregationNode>(
nextPlanNodeId(),
core::AggregationNode::Step::kPartial,
std::vector<core::FieldAccessTypedExprPtr>{}, // groupingKeys
std::vector<core::FieldAccessTypedExprPtr>{}, // preGroupedKeys
aggregatesAndNames.names, // ignoreNullKeys
aggregatesAndNames.aggregates,
false,
planNode_);
}

planNode_ = std::make_shared<core::TableWriteNode>(
nextPlanNodeId(),
rowType,
rowType->names(),
aggregationNode,
insertHandle,
false,
TableWriteTraits::outputType(aggregationNode),
connector::CommitStrategy::kNoCommit,
planNode_);
return *this;
return TableWriterBuilder(*this)
.outputDirectoryPath(outputDirectoryPath)
.outputFileName(outputFileName)
.outputType(schema)
.partitionBy(partitionBy)
.bucketCount(bucketCount)
.bucketedBy(bucketedBy)
.sortBy(sortBy)
.fileFormat(fileFormat)
.aggregates(aggregates)
.connectorId(connectorId)
.serdeParameters(serdeParameters)
.options(options)
.compressionKind(compressionKind)
.endTableWriter();
}

PlanBuilder& PlanBuilder::tableWriteMerge(
Expand Down
Loading

0 comments on commit ff965b7

Please sign in to comment.