diff --git a/velox/functions/sparksql/ConcatWs.cpp b/velox/functions/sparksql/ConcatWs.cpp index 8b745391cfdd0..ef88441b07a13 100644 --- a/velox/functions/sparksql/ConcatWs.cpp +++ b/velox/functions/sparksql/ConcatWs.cpp @@ -39,16 +39,9 @@ class ConcatWs : public exec::VectorFunction { const exec::LocalDecodedVector& decodedSeparator) const { auto arrayArgNum = decodedArrays.size(); std::vector arrayVectors; - std::vector elementsDecodedVectors; for (auto i = 0; i < arrayArgNum; ++i) { auto arrayVector = decodedArrays[i].get()->base()->as(); arrayVectors.push_back(arrayVector); - auto elements = arrayVector->elements(); - exec::LocalSelectivityVector nestedRows(context, elements->size()); - nestedRows.get()->setAll(); - exec::LocalDecodedVector elementsHolder( - context, *elements, *nestedRows.get()); - elementsDecodedVectors.push_back(elementsHolder.get()); } size_t totalResultBytes = 0; @@ -61,16 +54,15 @@ class ConcatWs : public exec::VectorFunction { // Calculate size for array columns data. for (int i = 0; i < arrayArgNum; i++) { auto arrayVector = arrayVectors[i]; - auto rawSizes = arrayVector->rawSizes(); - auto rawOffsets = arrayVector->rawOffsets(); auto indices = decodedArrays[i].get()->indices(); - auto elementsDecoded = elementsDecodedVectors[i]; + SelectivityVector nestedRows(arrayVector->elements()->size()); + DecodedVector elementsDecoded(*arrayVector->elements(), nestedRows); + auto size = arrayVector->sizeAt(indices[row]); + auto offset = arrayVector->offsetAt(indices[row]); - auto size = rawSizes[indices[row]]; - auto offset = rawOffsets[indices[row]]; for (int j = 0; j < size; ++j) { - if (!elementsDecoded->isNullAt(offset + j)) { - auto element = elementsDecoded->valueAt(offset + j); + if (!elementsDecoded.isNullAt(offset + j)) { + auto element = elementsDecoded.valueAt(offset + j); // No matter empty string or not. ++allElements; totalResultBytes += element.size(); @@ -209,16 +201,9 @@ class ConcatWs : public exec::VectorFunction { decodedSeparator); std::vector arrayVectors; - std::vector elementsDecodedVectors; for (auto i = 0; i < decodedArrays.size(); ++i) { auto arrayVector = decodedArrays[i].get()->base()->as(); arrayVectors.push_back(arrayVector); - auto elements = arrayVector->elements(); - exec::LocalSelectivityVector nestedRows(context, elements->size()); - nestedRows.get()->setAll(); - exec::LocalDecodedVector elementsHolder( - context, *elements, *nestedRows.get()); - elementsDecodedVectors.push_back(elementsHolder.get()); } // Allocate a string buffer. auto rawBuffer = @@ -256,16 +241,15 @@ class ConcatWs : public exec::VectorFunction { for (auto itArgs = args.begin() + 1; itArgs != args.end(); ++itArgs) { if ((*itArgs)->typeKind() == TypeKind::ARRAY) { auto arrayVector = arrayVectors[i]; - auto rawSizes = arrayVector->rawSizes(); - auto rawOffsets = arrayVector->rawOffsets(); auto indices = decodedArrays[i].get()->indices(); - auto elementsDecoded = elementsDecodedVectors[i]; + SelectivityVector nestedRows(arrayVector->elements()->size()); + DecodedVector elementsDecoded(*arrayVector->elements(), nestedRows); + auto size = arrayVector->sizeAt(indices[row]); + auto offset = arrayVector->offsetAt(indices[row]); - auto size = rawSizes[indices[row]]; - auto offset = rawOffsets[indices[row]]; for (int k = 0; k < size; ++k) { - if (!elementsDecoded->isNullAt(offset + k)) { - auto element = elementsDecoded->valueAt(offset + k); + if (!elementsDecoded.isNullAt(offset + k)) { + auto element = elementsDecoded.valueAt(offset + k); copyToBuffer( element, isConstantSeparator() diff --git a/velox/functions/sparksql/tests/ConcatWsTest.cpp b/velox/functions/sparksql/tests/ConcatWsTest.cpp index 3e51838b448bc..4cf6703c68f2b 100644 --- a/velox/functions/sparksql/tests/ConcatWsTest.cpp +++ b/velox/functions/sparksql/tests/ConcatWsTest.cpp @@ -127,7 +127,7 @@ TEST_F(ConcatWsTest, stringArgsWithNulls) { auto result = evaluate>( "concat_ws('~','',c0,'x',NULL::VARCHAR)", makeRowVector({input})); - auto expected = makeNullableFlatVector({ + auto expected = makeFlatVector({ "~~x", "~x", "~a~x", @@ -219,6 +219,13 @@ TEST_F(ConcatWsTest, arrayArgs) { "red--purple--green--red--purple--green", }); velox::test::assertEqualVectors(expected, result); + + // Constant arrays. + auto dummyInput = makeRowVector(makeRowType({VARCHAR()}), 1); + result = evaluate>( + "concat_ws('--', array['a','b','c'], array['d'])", dummyInput); + expected = makeFlatVector({"a--b--c--d"}); + velox::test::assertEqualVectors(expected, result); } TEST_F(ConcatWsTest, mixedStringAndArrayArgs) {