diff --git a/velox/connectors/hive/HiveConnectorUtil.cpp b/velox/connectors/hive/HiveConnectorUtil.cpp index 3ff34cc50e84..a1d296e650a7 100644 --- a/velox/connectors/hive/HiveConnectorUtil.cpp +++ b/velox/connectors/hive/HiveConnectorUtil.cpp @@ -180,6 +180,7 @@ void addSubfields( if (stringKey) { deduplicate(stringSubscripts); filter = std::make_unique(stringSubscripts, false); + spec.setFlatMapFeatureSelection(std::move(stringSubscripts)); } else { deduplicate(longSubscripts); if (keyType->isReal()) { @@ -189,6 +190,11 @@ void addSubfields( } else { filter = common::createBigintValues(longSubscripts, false); } + std::vector features; + for (auto num : longSubscripts) { + features.push_back(std::to_string(num)); + } + spec.setFlatMapFeatureSelection(std::move(features)); } keys->setFilter(std::move(filter)); break; @@ -510,9 +516,16 @@ void configureRowReaderOptions( std::vector columnNames; for (auto& spec : scanSpec->children()) { - if (!spec->isConstant()) { - columnNames.push_back(spec->fieldName()); + if (spec->isConstant()) { + continue; + } + std::string name = spec->fieldName(); + if (!spec->flatMapFeatureSelection().empty()) { + name += "#["; + name += folly::join(',', spec->flatMapFeatureSelection()); + name += ']'; } + columnNames.push_back(std::move(name)); } std::shared_ptr cs; if (columnNames.empty()) { diff --git a/velox/connectors/hive/tests/HiveConnectorTest.cpp b/velox/connectors/hive/tests/HiveConnectorTest.cpp index c58edffe999f..a21d98210641 100644 --- a/velox/connectors/hive/tests/HiveConnectorTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorTest.cpp @@ -149,6 +149,7 @@ TEST_F(HiveConnectorTest, makeScanSpec_requiredSubfields_mergeArray) { pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_EQ(c0->maxArrayElementsCount(), 2); + ASSERT_TRUE(c0->flatMapFeatureSelection().empty()); auto* elements = c0->childByName(ScanSpec::kArrayElementsFieldName); ASSERT_FALSE(elements->childByName("c0c0")->isConstant()); ASSERT_FALSE(elements->childByName("c0c2")->isConstant()); @@ -180,6 +181,8 @@ TEST_F(HiveConnectorTest, makeScanSpec_requiredSubfields_mergeMap) { {}, pool_.get()); auto* c0 = scanSpec->childByName("c0"); + ASSERT_EQ( + c0->flatMapFeatureSelection(), std::vector({"10", "20"})); auto* keysFilter = c0->childByName(ScanSpec::kMapKeysFieldName)->filter(); ASSERT_TRUE(keysFilter); ASSERT_TRUE(applyFilter(*keysFilter, 10)); @@ -206,6 +209,7 @@ TEST_F(HiveConnectorTest, makeScanSpec_requiredSubfields_allSubscripts) { {}, pool_.get()); auto* c0 = scanSpec->childByName("c0"); + ASSERT_TRUE(c0->flatMapFeatureSelection().empty()); ASSERT_FALSE(c0->childByName(ScanSpec::kMapKeysFieldName)->filter()); auto* values = c0->childByName(ScanSpec::kMapValuesFieldName); ASSERT_EQ( diff --git a/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp b/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp index 8f4ae79cac68..638ed28d91f2 100644 --- a/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorUtilTest.cpp @@ -205,4 +205,21 @@ TEST_F(HiveConnectorUtilTest, configureReaderOptions) { hiveConfig->filePreloadThreshold()); } +TEST_F(HiveConnectorUtilTest, configureRowReaderOptions) { + auto split = + std::make_shared("", "", FileFormat::UNKNOWN); + auto rowType = ROW({{"float_features", MAP(INTEGER(), REAL())}}); + auto spec = std::make_shared(""); + spec->addAllChildFields(*rowType); + auto* float_features = spec->childByName("float_features"); + float_features->childByName(common::ScanSpec::kMapKeysFieldName) + ->setFilter(common::createBigintValues({1, 3}, false)); + float_features->setFlatMapFeatureSelection({"1", "3"}); + RowReaderOptions options; + configureRowReaderOptions(options, {}, spec, nullptr, rowType, split); + auto& nodes = options.getSelector()->getProjection(); + ASSERT_EQ(nodes.size(), 1); + ASSERT_EQ(nodes[0].expression, "[1,3]"); +} + } // namespace facebook::velox::connector diff --git a/velox/dwio/common/ScanSpec.h b/velox/dwio/common/ScanSpec.h index e03b33963d0c..a75ff65496ff 100644 --- a/velox/dwio/common/ScanSpec.h +++ b/velox/dwio/common/ScanSpec.h @@ -324,6 +324,14 @@ class ScanSpec { // projected out. void addAllChildFields(const Type&); + const std::vector& flatMapFeatureSelection() const { + return flatMapFeatureSelection_; + } + + void setFlatMapFeatureSelection(std::vector features) { + flatMapFeatureSelection_ = std::move(features); + } + private: void reorder(); @@ -400,6 +408,9 @@ class ScanSpec { // Only take the first maxArrayElementsCount_ elements from each array. vector_size_t maxArrayElementsCount_ = std::numeric_limits::max(); + + // Used only for bulk reader to project flat map features. + std::vector flatMapFeatureSelection_; }; // Returns false if no value from a range defined by stats can pass the