Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Apr 1, 2024
1 parent 73328d6 commit 9fc2e1f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 32 deletions.
68 changes: 40 additions & 28 deletions velox/functions/lib/MapFromEntries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ static const char* kIndeterminateKeyErrorMessage =
"map key cannot be indeterminate";
static const char* kErrorMessageEntryNotNull = "map entry cannot be null";

// allowNullEle If true, will return null if map has
// an entry with null as key or map is null (Spark's behavior)
// instead of throw execeptions(Presto's behavior)
template <bool allowNullEle>
/// @tparam throwForNull If true, will return null if input array is null or has
/// null entry (Spark's behavior), instead of throw execeptions(Presto's
/// behavior).
template <bool throwForNull>
class MapFromEntriesFunction : public exec::VectorFunction {
public:
void apply(
Expand Down Expand Up @@ -98,7 +98,7 @@ class MapFromEntriesFunction : public exec::VectorFunction {
exec::LocalDecodedVector decodedRowVector(context);
decodedRowVector.get()->decode(*inputValueVector);
if (inputValueVector->typeKind() == TypeKind::UNKNOWN) {
if (!allowNullEle) {
if constexpr (throwForNull) {
try {
VELOX_USER_FAIL(kErrorMessageEntryNotNull);
} catch (...) {
Expand Down Expand Up @@ -144,12 +144,13 @@ class MapFromEntriesFunction : public exec::VectorFunction {
for (auto i = 0; i < size; ++i) {
// Check nulls in the top level row vector.
const bool isMapEntryNull = decodedRowVector->isNullAt(offset + i);
if (isMapEntryNull && allowNullEle) {
// Spark: For nulls in the top level row vector, return null.
bits::setNull(mutableNulls, row);
resetSize(row);
break;
} else if (isMapEntryNull) {
if (isMapEntryNull) {
if constexpr (!throwForNull) {
// Spark: For nulls in the top level row vector, return null.
bits::setNull(mutableNulls, row);
resetSize(row);
break;
}
// Presto: Set the sizes to 0 so that the final map vector generated
// is valid in case we are inside a try. The map vector needs to be
// valid because its consumed by checkDuplicateKeys before try
Expand Down Expand Up @@ -228,34 +229,45 @@ class MapFromEntriesFunction : public exec::VectorFunction {

// For Presto, need construct map vector based on input nulls for possible
// outer expression like try(). For Spark, use the updated nulls.
auto mapVetorNulls = allowNullEle ? nulls : inputArray->nulls();
auto mapVector = std::make_shared<MapVector>(
context.pool(),
outputType,
mapVetorNulls,
rows.end(),
inputArray->offsets(),
sizes,
wrappedKeys,
wrappedValues);

std::shared_ptr<MapVector> mapVector;
if constexpr (throwForNull) {
mapVector = std::make_shared<MapVector>(
context.pool(),
outputType,
inputArray->nulls(),
rows.end(),
inputArray->offsets(),
sizes,
wrappedKeys,
wrappedValues);
} else {
mapVector = std::make_shared<MapVector>(
context.pool(),
outputType,
nulls,
rows.end(),
inputArray->offsets(),
sizes,
wrappedKeys,
wrappedValues);
}
checkDuplicateKeys(mapVector, *remianingRows, context);
return mapVector;
}
};
} // namespace

void registerMapFromEntriesFunction(const std::string& name) {
void registerMapFromEntriesThrowForNullFunction(const std::string& name) {
exec::registerVectorFunction(
name,
MapFromEntriesFunction</*AllowNullEle=*/false>::signatures(),
std::make_unique<MapFromEntriesFunction</*AllowNullEle=*/false>>());
MapFromEntriesFunction</*ThrowForNull=*/true>::signatures(),
std::make_unique<MapFromEntriesFunction</*ThrowForNull=*/true>>());
}

void registerMapFromEntriesAllowNullEleFunction(const std::string& name) {
void registerMapFromEntriesFunction(const std::string& name) {
exec::registerVectorFunction(
name,
MapFromEntriesFunction</*AllowNullEle=*/true>::signatures(),
std::make_unique<MapFromEntriesFunction</*AllowNullEle=*/true>>());
MapFromEntriesFunction</*ThrowForNull=*/false>::signatures(),
std::make_unique<MapFromEntriesFunction</*ThrowForNull=*/false>>());
}
} // namespace facebook::velox::functions
4 changes: 2 additions & 2 deletions velox/functions/lib/MapFromEntries.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

namespace facebook::velox::functions {

void registerMapFromEntriesFunction(const std::string& name);
void registerMapFromEntriesThrowForNullFunction(const std::string& name);

void registerMapFromEntriesAllowNullEleFunction(const std::string& name);
void registerMapFromEntriesFunction(const std::string& name);

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void registerMapFunctions(const std::string& prefix) {
udf_transform_values, prefix + "transform_values");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map, prefix + "map");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries");
registerMapFromEntriesFunction(prefix + "map_from_entries");
registerMapFromEntriesThrowForNullFunction(prefix + "map_from_entries");

VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values");
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ static void workAroundRegistrationMacro(const std::string& prefix) {

VELOX_REGISTER_VECTOR_FUNCTION(
udf_map_allow_duplicates, prefix + "map_from_arrays");
registerMapFromEntriesAllowNullEleFunction(prefix + "map_from_entries");
registerMapFromEntriesFunction(prefix + "map_from_entries");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_concat_row, exec::RowConstructorCallToSpecialForm::kRowConstructor);
// String functions.
Expand Down

0 comments on commit 9fc2e1f

Please sign in to comment.