From 52100dd3ef04546c2c071cbc5fe2f39dc5768e91 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Wed, 21 Aug 2024 14:44:12 -0800 Subject: [PATCH] [C++][Go]: support casting nullable fields to non-nullable if there are no null values Fixes https://github.com/apache/arrow/issues/33592 --- cpp/src/arrow/compute/kernels/scalar_cast_nested.cc | 8 ++++---- cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 11 ++++------- go/arrow/compute/cast.go | 4 ---- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc b/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc index ec5291ef608a3..027d92c7c03f6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_nested.cc @@ -347,6 +347,7 @@ struct CastStruct { const int out_field_count = out_type.num_fields(); std::vector fields_to_select(out_field_count, -1); + const ArraySpan& in_array = batch[0].array; int out_field_index = 0; for (int in_field_index = 0; @@ -355,9 +356,9 @@ struct CastStruct { const auto& in_field = in_type.field(in_field_index); const auto& out_field = out_type.field(out_field_index); if (in_field->name() == out_field->name()) { - if (in_field->nullable() && !out_field->nullable()) { - return Status::TypeError("cannot cast nullable field to non-nullable field: ", - in_type.ToString(), " ", out_type.ToString()); + if (!out_field->nullable() && in_array.null_count > 0) { + return Status::TypeError("casting field ", in_field->name(), + " with nulls to non-nullable type", out_type.ToString()); } fields_to_select[out_field_index++] = in_field_index; } @@ -369,7 +370,6 @@ struct CastStruct { in_type.ToString(), " output fields: ", out_type.ToString()); } - const ArraySpan& in_array = batch[0].array; ArrayData* out_array = out->array_data().get(); if (in_array.buffers[0].data != nullptr) { diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 140789e59665b..c126da8244431 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -2855,7 +2855,7 @@ TEST(Cast, StructToDifferentNullabilityStruct) { CheckCast(src_non_nullable, dest3_nullable); } { - // But NOT OK to go from nullable to non-nullable... + // But when going from nullable to non-nullable, all data must be non-null... std::vector> fields_src_nullable = { std::make_shared("a", int8(), true), std::make_shared("b", int8(), true), @@ -2877,7 +2877,7 @@ TEST(Cast, StructToDifferentNullabilityStruct) { const auto options1_non_nullable = CastOptions::Safe(dest1_non_nullable); EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("cannot cast nullable field to non-nullable field"), + ::testing::HasSubstr("casting field a with nulls to non-nullable type int not null"), Cast(src_nullable, options1_non_nullable)); std::vector> fields_dest2_non_nullable = { @@ -2887,17 +2887,14 @@ TEST(Cast, StructToDifferentNullabilityStruct) { const auto options2_non_nullable = CastOptions::Safe(dest2_non_nullable); EXPECT_RAISES_WITH_MESSAGE_THAT( TypeError, - ::testing::HasSubstr("cannot cast nullable field to non-nullable field"), + ::testing::HasSubstr("casting field a with nulls to non-nullable type int not null"), Cast(src_nullable, options2_non_nullable)); std::vector> fields_dest3_non_nullable = { std::make_shared("c", int64(), false)}; const auto dest3_non_nullable = arrow::struct_(fields_dest3_non_nullable); const auto options3_non_nullable = CastOptions::Safe(dest3_non_nullable); - EXPECT_RAISES_WITH_MESSAGE_THAT( - TypeError, - ::testing::HasSubstr("cannot cast nullable field to non-nullable field"), - Cast(src_nullable, options3_non_nullable)); + CheckCast(src_nullable, dest3_non_nullable, options3_non_nullable); } } diff --git a/go/arrow/compute/cast.go b/go/arrow/compute/cast.go index 6ef6fdddd16ff..02ad2595225f9 100644 --- a/go/arrow/compute/cast.go +++ b/go/arrow/compute/cast.go @@ -280,10 +280,6 @@ func CastStruct(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) inField := inType.Field(inFieldIndex) outField := outType.Field(outFieldIndex) if inField.Name == outField.Name { - if inField.Nullable && !outField.Nullable { - return fmt.Errorf("%w: cannot cast nullable field to non-nullable field: %s %s", - arrow.ErrType, inType, outType) - } fieldsToSelect[outFieldIndex] = inFieldIndex outFieldIndex++ }