diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index d104aae1dc5a2..d959d9974ade7 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -28,11 +28,11 @@ String Functions .. spark:function:: concat_ws(separator, [string/array], ...) -> varchar Returns the concatenation result for ``string`` and all elements in ``array``, separated - by ``separator``. The type of ``separator`` is VARCHAR. It can take variable number of remaining - arguments, and it allows mixed use of ``string`` and ``array``. Skips NULL argument or - NULL array element during the concatenation. If ``separator`` is NULL, returns NULL, regardless - of the following inputs. For non-NULL ``separator``, if no remaining input or all remaining inputs - are NULL, returns an empty string. :: + by ``separator``. The first argument is ``separator`` whose type is VARCHAR. Then, this function + can take variable number of remaining arguments , and it allows mixed use of ``string`` type and + ``array`` type. Skips NULL argument or NULL array element during the concatenation. If + ``separator`` is NULL, returns NULL, regardless of the following inputs. For non-NULL ``separator``, + if no remaining input exists or all remaining inputs are NULL, returns an empty string. :: SELECT concat_ws('~', 'a', 'b', 'c'); -- 'a~b~c' SELECT concat_ws('~', ['a', 'b', 'c'], ['d']); -- 'a~b~c~d' diff --git a/velox/functions/sparksql/ConcatWs.cpp b/velox/functions/sparksql/ConcatWs.cpp index 0bf9db91f96d4..68877e6a52af0 100644 --- a/velox/functions/sparksql/ConcatWs.cpp +++ b/velox/functions/sparksql/ConcatWs.cpp @@ -99,8 +99,18 @@ class ConcatWs : public exec::VectorFunction { return totalResultBytes; } - // Initialize vectors to hold decoded inputs. Concatenate consecutive constant - // string args in advance. + // Initialize some vectors for inputs. And concatenate consecutive + // constant string arguments in advance. + // @param rows The rows to process. + // @param args The arguments to the function. + // @param context The evaluation context. + // @param decodedArrays The decoded vectors for array arguments. + // @param decodedElements The decoded vectors for array elements. + // @param argMapping The mapping of the string arguments. + // @param constantStrings The constant string arguments concatenated in + // advance. + // @param decodedStringArgs The decoded vectors for non-constant string + // arguments. void initVectors( const SelectivityVector& rows, const std::vector& args, @@ -249,7 +259,7 @@ class ConcatWs : public exec::VectorFunction { auto size = arrayVector->sizeAt(indices[row]); auto offset = arrayVector->offsetAt(indices[row]); - for (int k = 0; k < size; ++k) { + for (auto k = 0; k < size; ++k) { if (!decodedElements[i].isNullAt(offset + k)) { auto element = decodedElements[i].valueAt(offset + k); copyToBuffer(element.data(), element.size()); diff --git a/velox/functions/sparksql/tests/ConcatWsTest.cpp b/velox/functions/sparksql/tests/ConcatWsTest.cpp index 4cf6703c68f2b..63de9ce0255c4 100644 --- a/velox/functions/sparksql/tests/ConcatWsTest.cpp +++ b/velox/functions/sparksql/tests/ConcatWsTest.cpp @@ -13,11 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" #include "velox/type/Type.h" -#include - namespace facebook::velox::functions::sparksql::test { namespace { @@ -36,17 +35,17 @@ class ConcatWsTest : public SparkFunctionBaseTest { void testConcatWsFlatVector( const std::vector>& inputTable, - const size_t argsCount, + const size_t& argsCount, const std::string& separator) { std::vector inputVectors; - for (int i = 0; i < argsCount; i++) { + for (auto i = 0; i < argsCount; i++) { inputVectors.emplace_back( BaseVector::create(VARCHAR(), inputTable.size(), execCtx_.pool())); } - for (int row = 0; row < inputTable.size(); row++) { - for (int col = 0; col < argsCount; col++) { + for (auto row = 0; row < inputTable.size(); row++) { + for (auto col = 0; col < argsCount; col++) { std::static_pointer_cast>(inputVectors[col]) ->set(row, StringView(inputTable[row][col])); } @@ -55,7 +54,7 @@ class ConcatWsTest : public SparkFunctionBaseTest { auto buildConcatQuery = [&]() { std::string output = "concat_ws('" + separator + "'"; - for (int i = 0; i < argsCount; i++) { + for (auto i = 0; i < argsCount; i++) { output += ",c" + std::to_string(i); } output += ")"; @@ -67,7 +66,7 @@ class ConcatWsTest : public SparkFunctionBaseTest { auto produceExpectedResult = [&](const std::vector& inputs) { auto isFirst = true; std::string output; - for (int i = 0; i < inputs.size(); i++) { + for (auto i = 0; i < inputs.size(); i++) { auto value = inputs[i]; if (isFirst) { isFirst = false; @@ -79,7 +78,7 @@ class ConcatWsTest : public SparkFunctionBaseTest { return output; }; - for (int i = 0; i < inputTable.size(); ++i) { + for (auto i = 0; i < inputTable.size(); ++i) { EXPECT_EQ(result->valueAt(i), produceExpectedResult(inputTable[i])) << "at " << i; } @@ -93,22 +92,22 @@ TEST_F(ConcatWsTest, stringArgs) { auto c1 = generateRandomString(20); auto result = evaluate>( fmt::format("concat_ws('-', '{}', '{}')", c0, c1), rows); - for (int i = 0; i < 10; ++i) { + for (auto i = 0; i < 10; ++i) { EXPECT_EQ(result->valueAt(i), c0 + "-" + c1); } // Test with variable arguments. - size_t maxArgsCount = 10; - size_t rowCount = 100; - size_t maxStringLength = 100; + const size_t maxArgsCount = 10; + const size_t rowCount = 100; + const size_t maxStringLength = 100; std::vector> inputTable; - for (int argsCount = 1; argsCount <= maxArgsCount; argsCount++) { + for (auto argsCount = 1; argsCount <= maxArgsCount; argsCount++) { inputTable.clear(); inputTable.resize(rowCount, std::vector(argsCount)); - for (int row = 0; row < rowCount; row++) { - for (int col = 0; col < argsCount; col++) { + for (auto row = 0; row < rowCount; row++) { + for (auto col = 0; col < argsCount; col++) { inputTable[row][col] = generateRandomString(folly::Random::rand32() % maxStringLength); } @@ -137,7 +136,7 @@ TEST_F(ConcatWsTest, stringArgsWithNulls) { velox::test::assertEqualVectors(expected, result); } -TEST_F(ConcatWsTest, mixedConstantAndNonconstantStringArgs) { +TEST_F(ConcatWsTest, mixedConstantAndNonConstantStringArgs) { size_t maxStringLength = 100; std::string value; auto data = makeRowVector({ @@ -256,7 +255,7 @@ TEST_F(ConcatWsTest, mixedStringAndArrayArgs) { velox::test::assertEqualVectors(expected, result); } -TEST_F(ConcatWsTest, nonconstantSeparator) { +TEST_F(ConcatWsTest, nonConstantSeparator) { auto separatorVector = makeNullableFlatVector( {"##", "--", "~~", "**", std::nullopt}); auto arrayVector = makeNullableArrayVector({