diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index a9cf1bccb8fd..28751a81c449 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -226,6 +226,82 @@ core::PlanNodePtr PlanBuilder::TableScanBuilder::build(core::PlanNodeId id) { id, outputType_, tableHandle_, assignments_); } +core::PlanNodePtr PlanBuilder::TableWriterBuilder::build(core::PlanNodeId id) { + auto upstreamNode = planBuilder_.planNode(); + VELOX_CHECK_NOT_NULL(upstreamNode, "TableWrite cannot be the source node"); + + // If outputType wasn't explicit specified, fallback to use the output of the + // upstream operator. + auto outputType = outputType_ ? outputType_ : upstreamNode->outputType(); + + // Create column handles. + std::vector> + columnHandles; + for (auto i = 0; i < outputType->size(); ++i) { + const auto column = outputType->nameOf(i); + const bool isPartitionKey = + std::find(partitionBy_.begin(), partitionBy_.end(), column) != + partitionBy_.end(); + columnHandles.push_back(std::make_shared( + column, + isPartitionKey + ? connector::hive::HiveColumnHandle::ColumnType::kPartitionKey + : connector::hive::HiveColumnHandle::ColumnType::kRegular, + outputType->childAt(i), + outputType->childAt(i))); + } + + auto locationHandle = std::make_shared( + outputDirectoryPath_, + outputDirectoryPath_, + connector::hive::LocationHandle::TableType::kNew, + outputFileName_); + + std::shared_ptr bucketProperty; + if (bucketCount_ != 0) { + bucketProperty = buildHiveBucketProperty( + outputType_, bucketCount_, bucketedBy_, sortBy_); + } + + auto hiveHandle = std::make_shared( + columnHandles, + locationHandle, + fileFormat_, + bucketProperty, + compressionKind_, + serdeParameters_, + options_); + + auto insertHandle = + std::make_shared(connectorId_, hiveHandle); + + std::shared_ptr aggregationNode; + if (!aggregates_.empty()) { + auto aggregatesAndNames = planBuilder_.createAggregateExpressionsAndNames( + aggregates_, {}, core::AggregationNode::Step::kPartial); + aggregationNode = std::make_shared( + planBuilder_.nextPlanNodeId(), + core::AggregationNode::Step::kPartial, + std::vector{}, // groupingKeys + std::vector{}, // preGroupedKeys + aggregatesAndNames.names, // ignoreNullKeys + aggregatesAndNames.aggregates, + false, + upstreamNode); + } + + return std::make_shared( + planBuilder_.nextPlanNodeId(), + outputType_, + outputType->names(), + aggregationNode, + insertHandle, + false, + TableWriteTraits::outputType(aggregationNode), + connector::CommitStrategy::kNoCommit, + upstreamNode); +} + PlanBuilder& PlanBuilder::values( const std::vector& values, bool parallelizable, @@ -377,18 +453,13 @@ PlanBuilder& PlanBuilder::tableWrite( const std::vector& aggregates, const std::shared_ptr& options, const std::string& outputFileName) { - return tableWrite( - outputDirectoryPath, - {}, - 0, - {}, - {}, - fileFormat, - aggregates, - kHiveDefaultConnectorId, - {}, - options, - outputFileName); + return TableWriterBuilder(*this) + .outputDirectoryPath(outputDirectoryPath) + .outputFileName(outputFileName) + .fileFormat(fileFormat) + .aggregates(aggregates) + .options(options) + .endTableWriter(); } PlanBuilder& PlanBuilder::tableWrite( @@ -397,17 +468,13 @@ PlanBuilder& PlanBuilder::tableWrite( const dwio::common::FileFormat fileFormat, const std::vector& aggregates, const std::shared_ptr& options) { - return tableWrite( - outputDirectoryPath, - partitionBy, - 0, - {}, - {}, - fileFormat, - aggregates, - kHiveDefaultConnectorId, - {}, - options); + return TableWriterBuilder(*this) + .outputDirectoryPath(outputDirectoryPath) + .partitionBy(partitionBy) + .fileFormat(fileFormat) + .aggregates(aggregates) + .options(options) + .endTableWriter(); } PlanBuilder& PlanBuilder::tableWrite( @@ -418,17 +485,15 @@ PlanBuilder& PlanBuilder::tableWrite( const dwio::common::FileFormat fileFormat, const std::vector& aggregates, const std::shared_ptr& options) { - return tableWrite( - outputDirectoryPath, - partitionBy, - bucketCount, - bucketedBy, - {}, - fileFormat, - aggregates, - kHiveDefaultConnectorId, - {}, - options); + return TableWriterBuilder(*this) + .outputDirectoryPath(outputDirectoryPath) + .partitionBy(partitionBy) + .bucketCount(bucketCount) + .bucketedBy(bucketedBy) + .fileFormat(fileFormat) + .aggregates(aggregates) + .options(options) + .endTableWriter(); } PlanBuilder& PlanBuilder::tableWrite( @@ -445,73 +510,21 @@ PlanBuilder& PlanBuilder::tableWrite( const std::string& outputFileName, const common::CompressionKind compressionKind, const RowTypePtr& schema) { - VELOX_CHECK_NOT_NULL(planNode_, "TableWrite cannot be the source node"); - auto rowType = schema ? schema : planNode_->outputType(); - - std::vector> - columnHandles; - for (auto i = 0; i < rowType->size(); ++i) { - const auto column = rowType->nameOf(i); - const bool isPartitionKey = - std::find(partitionBy.begin(), partitionBy.end(), column) != - partitionBy.end(); - columnHandles.push_back(std::make_shared( - column, - isPartitionKey - ? connector::hive::HiveColumnHandle::ColumnType::kPartitionKey - : connector::hive::HiveColumnHandle::ColumnType::kRegular, - rowType->childAt(i), - rowType->childAt(i))); - } - - auto locationHandle = std::make_shared( - outputDirectoryPath, - outputDirectoryPath, - connector::hive::LocationHandle::TableType::kNew, - outputFileName); - std::shared_ptr bucketProperty; - if (bucketCount != 0) { - bucketProperty = - buildHiveBucketProperty(rowType, bucketCount, bucketedBy, sortBy); - } - auto hiveHandle = std::make_shared( - columnHandles, - locationHandle, - fileFormat, - bucketProperty, - compressionKind, - serdeParameters, - options); - - auto insertHandle = std::make_shared( - std::string(connectorId), hiveHandle); - - std::shared_ptr aggregationNode; - if (!aggregates.empty()) { - auto aggregatesAndNames = createAggregateExpressionsAndNames( - aggregates, {}, core::AggregationNode::Step::kPartial); - aggregationNode = std::make_shared( - nextPlanNodeId(), - core::AggregationNode::Step::kPartial, - std::vector{}, // groupingKeys - std::vector{}, // preGroupedKeys - aggregatesAndNames.names, // ignoreNullKeys - aggregatesAndNames.aggregates, - false, - planNode_); - } - - planNode_ = std::make_shared( - nextPlanNodeId(), - rowType, - rowType->names(), - aggregationNode, - insertHandle, - false, - TableWriteTraits::outputType(aggregationNode), - connector::CommitStrategy::kNoCommit, - planNode_); - return *this; + return TableWriterBuilder(*this) + .outputDirectoryPath(outputDirectoryPath) + .outputFileName(outputFileName) + .outputType(schema) + .partitionBy(partitionBy) + .bucketCount(bucketCount) + .bucketedBy(bucketedBy) + .sortBy(sortBy) + .fileFormat(fileFormat) + .aggregates(aggregates) + .connectorId(connectorId) + .serdeParameters(serdeParameters) + .options(options) + .compressionKind(compressionKind) + .endTableWriter(); } PlanBuilder& PlanBuilder::tableWriteMerge( diff --git a/velox/exec/tests/utils/PlanBuilder.h b/velox/exec/tests/utils/PlanBuilder.h index 4bd09fc680f9..5445963c0c53 100644 --- a/velox/exec/tests/utils/PlanBuilder.h +++ b/velox/exec/tests/utils/PlanBuilder.h @@ -110,6 +110,10 @@ class PlanBuilder { static constexpr const std::string_view kHiveDefaultConnectorId{"test-hive"}; static constexpr const std::string_view kTpchDefaultConnectorId{"test-tpch"}; + /// + /// TableScan + /// + /// Add a TableScanNode to scan a Hive table. /// /// @param outputType List of column names and types to read from the table. @@ -302,6 +306,144 @@ class PlanBuilder { return *tableScanBuilder_; } + /// + /// TableWriter + /// + + /// Helper class to build a custom TableWriterNode. + /// Uses a planBuilder instance to get the next plan id, memory pool, and + /// upstream node (the node which will produce the data). + /// + /// Uses the hive connector by default. + class TableWriterBuilder { + public: + TableWriterBuilder(PlanBuilder& builder) : planBuilder_(builder) {} + + /// @param outputType The schema that will be written to the output file. It + /// may reference a subset or change the order of columns from the input + /// (upstream operator output). + TableWriterBuilder& outputType(RowTypePtr outputType) { + outputType_ = std::move(outputType); + return *this; + } + + /// @param outputDirectoryPath Path in which output files will be created. + TableWriterBuilder& outputDirectoryPath(std::string outputDirectoryPath) { + outputDirectoryPath_ = std::move(outputDirectoryPath); + return *this; + } + + /// @param outputFileName File name of the output (optional). If specified + /// (non-empty), use it instead of generating the file name in Velox. Should + /// only be specified in non-bucketing write. + TableWriterBuilder& outputFileName(std::string outputFileName) { + outputFileName_ = std::move(outputFileName); + return *this; + } + + /// @param connectorId The id of the connector to write to. + TableWriterBuilder& connectorId(std::string_view connectorId) { + connectorId_ = connectorId; + return *this; + } + + /// @param partitionBy Specifies the partition key columns. + TableWriterBuilder& partitionBy(std::vector partitionBy) { + partitionBy_ = std::move(partitionBy); + return *this; + } + + /// @param bucketCount Specifies the bucket count. + TableWriterBuilder& bucketCount(int32_t count) { + bucketCount_ = count; + return *this; + } + + /// @param bucketedBy Specifies the bucket by columns. + TableWriterBuilder& bucketedBy(std::vector bucketedBy) { + bucketedBy_ = std::move(bucketedBy); + return *this; + } + + /// @param aggregates Aggregations for column statistics collection during + /// write. + TableWriterBuilder& aggregates(std::vector aggregates) { + aggregates_ = std::move(aggregates); + return *this; + } + + /// @param sortBy Specifies the sort by columns. + TableWriterBuilder& sortBy( + std::vector> + sortBy) { + sortBy_ = std::move(sortBy); + return *this; + } + + /// @param serdeParameters Additional parameters passed to the writer. + TableWriterBuilder& serdeParameters( + std::unordered_map serdeParameters) { + serdeParameters_ = std::move(serdeParameters); + return *this; + } + + /// @param Option objects passed to the writer. + TableWriterBuilder& options( + std::shared_ptr options) { + options_ = std::move(options); + return *this; + } + + /// @param fileFormat File format to use for the written data. + TableWriterBuilder& fileFormat(dwio::common::FileFormat fileFormat) { + fileFormat_ = fileFormat; + return *this; + } + + /// @param compressionKind Compression scheme to use for writing the + /// output data files. + TableWriterBuilder& compressionKind( + common::CompressionKind compressionKind) { + compressionKind_ = compressionKind; + return *this; + } + + /// Stop the TableWriterBuilder. + PlanBuilder& endTableWriter() { + planBuilder_.planNode_ = build(planBuilder_.nextPlanNodeId()); + return planBuilder_; + } + + private: + /// Build the plan node TableWriterNode. + core::PlanNodePtr build(core::PlanNodeId id); + + PlanBuilder& planBuilder_; + RowTypePtr outputType_; + std::string outputDirectoryPath_; + std::string outputFileName_; + std::string connectorId_{kHiveDefaultConnectorId}; + + std::vector partitionBy_; + int32_t bucketCount_{0}; + std::vector bucketedBy_; + std::vector aggregates_; + std::vector> + sortBy_; + + std::unordered_map serdeParameters_; + std::shared_ptr options_; + + dwio::common::FileFormat fileFormat_{dwio::common::FileFormat::DWRF}; + common::CompressionKind compressionKind_{common::CompressionKind_NONE}; + }; + + /// Start a TableWriterBuilder. + TableWriterBuilder& startTableWriter() { + tableWriterBuilder_.reset(new TableWriterBuilder(*this)); + return *tableWriterBuilder_; + } + /// Add a ValuesNode using specified data. /// /// @param values The data to use. @@ -1181,6 +1323,7 @@ class PlanBuilder { core::PlanNodePtr planNode_; parse::ParseOptions options_; std::shared_ptr tableScanBuilder_; + std::shared_ptr tableWriterBuilder_; private: std::shared_ptr planNodeIdGenerator_;