diff --git a/velox/connectors/hive/HiveConnectorUtil.cpp b/velox/connectors/hive/HiveConnectorUtil.cpp index 8189a92b812f..ecdf6d2c48c2 100644 --- a/velox/connectors/hive/HiveConnectorUtil.cpp +++ b/velox/connectors/hive/HiveConnectorUtil.cpp @@ -319,6 +319,19 @@ void checkColumnNameLowerCase(const core::TypedExprPtr& typeExpr) { } } +namespace { + +void filterOutNullMapKeys(const Type& rootType, common::ScanSpec& rootSpec) { + rootSpec.visit(rootType, [](const Type& type, common::ScanSpec& spec) { + if (type.isMap()) { + spec.childByName(common::ScanSpec::kMapKeysFieldName) + ->addFilter(common::IsNotNull()); + } + }); +} + +} // namespace + std::shared_ptr makeScanSpec( const RowTypePtr& rowType, const folly::F14FastMap>& @@ -348,7 +361,8 @@ std::shared_ptr makeScanSpec( auto& type = rowType->childAt(i); auto it = outputSubfields.find(name); if (it == outputSubfields.end()) { - spec->addFieldRecursively(name, *type, i); + auto* fieldSpec = spec->addFieldRecursively(name, *type, i); + filterOutNullMapKeys(*type, *fieldSpec); filterSubfields.erase(name); continue; } @@ -362,7 +376,9 @@ std::shared_ptr makeScanSpec( } filterSubfields.erase(it); } - addSubfields(*type, subfieldSpecs, 1, pool, *spec->addField(name, i)); + auto* fieldSpec = spec->addField(name, i); + addSubfields(*type, subfieldSpecs, 1, pool, *fieldSpec); + filterOutNullMapKeys(*type, *fieldSpec); subfieldSpecs.clear(); } @@ -376,6 +392,7 @@ std::shared_ptr makeScanSpec( auto& type = dataColumns->findChild(fieldName); auto* fieldSpec = spec->getOrCreateChild(common::Subfield(fieldName)); addSubfields(*type, subfieldSpecs, 1, pool, *fieldSpec); + filterOutNullMapKeys(*type, *fieldSpec); subfieldSpecs.clear(); } } diff --git a/velox/connectors/hive/tests/HiveConnectorTest.cpp b/velox/connectors/hive/tests/HiveConnectorTest.cpp index a21d98210641..3daa1d792dff 100644 --- a/velox/connectors/hive/tests/HiveConnectorTest.cpp +++ b/velox/connectors/hive/tests/HiveConnectorTest.cpp @@ -63,6 +63,11 @@ groupSubfields(const std::vector& subfields) { return grouped; } +bool mapKeyIsNotNull(const ScanSpec& mapSpec) { + return dynamic_cast( + mapSpec.childByName(ScanSpec::kMapKeysFieldName)->filter()); +} + TEST_F(HiveConnectorTest, hiveConfig) { ASSERT_EQ( HiveConfig::insertExistingPartitionsBehaviorString( @@ -210,7 +215,7 @@ TEST_F(HiveConnectorTest, makeScanSpec_requiredSubfields_allSubscripts) { pool_.get()); auto* c0 = scanSpec->childByName("c0"); ASSERT_TRUE(c0->flatMapFeatureSelection().empty()); - ASSERT_FALSE(c0->childByName(ScanSpec::kMapKeysFieldName)->filter()); + ASSERT_TRUE(mapKeyIsNotNull(*c0)); auto* values = c0->childByName(ScanSpec::kMapValuesFieldName); ASSERT_EQ( values->maxArrayElementsCount(), @@ -229,7 +234,7 @@ TEST_F(HiveConnectorTest, makeScanSpec_requiredSubfields_allSubscripts) { {}, pool_.get()); auto* c0 = scanSpec->childByName("c0"); - ASSERT_FALSE(c0->childByName(ScanSpec::kMapKeysFieldName)->filter()); + ASSERT_TRUE(mapKeyIsNotNull(*c0)); auto* values = c0->childByName(ScanSpec::kMapValuesFieldName); ASSERT_EQ( values->maxArrayElementsCount(), diff --git a/velox/dwio/common/ScanSpec.h b/velox/dwio/common/ScanSpec.h index e980baaf9020..9d998cce896a 100644 --- a/velox/dwio/common/ScanSpec.h +++ b/velox/dwio/common/ScanSpec.h @@ -322,6 +322,10 @@ class ScanSpec { flatMapFeatureSelection_ = std::move(features); } + /// Invoke the function provided on each node of the ScanSpec tree. + template + void visit(const Type& type, F&& f); + private: void reorder(); @@ -403,6 +407,31 @@ class ScanSpec { std::vector flatMapFeatureSelection_; }; +template +void ScanSpec::visit(const Type& type, F&& f) { + f(type, *this); + switch (type.kind()) { + case TypeKind::ROW: + for (auto& child : children_) { + VELOX_CHECK_NE(child->channel(), kNoChannel); + child->visit(*type.childAt(child->channel()), std::forward(f)); + } + break; + case TypeKind::MAP: + childByName(kMapKeysFieldName) + ->visit(*type.childAt(0), std::forward(f)); + childByName(kMapValuesFieldName) + ->visit(*type.childAt(1), std::forward(f)); + break; + case TypeKind::ARRAY: + childByName(kArrayElementsFieldName) + ->visit(*type.childAt(0), std::forward(f)); + break; + default: + break; + } +} + // Returns false if no value from a range defined by stats can pass the // filter. True, otherwise. bool testFilter( diff --git a/velox/exec/tests/TableScanTest.cpp b/velox/exec/tests/TableScanTest.cpp index b79c75751c06..624d18619ec2 100644 --- a/velox/exec/tests/TableScanTest.cpp +++ b/velox/exec/tests/TableScanTest.cpp @@ -975,6 +975,23 @@ TEST_F(TableScanTest, subfieldPruningArrayType) { } } +TEST_F(TableScanTest, skipNullMapKeys) { + auto vector = makeRowVector({makeMapVector( + {0, 2}, + makeNullableFlatVector({std::nullopt, 2}), + makeFlatVector({1, 2}))}); + auto rowType = asRowType(vector->type()); + auto filePath = TempFilePath::create(); + writeToFile(filePath->getPath(), {vector}); + auto plan = PlanBuilder().tableScan(rowType).planNode(); + auto split = makeHiveConnectorSplit(filePath->getPath()); + auto expected = makeRowVector({makeMapVector( + {0, 1}, + makeNullableFlatVector(std::vector>(1, 2)), + makeFlatVector(std::vector(1, 2)))}); + AssertQueryBuilder(plan).split(split).assertResults(expected); +} + // Test reading files written before schema change, e.g. missing newly added // columns. TEST_F(TableScanTest, missingColumns) { diff --git a/velox/functions/prestosql/aggregates/tests/MaxSizeForStatsTest.cpp b/velox/functions/prestosql/aggregates/tests/MaxSizeForStatsTest.cpp index 8ba6c0d95837..5a79f23a01ef 100644 --- a/velox/functions/prestosql/aggregates/tests/MaxSizeForStatsTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MaxSizeForStatsTest.cpp @@ -218,7 +218,7 @@ TEST_F(MaxSizeForStatsTest, complexRecursiveGlobalAggregate) { createMapOfArraysVector({ {{1, std::nullopt}}, {{2, {{4, 5, std::nullopt}}}}, - {{std::nullopt, {{7, 8, 9}}}}, + {{3, {{7, 8, 9}}}}, }), }), })}; @@ -261,7 +261,7 @@ TEST_F(MaxSizeForStatsTest, dictionaryEncodingTest) { createMapOfArraysVector({ {{1, std::nullopt}}, {{2, {{4, 5, std::nullopt}}}}, - {{std::nullopt, {{7, 8, 9}}}}, + {{3, {{7, 8, 9}}}}, }), }); vector_size_t size = 3; diff --git a/velox/functions/prestosql/aggregates/tests/SumDataSizeForStatsTest.cpp b/velox/functions/prestosql/aggregates/tests/SumDataSizeForStatsTest.cpp index eab51e9e6a66..3c3b05893632 100644 --- a/velox/functions/prestosql/aggregates/tests/SumDataSizeForStatsTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/SumDataSizeForStatsTest.cpp @@ -212,7 +212,7 @@ TEST_F(SumDataSizeForStatsTest, complexRecursiveGlobalAggregate) { createMapOfArraysVector({ {{1, std::nullopt}}, {{2, {{4, 5, std::nullopt}}}}, - {{std::nullopt, {{7, 8, 9}}}}, + {{3, {{7, 8, 9}}}}, }), }), })}; @@ -256,7 +256,7 @@ TEST_F(SumDataSizeForStatsTest, dictionaryEncodingTest) { createMapOfArraysVector({ {{1, std::nullopt}}, {{2, {{4, 5, std::nullopt}}}}, - {{std::nullopt, {{7, 8, 9}}}}, + {{3, {{7, 8, 9}}}}, }), }); vector_size_t size = 3;