Skip to content

Commit

Permalink
[GLUTEN-6856][CH]Support arrays_overlap and fix array_join diff (apac…
Browse files Browse the repository at this point in the history
…he#6857)

* Fix orc timezone read

* on test

* fix ci

* fix ci

* ci build error

* support arrays_overlap

* remove useless code

* fix ci

* fix ci

* fix array_join diff

* remove useless code

* remove useless code

* solve conflict

* ci fix

* remove useless code

* use default impl for array join

* remove useless code
  • Loading branch information
KevinyhZou authored Oct 9, 2024
1 parent 849b5d7 commit 92d3793
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,6 @@ case class EncodeDecodeValidator() extends FunctionValidator {
}
}

case class ArrayJoinValidator() extends FunctionValidator {
override def doValidate(expr: Expression): Boolean = expr match {
case t: ArrayJoin => !t.children.head.isInstanceOf[Literal]
case _ => true
}
}

case class FormatStringValidator() extends FunctionValidator {
override def doValidate(expr: Expression): Boolean = {
val formatString = expr.asInstanceOf[FormatString]
Expand All @@ -181,13 +174,11 @@ object CHExpressionUtil {
)

final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map(
ARRAY_JOIN -> ArrayJoinValidator(),
SPLIT_PART -> DefaultValidator(),
TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(),
UNIX_TIMESTAMP -> UnixTimeStampValidator(),
SEQUENCE -> SequenceValidator(),
GET_JSON_OBJECT -> GetJsonObjectValidator(),
ARRAYS_OVERLAP -> DefaultValidator(),
SPLIT -> StringSplitValidator(),
SUBSTRING_INDEX -> SubstringIndexValidator(),
LPAD -> StringLPadValidator(),
Expand Down
135 changes: 41 additions & 94 deletions cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,69 +46,28 @@ class SparkFunctionArrayJoin : public IFunction
size_t getNumberOfArguments() const override { return 0; }
String getName() const override { return name; }
bool isVariadic() const override { return true; }
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }

DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override
{
auto data_type = std::make_shared<DataTypeString>();
return makeNullable(data_type);
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const override
{
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
if (arguments.size() != 2 && arguments.size() != 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 or 3 arguments", getName());

const auto * arg_null_col = checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
const ColumnArray * array_col;
if (!arg_null_col)
array_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
else
array_col = checkAndGetColumn<ColumnArray>(arg_null_col->getNestedColumnPtr().get());
if (!array_col)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName());

auto res_col = ColumnString::create();
auto null_col = ColumnUInt8::create(array_col->size(), 0);
auto null_col = ColumnUInt8::create(input_rows_count, 0);
PaddedPODArray<UInt8> & null_result = null_col->getData();
std::pair<bool, StringRef> delim_p, null_replacement_p;
bool return_result = false;
auto checkAndGetConstString = [&](const ColumnPtr & col) -> std::pair<bool, StringRef>
{
StringRef res;
const auto * str_null_col = checkAndGetColumnConstData<ColumnNullable>(col.get());
if (str_null_col)
{
if (str_null_col->isNullAt(0))
{
for (size_t i = 0; i < array_col->size(); ++i)
{
res_col->insertDefault();
null_result[i] = 1;
}
return_result = true;
return std::pair<bool, StringRef>(false, res);
}
}
else
{
const auto * string_col = checkAndGetColumnConstData<ColumnString>(col.get());
if (!string_col)
return std::pair<bool, StringRef>(false, res);
else
return std::pair<bool, StringRef>(true, string_col->getDataAt(0));
}
};
delim_p = checkAndGetConstString(arguments[1].column);
if (return_result)
if (input_rows_count == 0)
return ColumnNullable::create(std::move(res_col), std::move(null_col));

if (arguments.size() == 3)
{
null_replacement_p = checkAndGetConstString(arguments[2].column);
if (return_result)
return ColumnNullable::create(std::move(res_col), std::move(null_col));
}
const ColumnArray * array_col = array_col = checkAndGetColumn<ColumnArray>(arguments[0].column.get());;
if (!array_col)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st argument must be array type", getName());

const ColumnNullable * array_nested_col = checkAndGetColumn<ColumnNullable>(&array_col->getData());
const ColumnString * string_col;
if (array_nested_col)
Expand All @@ -118,57 +77,42 @@ class SparkFunctionArrayJoin : public IFunction
const ColumnArray::Offsets & array_offsets = array_col->getOffsets();
const ColumnString::Offsets & string_offsets = string_col->getOffsets();
const ColumnString::Chars & string_data = string_col->getChars();
const ColumnNullable * delim_col = checkAndGetColumn<ColumnNullable>(arguments[1].column.get());
const ColumnNullable * null_replacement_col = arguments.size() == 3 ? checkAndGetColumn<ColumnNullable>(arguments[2].column.get()) : nullptr;

auto extractColumnString = [&](const ColumnPtr & col) -> const ColumnString *
{
const ColumnString * res = nullptr;
if (col->isConst())
{
const ColumnConst * const_col = checkAndGetColumn<ColumnConst>(col.get());
if (const_col)
res = checkAndGetColumn<ColumnString>(const_col->getDataColumnPtr().get());
}
else
res = checkAndGetColumn<ColumnString>(col.get());
return res;
};
bool const_delim_col = arguments[1].column->isConst();
bool const_null_replacement_col = false;
const ColumnString * delim_col = extractColumnString(arguments[1].column);
const ColumnString * null_replacement_col = nullptr;
if (arguments.size() == 3)
{
const_null_replacement_col = arguments[2].column->isConst();
null_replacement_col = extractColumnString(arguments[2].column);
}
size_t current_offset = 0, array_pos = 0;
for (size_t i = 0; i < array_col->size(); ++i)
{
String res;
auto setResultNull = [&]() -> void
const StringRef delim = const_delim_col ? delim_col->getDataAt(0) : delim_col->getDataAt(i);
StringRef null_replacement = StringRef(nullptr, 0);
if (null_replacement_col)
{
res_col->insertDefault();
null_result[i] = 1;
current_offset = array_offsets[i];
};
auto getDelimiterOrNullReplacement = [&](const std::pair<bool, StringRef> & s, const ColumnNullable * col) -> StringRef
{
if (s.first)
return s.second;
else
{
if (col->isNullAt(i))
return StringRef(nullptr, 0);
else
{
const ColumnString * col_string = checkAndGetColumn<ColumnString>(col->getNestedColumnPtr().get());
return col_string->getDataAt(i);
}
}
};
if (arg_null_col->isNullAt(i))
{
setResultNull();
continue;
}
const StringRef delim = getDelimiterOrNullReplacement(delim_p, delim_col);
if (!delim.data)
{
setResultNull();
continue;
null_replacement = const_null_replacement_col ? null_replacement_col->getDataAt(0) : null_replacement_col->getDataAt(i);
}
StringRef null_replacement;
if (arguments.size() == 3)
{
null_replacement = getDelimiterOrNullReplacement(null_replacement_p, null_replacement_col);
if (!null_replacement.data)
{
setResultNull();
continue;
}
}

size_t array_size = array_offsets[i] - current_offset;
size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos - 1];
size_t last_not_null_pos = 0;
for (size_t j = 0; j < array_size; ++j)
{
if (array_nested_col && array_nested_col->isNullAt(j + array_pos))
Expand All @@ -179,11 +123,14 @@ class SparkFunctionArrayJoin : public IFunction
if (j != array_size - 1)
res += delim.toString();
}
else if (j == array_size - 1)
res = res.substr(0, last_not_null_pos);
}
else
{
const StringRef s(&string_data[data_pos], string_offsets[j + array_pos] - data_pos - 1);
res += s.toString();
last_not_null_pos = res.size();
if (j != array_size - 1)
res += delim.toString();
}
Expand All @@ -194,7 +141,7 @@ class SparkFunctionArrayJoin : public IFunction
current_offset = array_offsets[i];
}
return ColumnNullable::create(std::move(res_col), std::move(null_col));
}
}
};

REGISTER_FUNCTION(SparkArrayJoin)
Expand Down
139 changes: 139 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Columns/ColumnString.h>
#include <Columns/ColumnNullable.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>

using namespace DB;

namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{
class SparkFunctionArraysOverlap : public IFunction
{
public:
static constexpr auto name = "sparkArraysOverlap";
static FunctionPtr create(ContextPtr) { return std::make_shared<SparkFunctionArraysOverlap>(); }
SparkFunctionArraysOverlap() = default;
~SparkFunctionArraysOverlap() override = default;
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
size_t getNumberOfArguments() const override { return 2; }
String getName() const override { return name; }
bool useDefaultImplementationForConstants() const override { return true; }

DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override
{
auto data_type = std::make_shared<DataTypeUInt8>();
return makeNullable(data_type);
}

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
if (arguments.size() != 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} must have 2 arguments", getName());

auto res = ColumnUInt8::create(input_rows_count, 0);
auto null_map = ColumnUInt8::create(input_rows_count, 0);
PaddedPODArray<UInt8> & res_data = res->getData();
PaddedPODArray<UInt8> & null_map_data = null_map->getData();
if (input_rows_count == 0)
return ColumnNullable::create(std::move(res), std::move(null_map));

const ColumnArray * array_col_1 = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
const ColumnArray * array_col_2 = checkAndGetColumn<ColumnArray>(arguments[1].column.get());
if (!array_col_1 || !array_col_2)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {} 1st/2nd argument must be array type", getName());

const ColumnArray::Offsets & array_offsets_1 = array_col_1->getOffsets();
const ColumnArray::Offsets & array_offsets_2 = array_col_2->getOffsets();

size_t current_offset_1 = 0, current_offset_2 = 0;
size_t array_pos_1 = 0, array_pos_2 = 0;
for (size_t i = 0; i < array_col_1->size(); ++i)
{
size_t array_size_1 = array_offsets_1[i] - current_offset_1;
size_t array_size_2 = array_offsets_2[i] - current_offset_2;
auto executeCompare = [&](const IColumn & col1, const IColumn & col2, const ColumnUInt8 * null_map1, const ColumnUInt8 * null_map2) -> void
{
for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j)
{
for (size_t k = 0; k < array_size_2; ++k)
{
if ((null_map1 && null_map1->getElement(j + array_pos_1)) || (null_map2 && null_map2->getElement(k + array_pos_2)))
{
null_map_data[i] = 1;
}
else if (col1.compareAt(j + array_pos_1, k + array_pos_2, col2, -1) == 0)
{
res_data[i] = 1;
null_map_data[i] = 0;
break;
}
}
}
};
if (array_col_1->getData().isNullable() || array_col_2->getData().isNullable())
{
if (array_col_1->getData().isNullable() && array_col_2->getData().isNullable())
{
const ColumnNullable * array_null_col_1 = assert_cast<const ColumnNullable *>(&array_col_1->getData());
const ColumnNullable * array_null_col_2 = assert_cast<const ColumnNullable *>(&array_col_2->getData());
executeCompare(array_null_col_1->getNestedColumn(), array_null_col_2->getNestedColumn(),
&array_null_col_1->getNullMapColumn(), &array_null_col_2->getNullMapColumn());
}
else if (array_col_1->getData().isNullable())
{
const ColumnNullable * array_null_col_1 = assert_cast<const ColumnNullable *>(&array_col_1->getData());
executeCompare(array_null_col_1->getNestedColumn(), array_col_2->getData(), &array_null_col_1->getNullMapColumn(), nullptr);
}
else if (array_col_2->getData().isNullable())
{
const ColumnNullable * array_null_col_2 = assert_cast<const ColumnNullable *>(&array_col_2->getData());
executeCompare(array_col_1->getData(), array_null_col_2->getNestedColumn(), nullptr, &array_null_col_2->getNullMapColumn());
}
}
else if (array_col_1->getData().getDataType() == array_col_2->getData().getDataType())
{
executeCompare(array_col_1->getData(), array_col_2->getData(), nullptr, nullptr);
}

current_offset_1 = array_offsets_1[i];
current_offset_2 = array_offsets_2[i];
array_pos_1 += array_size_1;
array_pos_2 += array_size_2;
}
return ColumnNullable::create(std::move(res), std::move(null_map));
}
};

REGISTER_FUNCTION(SparkArraysOverlap)
{
factory.registerFunction<SparkFunctionArraysOverlap>();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle, arrayShuffle);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysOverlap, arrays_overlap, sparkArraysOverlap);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysZip, arrays_zip, arrayZipUnaligned);

// map functions
Expand Down

0 comments on commit 92d3793

Please sign in to comment.