From 69a17337fdf4524f03eba5953cb44be3791d5522 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Fri, 13 Dec 2024 13:25:03 -0800 Subject: [PATCH] feat: Add Spark concat_ws function (#8854) Summary: Add concat_ws Spark function which returns the concatenation for the input, separated by a separator (the first argument). It allows variable number of VARCHAR or ARRAY\ arguments. And these two types can be used in combination. This function is a bit similar to [ConcatFunction](https://github.com/facebookincubator/velox/blob/main/velox/functions/prestosql/StringFunctions.cpp#L140), except that `concat_ws` requires separator and supports using ARRAY type and mixed types. This PR is based on https://github.com/facebookincubator/velox/pull/6292 (author: unigof). There are a few bug fixes and improvements. Also made some changes to align with Spark. Doc [link](https://docs.databricks.com/en/sql/language-manual/functions/concat_ws.html). Pull Request resolved: https://github.com/facebookincubator/velox/pull/8854 Reviewed By: kgpai Differential Revision: D66898251 Pulled By: bikramSingh91 fbshipit-source-id: 1fcd193a245bea4062c4e20d1e1db9ad6cc3290b --- velox/docs/functions/spark/string.rst | 19 + velox/expression/fuzzer/ExpressionFuzzer.cpp | 18 +- velox/functions/sparksql/CMakeLists.txt | 1 + velox/functions/sparksql/ConcatWs.cpp | 390 ++++++++++++++++++ velox/functions/sparksql/ConcatWs.h | 35 ++ .../sparksql/registration/RegisterString.cpp | 5 + velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../functions/sparksql/tests/ConcatWsTest.cpp | 305 ++++++++++++++ 8 files changed, 771 insertions(+), 3 deletions(-) create mode 100644 velox/functions/sparksql/ConcatWs.cpp create mode 100644 velox/functions/sparksql/ConcatWs.h create mode 100644 velox/functions/sparksql/tests/ConcatWsTest.cpp diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index 1bb01c65eae6..d959d9974ade 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -25,6 +25,25 @@ String Functions If ``n < 0``, the result is an empty string. If ``n >= 256``, the result is equivalent to chr(``n % 256``). +.. spark:function:: concat_ws(separator, [string/array], ...) -> varchar + + Returns the concatenation result for ``string`` and all elements in ``array``, separated + by ``separator``. The first argument is ``separator`` whose type is VARCHAR. Then, this function + can take variable number of remaining arguments , and it allows mixed use of ``string`` type and + ``array`` type. Skips NULL argument or NULL array element during the concatenation. If + ``separator`` is NULL, returns NULL, regardless of the following inputs. For non-NULL ``separator``, + if no remaining input exists or all remaining inputs are NULL, returns an empty string. :: + + SELECT concat_ws('~', 'a', 'b', 'c'); -- 'a~b~c' + SELECT concat_ws('~', ['a', 'b', 'c'], ['d']); -- 'a~b~c~d' + SELECT concat_ws('~', 'a', ['b', 'c']); -- 'a~b~c' + SELECT concat_ws('~', '', [''], ['a', '']); -- '~~a~' + SELECT concat_ws(NULL, 'a'); -- NULL + SELECT concat_ws('~'); -- '' + SELECT concat_ws('~', NULL, [NULL], 'a', 'b'); -- 'a~b' + SELECT concat_ws('~', NULL, NULL); -- '' + SELECT concat_ws('~', [NULL]); -- '' + .. spark:function:: contains(left, right) -> boolean Returns true if 'right' is found in 'left'. Otherwise, returns false. :: diff --git a/velox/expression/fuzzer/ExpressionFuzzer.cpp b/velox/expression/fuzzer/ExpressionFuzzer.cpp index 876501414501..3f381b5fccac 100644 --- a/velox/expression/fuzzer/ExpressionFuzzer.cpp +++ b/velox/expression/fuzzer/ExpressionFuzzer.cpp @@ -354,11 +354,23 @@ static void appendSpecialForms( }, { "cast", - /// TODO: Add supported Cast signatures to CastTypedExpr and - /// expose - /// them to fuzzer instead of hard-coding signatures here. + // TODO: Add supported Cast signatures to CastTypedExpr and + // expose + // them to fuzzer instead of hard-coding signatures here. getSignaturesForCast(), }, + { + // For Spark SQL only. + "concat_ws", + std::vector{ + // Signature: concat_ws (separator, input, ...) -> output: + // varchar, varchar, varchar, ... -> varchar + facebook::velox::exec::FunctionSignatureBuilder() + .argumentType("varchar") + .variableArity("varchar") + .returnType("varchar") + .build()}, + }, }; auto specialFormNames = splitNames(specialForms); diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index 5e2f5ad58271..2e60515c925c 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -18,6 +18,7 @@ velox_add_library( ArrayGetFunction.cpp ArraySort.cpp Comparisons.cpp + ConcatWs.cpp DecimalArithmetic.cpp DecimalCompare.cpp Hash.cpp diff --git a/velox/functions/sparksql/ConcatWs.cpp b/velox/functions/sparksql/ConcatWs.cpp new file mode 100644 index 000000000000..68877e6a52af --- /dev/null +++ b/velox/functions/sparksql/ConcatWs.cpp @@ -0,0 +1,390 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/ConcatWs.h" +#include "velox/expression/ConstantExpr.h" +#include "velox/expression/VectorFunction.h" + +namespace facebook::velox::functions::sparksql { + +namespace { +class ConcatWs : public exec::VectorFunction { + public: + explicit ConcatWs(const std::optional& separator) + : separator_(separator) {} + + bool isConstantSeparator() const { + return separator_.has_value(); + } + + // Calculate the total number of bytes in the result. + size_t calculateTotalResultBytes( + const SelectivityVector& rows, + exec::EvalCtx& context, + std::vector& decodedArrays, + const std::vector& decodedElements, + const std::vector>& constantStrings, + const std::vector& decodedStringArgs, + const exec::LocalDecodedVector& decodedSeparator) const { + uint64_t totalResultBytes = 0; + rows.applyToSelected([&](auto row) { + // NULL separator produces NULL result. + if (!isConstantSeparator() && decodedSeparator->isNullAt(row)) { + return; + } + int32_t allElements = 0; + // Calculate size for array columns data. + for (auto i = 0; i < decodedArrays.size(); i++) { + if (decodedArrays[i]->isNullAt(row)) { + continue; + } + auto indices = decodedArrays[i].get()->indices(); + auto arrayVector = decodedArrays[i].get()->base()->as(); + auto size = arrayVector->sizeAt(indices[row]); + auto offset = arrayVector->offsetAt(indices[row]); + + for (auto j = 0; j < size; ++j) { + if (!decodedElements[i].isNullAt(offset + j)) { + // No matter empty string or not. + ++allElements; + totalResultBytes += + decodedElements[i].valueAt(offset + j).size(); + } + } + } + + // Calculate size for string arg. + auto it = decodedStringArgs.begin(); + for (const auto& constantString : constantStrings) { + int32_t valueSize; + if (constantString.has_value()) { + valueSize = constantString->size(); + } else { + VELOX_CHECK( + it < decodedStringArgs.end(), + "Unexpected end when iterating over decodedStringArgs."); + // Skip NULL. + if ((*it)->isNullAt(row)) { + ++it; + continue; + } + valueSize = (*it++)->valueAt(row).size(); + } + // No matter empty string or not. + allElements++; + totalResultBytes += valueSize; + } + + const auto separatorSize = isConstantSeparator() + ? separator_.value().size() + : decodedSeparator->valueAt(row).size(); + + if (allElements > 1) { + totalResultBytes += (allElements - 1) * separatorSize; + } + }); + VELOX_USER_CHECK_LE(totalResultBytes, UINT32_MAX); + return totalResultBytes; + } + + // Initialize some vectors for inputs. And concatenate consecutive + // constant string arguments in advance. + // @param rows The rows to process. + // @param args The arguments to the function. + // @param context The evaluation context. + // @param decodedArrays The decoded vectors for array arguments. + // @param decodedElements The decoded vectors for array elements. + // @param argMapping The mapping of the string arguments. + // @param constantStrings The constant string arguments concatenated in + // advance. + // @param decodedStringArgs The decoded vectors for non-constant string + // arguments. + void initVectors( + const SelectivityVector& rows, + const std::vector& args, + const exec::EvalCtx& context, + std::vector& decodedArrays, + std::vector& decodedElements, + std::vector& argMapping, + std::vector>& constantStrings, + std::vector& decodedStringArgs) const { + for (auto i = 1; i < args.size(); ++i) { + if (args[i] && args[i]->typeKind() == TypeKind::ARRAY) { + decodedArrays.emplace_back(context, *args[i], rows); + auto arrayVector = + decodedArrays.back().get()->base()->as(); + SelectivityVector nestedRows(arrayVector->elements()->size()); + decodedElements.emplace_back(*arrayVector->elements(), nestedRows); + continue; + } + argMapping.push_back(i); + if (!isConstantSeparator()) { + // Cannot concat consecutive constant string args in advance. + constantStrings.push_back(std::nullopt); + continue; + } + if (args[i] && args[i]->as>() && + !args[i]->isNullAt(0)) { + std::ostringstream out; + out << args[i]->as>()->valueAt(0); + column_index_t j = i + 1; + // Concat constant string args in advance. + for (; j < args.size(); ++j) { + if (!args[j] || args[j]->typeKind() == TypeKind::ARRAY || + !args[j]->as>() || + args[j]->isNullAt(0)) { + break; + } + out << separator_.value() + << args[j]->as>()->valueAt(0); + } + constantStrings.emplace_back(out.str()); + i = j - 1; + } else { + constantStrings.push_back(std::nullopt); + } + } + + for (auto i = 0; i < constantStrings.size(); ++i) { + if (!constantStrings[i].has_value()) { + auto index = argMapping[i]; + decodedStringArgs.emplace_back(context, *args[index], rows); + } + } + } + + // ConcatWs implementation. It concatenates the arguments with the separator. + // Mixed using of VARCHAR & ARRAY is considered. If separator is + // constant, concatenate consecutive constant string args in advance. Then, + // concatenate the intermediate result with neighboring array args or + // non-constant string args. + void doApply( + const SelectivityVector& rows, + std::vector& args, + exec::EvalCtx& context, + VectorPtr& result) const { + auto& flatResult = *result->asFlatVector(); + // Holds string arg indexes. + std::vector argMapping; + std::vector> constantStrings; + const auto numArgs = args.size(); + argMapping.reserve(numArgs - 1); + // Save intermediate result for consecutive constant string args. + // They will be concatenated in advance. + constantStrings.reserve(numArgs - 1); + std::vector decodedArrays; + decodedArrays.reserve(numArgs - 1); + // For column string arg decoding. + std::vector decodedStringArgs; + decodedStringArgs.reserve(numArgs); + + std::vector decodedElements; + decodedElements.reserve(numArgs - 1); + initVectors( + rows, + args, + context, + decodedArrays, + decodedElements, + argMapping, + constantStrings, + decodedStringArgs); + exec::LocalDecodedVector decodedSeparator(context); + if (!isConstantSeparator()) { + decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows); + } + + const auto totalResultBytes = calculateTotalResultBytes( + rows, + context, + decodedArrays, + decodedElements, + constantStrings, + decodedStringArgs, + decodedSeparator); + + // Allocate a string buffer. + auto rawBuffer = + flatResult.getRawStringBufferWithSpace(totalResultBytes, true); + rows.applyToSelected([&](auto row) { + // NULL separator produces NULL result. + if (!isConstantSeparator() && decodedSeparator->isNullAt(row)) { + result->setNull(row, true); + return; + } + uint32_t bufferOffset = 0; + auto isFirst = true; + // For array arg. + int32_t i = 0; + // For string arg. + int32_t j = 0; + auto it = decodedStringArgs.begin(); + + const auto separator = isConstantSeparator() + ? separator_.value() + : decodedSeparator->valueAt(row); + + const auto copyToBuffer = [&](const char* value, const size_t valueSize) { + if (isFirst) { + isFirst = false; + } else { + // Add separator before the current value. + memcpy(rawBuffer + bufferOffset, separator.data(), separator.size()); + bufferOffset += separator.size(); + } + memcpy(rawBuffer + bufferOffset, value, valueSize); + bufferOffset += valueSize; + }; + + for (auto itArgs = args.begin() + 1; itArgs != args.end(); ++itArgs) { + if ((*itArgs)->typeKind() == TypeKind::ARRAY) { + if ((*itArgs)->isNullAt(row)) { + ++i; + continue; + } + auto indices = decodedArrays[i].get()->indices(); + auto arrayVector = decodedArrays[i].get()->base()->as(); + auto size = arrayVector->sizeAt(indices[row]); + auto offset = arrayVector->offsetAt(indices[row]); + + for (auto k = 0; k < size; ++k) { + if (!decodedElements[i].isNullAt(offset + k)) { + auto element = decodedElements[i].valueAt(offset + k); + copyToBuffer(element.data(), element.size()); + } + } + ++i; + continue; + } + + if (j >= constantStrings.size()) { + continue; + } + + if (constantStrings[j].has_value()) { + copyToBuffer(constantStrings[j]->data(), constantStrings[j]->size()); + } else { + VELOX_CHECK( + it < decodedStringArgs.end(), + "Unexpected end when iterating over decodedStringArgs."); + // Skip NULL. + if ((*it)->isNullAt(row)) { + ++it; + ++j; + continue; + } + const auto value = (*it++)->valueAt(row); + copyToBuffer(value.data(), value.size()); + } + ++j; + } + VELOX_USER_CHECK_LE(bufferOffset, INT32_MAX); + flatResult.setNoCopy(row, StringView(rawBuffer, bufferOffset)); + rawBuffer += bufferOffset; + }); + } + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + context.ensureWritable(rows, VARCHAR(), result); + auto flatResult = result->asFlatVector(); + const auto numArgs = args.size(); + // If separator is NULL, result is NULL. + if (isConstantSeparator() && args[0]->isNullAt(0)) { + auto localResult = BaseVector::createNullConstant( + outputType, rows.end(), context.pool()); + context.moveOrCopyResult(localResult, rows, result); + return; + } + // If only separator (not a NULL) is provided, result is an empty string. + if (numArgs == 1) { + auto decodedSeparator = exec::LocalDecodedVector(context, *args[0], rows); + // 1. Separator is constant and not a NULL. + // 2. Separator is column and have no NULL. + if (isConstantSeparator() || !decodedSeparator->mayHaveNulls()) { + auto localResult = BaseVector::createConstant( + VARCHAR(), "", rows.end(), context.pool()); + context.moveOrCopyResult(localResult, rows, result); + } else { + rows.applyToSelected([&](auto row) { + if (decodedSeparator->isNullAt(row)) { + result->setNull(row, true); + } else { + flatResult->setNoCopy(row, StringView("")); + } + }); + } + return; + } + doApply(rows, args, context, result); + } + + private: + // For holding constant separator. + const std::optional separator_; +}; +} // namespace + +TypePtr ConcatWsCallToSpecialForm::resolveType( + const std::vector& /*argTypes*/) { + return VARCHAR(); +} + +exec::ExprPtr ConcatWsCallToSpecialForm::constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& config) { + auto numArgs = args.size(); + VELOX_USER_CHECK_GE( + numArgs, + 1, + "concat_ws requires one arguments at least, but got {}.", + numArgs); + VELOX_USER_CHECK( + args[0]->type()->isVarchar(), + "The first argument of concat_ws must be a varchar."); + for (const auto& arg : args) { + VELOX_USER_CHECK( + arg->type()->isVarchar() || + (arg->type()->isArray() && + arg->type()->asArray().elementType()->isVarchar()), + "The 2nd and following arguments for concat_ws should be varchar or array(varchar), but got {}.", + arg->type()->toString()); + } + + std::optional separator = std::nullopt; + auto constantExpr = std::dynamic_pointer_cast(args[0]); + + if (constantExpr) { + separator = constantExpr->value() + ->asUnchecked>() + ->valueAt(0); + } + auto concatWsFunction = std::make_shared(separator); + return std::make_shared( + type, + std::move(args), + std::move(concatWsFunction), + exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(), + kConcatWs, + trackCpuUsage); +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/ConcatWs.h b/velox/functions/sparksql/ConcatWs.h new file mode 100644 index 000000000000..a0b6d38dd22c --- /dev/null +++ b/velox/functions/sparksql/ConcatWs.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/expression/FunctionCallToSpecialForm.h" + +namespace facebook::velox::functions::sparksql { + +class ConcatWsCallToSpecialForm : public exec::FunctionCallToSpecialForm { + public: + TypePtr resolveType(const std::vector& argTypes) override; + + exec::ExprPtr constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& config) override; + + static constexpr const char* kConcatWs = "concat_ws"; +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterString.cpp b/velox/functions/sparksql/registration/RegisterString.cpp index 38a01bc4941d..007b77868a1a 100644 --- a/velox/functions/sparksql/registration/RegisterString.cpp +++ b/velox/functions/sparksql/registration/RegisterString.cpp @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/expression/SpecialFormRegistry.h" #include "velox/functions/lib/Re2Functions.h" #include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/prestosql/URLFunctions.h" +#include "velox/functions/sparksql/ConcatWs.h" #include "velox/functions/sparksql/MaskFunction.h" #include "velox/functions/sparksql/Split.h" #include "velox/functions/sparksql/String.h" @@ -142,6 +144,9 @@ void registerStringFunctions(const std::string& prefix) { Varchar, Varchar, Varchar>({prefix + "mask"}); + registerFunctionCallToSpecialForm( + ConcatWsCallToSpecialForm::kConcatWs, + std::make_unique()); } } // namespace sparksql } // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index d17a9e0ed855..24772505d82d 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -26,6 +26,7 @@ add_executable( AtLeastNNonNullsTest.cpp BitwiseTest.cpp ComparisonsTest.cpp + ConcatWsTest.cpp DateTimeFunctionsTest.cpp DecimalArithmeticTest.cpp DecimalCompareTest.cpp diff --git a/velox/functions/sparksql/tests/ConcatWsTest.cpp b/velox/functions/sparksql/tests/ConcatWsTest.cpp new file mode 100644 index 000000000000..63de9ce0255c --- /dev/null +++ b/velox/functions/sparksql/tests/ConcatWsTest.cpp @@ -0,0 +1,305 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/Type.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class ConcatWsTest : public SparkFunctionBaseTest { + protected: + std::string generateRandomString(size_t length) { + const std::string chars = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::string randomString; + for (std::size_t i = 0; i < length; ++i) { + randomString += chars[folly::Random::rand32() % chars.size()]; + } + return randomString; + } + + void testConcatWsFlatVector( + const std::vector>& inputTable, + const size_t& argsCount, + const std::string& separator) { + std::vector inputVectors; + + for (auto i = 0; i < argsCount; i++) { + inputVectors.emplace_back( + BaseVector::create(VARCHAR(), inputTable.size(), execCtx_.pool())); + } + + for (auto row = 0; row < inputTable.size(); row++) { + for (auto col = 0; col < argsCount; col++) { + std::static_pointer_cast>(inputVectors[col]) + ->set(row, StringView(inputTable[row][col])); + } + } + + auto buildConcatQuery = [&]() { + std::string output = "concat_ws('" + separator + "'"; + + for (auto i = 0; i < argsCount; i++) { + output += ",c" + std::to_string(i); + } + output += ")"; + return output; + }; + auto result = evaluate>( + buildConcatQuery(), makeRowVector(inputVectors)); + + auto produceExpectedResult = [&](const std::vector& inputs) { + auto isFirst = true; + std::string output; + for (auto i = 0; i < inputs.size(); i++) { + auto value = inputs[i]; + if (isFirst) { + isFirst = false; + } else { + output += separator; + } + output += value; + } + return output; + }; + + for (auto i = 0; i < inputTable.size(); ++i) { + EXPECT_EQ(result->valueAt(i), produceExpectedResult(inputTable[i])) + << "at " << i; + } + } +}; + +TEST_F(ConcatWsTest, stringArgs) { + // Test with constant args. + auto rows = makeRowVector(makeRowType({VARCHAR(), VARCHAR()}), 10); + auto c0 = generateRandomString(20); + auto c1 = generateRandomString(20); + auto result = evaluate>( + fmt::format("concat_ws('-', '{}', '{}')", c0, c1), rows); + for (auto i = 0; i < 10; ++i) { + EXPECT_EQ(result->valueAt(i), c0 + "-" + c1); + } + + // Test with variable arguments. + const size_t maxArgsCount = 10; + const size_t rowCount = 100; + const size_t maxStringLength = 100; + + std::vector> inputTable; + for (auto argsCount = 1; argsCount <= maxArgsCount; argsCount++) { + inputTable.clear(); + inputTable.resize(rowCount, std::vector(argsCount)); + + for (auto row = 0; row < rowCount; row++) { + for (auto col = 0; col < argsCount; col++) { + inputTable[row][col] = + generateRandomString(folly::Random::rand32() % maxStringLength); + } + } + + SCOPED_TRACE(fmt::format("Number of arguments: {}", argsCount)); + testConcatWsFlatVector(inputTable, argsCount, "--testSep--"); + // Test with empty separator. + testConcatWsFlatVector(inputTable, argsCount, ""); + } +} + +TEST_F(ConcatWsTest, stringArgsWithNulls) { + auto input = + makeNullableFlatVector({"", std::nullopt, "a", "*", "b"}); + + auto result = evaluate>( + "concat_ws('~','',c0,'x',NULL::VARCHAR)", makeRowVector({input})); + auto expected = makeFlatVector({ + "~~x", + "~x", + "~a~x", + "~*~x", + "~b~x", + }); + velox::test::assertEqualVectors(expected, result); +} + +TEST_F(ConcatWsTest, mixedConstantAndNonConstantStringArgs) { + size_t maxStringLength = 100; + std::string value; + auto data = makeRowVector({ + makeFlatVector( + 1'000, + [&](auto /* row */) { + value = + generateRandomString(folly::Random::rand32() % maxStringLength); + return StringView(value); + }), + makeFlatVector( + 1'000, + [&](auto /* row */) { + value = + generateRandomString(folly::Random::rand32() % maxStringLength); + return StringView(value); + }), + }); + + auto c0 = data->childAt(0)->as>()->rawValues(); + auto c1 = data->childAt(1)->as>()->rawValues(); + + // Test with consecutive constant inputs. + auto result = evaluate>( + "concat_ws('--', c0, c1, 'foo', 'bar')", data); + + auto expected = makeFlatVector(1'000, [&](auto row) { + const std::string& s0 = c0[row].str(); + const std::string& s1 = c1[row].str(); + value = s0 + "--" + s1 + "--foo--bar"; + return StringView(value); + }); + velox::test::assertEqualVectors(expected, result); + + // Test with non-ASCII characters. + result = evaluate>( + "concat_ws('$*@', 'aaa', 'åæ', c0, 'eee', 'ddd', c1, '\u82f9\u679c', 'fff')", + data); + expected = makeFlatVector(1'000, [&](auto row) { + std::string delim = "$*@"; + + value = "aaa" + delim + "åæ" + delim + c0[row].str() + delim + "eee" + + delim + "ddd" + delim + c1[row].str() + delim + "\u82f9\u679c" + delim + + "fff"; + return StringView(value); + }); + velox::test::assertEqualVectors(expected, result); +} + +TEST_F(ConcatWsTest, arrayArgs) { + auto arrayVector = makeNullableArrayVector({ + {"red", "blue"}, + {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, + {}, + {std::nullopt}, + {"red", "purple", "green"}, + }); + + // One array arg. + auto result = evaluate>( + "concat_ws('--', c0)", makeRowVector({arrayVector})); + auto expected = makeFlatVector({ + "red--blue", + "blue--yellow--orange", + "", + "", + "red--purple--green", + }); + velox::test::assertEqualVectors(expected, result); + + // Two array args. + result = evaluate>( + "concat_ws('--', c0, c1)", makeRowVector({arrayVector, arrayVector})); + expected = makeFlatVector({ + "red--blue--red--blue", + "blue--yellow--orange--blue--yellow--orange", + "", + "", + "red--purple--green--red--purple--green", + }); + velox::test::assertEqualVectors(expected, result); + + // Constant arrays. + auto dummyInput = makeRowVector(makeRowType({VARCHAR()}), 1); + result = evaluate>( + "concat_ws('--', array['a','b','c'], array['d'])", dummyInput); + expected = makeFlatVector({"a--b--c--d"}); + velox::test::assertEqualVectors(expected, result); +} + +TEST_F(ConcatWsTest, mixedStringAndArrayArgs) { + auto arrayVector = makeNullableArrayVector({ + {"red", "blue"}, + {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, + {}, + {std::nullopt}, + {"red", "purple", "green"}, + {""}, + {"", "green"}, + }); + + auto result = evaluate>( + "concat_ws('--', c0, 'foo', c1, 'bar', 'end', '')", + makeRowVector({arrayVector, arrayVector})); + // Empty string is also concatenated with its neighboring inputs, + // separated by given separator. + auto expected = makeFlatVector({ + "red--blue--foo--red--blue--bar--end--", + "blue--yellow--orange--foo--blue--yellow--orange--bar--end--", + "foo--bar--end--", + "foo--bar--end--", + "red--purple--green--foo--red--purple--green--bar--end--", + "--foo----bar--end--", + "--green--foo----green--bar--end--", + }); + velox::test::assertEqualVectors(expected, result); +} + +TEST_F(ConcatWsTest, nonConstantSeparator) { + auto separatorVector = makeNullableFlatVector( + {"##", "--", "~~", "**", std::nullopt}); + auto arrayVector = makeNullableArrayVector({ + {"red", "blue"}, + {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, + {"red", "blue"}, + {"blue", std::nullopt, "yellow", std::nullopt, "orange"}, + {"red", "purple", "green"}, + }); + + auto result = evaluate>( + "concat_ws(c0, c1, '|')", makeRowVector({separatorVector, arrayVector})); + auto expected = makeNullableFlatVector({ + "red##blue##|", + "blue--yellow--orange--|", + "red~~blue~~|", + "blue**yellow**orange**|", + std::nullopt, + }); + velox::test::assertEqualVectors(expected, result); +} + +TEST_F(ConcatWsTest, separatorOnly) { + auto separatorVector = makeNullableFlatVector( + {"##", std::nullopt, "~~", "**", std::nullopt}); + auto result = evaluate>( + "concat_ws(c0)", makeRowVector({separatorVector})); + auto expected = makeNullableFlatVector({ + "", + std::nullopt, + "", + "", + std::nullopt, + }); + velox::test::assertEqualVectors(expected, result); + + // Uses constant separator. + auto dummyInput = makeRowVector(makeRowType({VARCHAR()}), 1); + result = evaluate>( + "concat_ws(NULL::VARCHAR)", dummyInput); + EXPECT_TRUE(result->isNullAt(0)); + result = evaluate>("concat_ws('-')", dummyInput); + EXPECT_EQ(result->valueAt(0), ""); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test