Skip to content

Commit

Permalink
[C++][Go]: support casting nullable fields to non-nullable if there a…
Browse files Browse the repository at this point in the history
…re no null values

Fixes #33592
  • Loading branch information
NickCrews committed Aug 21, 2024
1 parent f078942 commit 52100dd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 15 deletions.
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/kernels/scalar_cast_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ struct CastStruct {
const int out_field_count = out_type.num_fields();

std::vector<int> fields_to_select(out_field_count, -1);
const ArraySpan& in_array = batch[0].array;

int out_field_index = 0;
for (int in_field_index = 0;
Expand All @@ -355,9 +356,9 @@ struct CastStruct {
const auto& in_field = in_type.field(in_field_index);
const auto& out_field = out_type.field(out_field_index);
if (in_field->name() == out_field->name()) {
if (in_field->nullable() && !out_field->nullable()) {
return Status::TypeError("cannot cast nullable field to non-nullable field: ",
in_type.ToString(), " ", out_type.ToString());
if (!out_field->nullable() && in_array.null_count > 0) {
return Status::TypeError("casting field ", in_field->name(),
" with nulls to non-nullable type", out_type.ToString());
}
fields_to_select[out_field_index++] = in_field_index;
}
Expand All @@ -369,7 +370,6 @@ struct CastStruct {
in_type.ToString(), " output fields: ", out_type.ToString());
}

const ArraySpan& in_array = batch[0].array;
ArrayData* out_array = out->array_data().get();

if (in_array.buffers[0].data != nullptr) {
Expand Down
11 changes: 4 additions & 7 deletions cpp/src/arrow/compute/kernels/scalar_cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2855,7 +2855,7 @@ TEST(Cast, StructToDifferentNullabilityStruct) {
CheckCast(src_non_nullable, dest3_nullable);
}
{
// But NOT OK to go from nullable to non-nullable...
// But when going from nullable to non-nullable, all data must be non-null...
std::vector<std::shared_ptr<Field>> fields_src_nullable = {
std::make_shared<Field>("a", int8(), true),
std::make_shared<Field>("b", int8(), true),
Expand All @@ -2877,7 +2877,7 @@ TEST(Cast, StructToDifferentNullabilityStruct) {
const auto options1_non_nullable = CastOptions::Safe(dest1_non_nullable);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("cannot cast nullable field to non-nullable field"),
::testing::HasSubstr("casting field a with nulls to non-nullable type int not null"),
Cast(src_nullable, options1_non_nullable));

std::vector<std::shared_ptr<Field>> fields_dest2_non_nullable = {
Expand All @@ -2887,17 +2887,14 @@ TEST(Cast, StructToDifferentNullabilityStruct) {
const auto options2_non_nullable = CastOptions::Safe(dest2_non_nullable);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("cannot cast nullable field to non-nullable field"),
::testing::HasSubstr("casting field a with nulls to non-nullable type int not null"),
Cast(src_nullable, options2_non_nullable));

std::vector<std::shared_ptr<Field>> fields_dest3_non_nullable = {
std::make_shared<Field>("c", int64(), false)};
const auto dest3_non_nullable = arrow::struct_(fields_dest3_non_nullable);
const auto options3_non_nullable = CastOptions::Safe(dest3_non_nullable);
EXPECT_RAISES_WITH_MESSAGE_THAT(
TypeError,
::testing::HasSubstr("cannot cast nullable field to non-nullable field"),
Cast(src_nullable, options3_non_nullable));
CheckCast(src_nullable, dest3_non_nullable, options3_non_nullable);
}
}

Expand Down
4 changes: 0 additions & 4 deletions go/arrow/compute/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,6 @@ func CastStruct(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult)
inField := inType.Field(inFieldIndex)
outField := outType.Field(outFieldIndex)
if inField.Name == outField.Name {
if inField.Nullable && !outField.Nullable {
return fmt.Errorf("%w: cannot cast nullable field to non-nullable field: %s %s",
arrow.ErrType, inType, outType)
}
fieldsToSelect[outFieldIndex] = inFieldIndex
outFieldIndex++
}
Expand Down

0 comments on commit 52100dd

Please sign in to comment.