From 72d20ad719021c5513620e23a0a65fb724f0e299 Mon Sep 17 00:00:00 2001 From: Clif Houck Date: Thu, 4 Apr 2024 08:59:30 -0500 Subject: [PATCH] GH-20213: [C++] Implement cast to/from halffloat (#40067) ### Rationale for this change ### What changes are included in this PR? This PR implements casting to and from float16 types using the vendored float16 library included in arrow at `cpp/arrrow/util/float16.*`. ### Are these changes tested? Unit tests are included in this PR. ### Are there any user-facing changes? In that casts to and from float16 will now work, yes. * Closes: #20213 ### TODO - [x] Add casts to/from float64. - [x] String <-> float16 casts. - [x] Integer <-> float16 casts. - [x] Tests. - [x] Update https://github.com/apache/arrow/blob/main/docs/source/status.rst about half float. - [x] Rebase. - [x] Run clang format over this PR. * GitHub Issue: #20213 Authored-by: Clif Houck Signed-off-by: Sutou Kouhei --- c_glib/test/test-half-float-scalar.rb | 2 +- cpp/src/arrow/compare.cc | 30 +++++ .../compute/kernels/scalar_cast_internal.cc | 70 ++++++++++++ .../compute/kernels/scalar_cast_numeric.cc | 103 +++++++++++++++--- .../compute/kernels/scalar_cast_string.cc | 4 + .../arrow/compute/kernels/scalar_cast_test.cc | 25 +++-- cpp/src/arrow/ipc/json_simple.cc | 32 +++++- cpp/src/arrow/ipc/json_simple_test.cc | 35 +++++- cpp/src/arrow/record_batch_test.cc | 3 + cpp/src/arrow/type_traits.h | 1 + cpp/src/arrow/util/formatting.cc | 11 ++ cpp/src/arrow/util/formatting.h | 7 ++ cpp/src/arrow/util/value_parsing.cc | 14 +++ cpp/src/arrow/util/value_parsing.h | 17 +++ docs/source/status.rst | 11 +- 15 files changed, 325 insertions(+), 40 deletions(-) diff --git a/c_glib/test/test-half-float-scalar.rb b/c_glib/test/test-half-float-scalar.rb index ac41f91ece621..3073d84d796cf 100644 --- a/c_glib/test/test-half-float-scalar.rb +++ b/c_glib/test/test-half-float-scalar.rb @@ -41,7 +41,7 @@ def test_equal end def test_to_s - assert_equal("[\n #{@half_float}\n]", @scalar.to_s) + assert_equal("1.0009765625", @scalar.to_s) end def test_value diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index bb632e2eb912d..e983b47e39dc4 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -44,6 +44,7 @@ #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_reader.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/float16.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" @@ -59,6 +60,7 @@ using internal::BitmapReader; using internal::BitmapUInt64Reader; using internal::checked_cast; using internal::OptionalBitmapEquals; +using util::Float16; // ---------------------------------------------------------------------- // Public method implementations @@ -95,6 +97,30 @@ struct FloatingEquality { const T epsilon; }; +// For half-float equality. +template +struct FloatingEquality { + explicit FloatingEquality(const EqualOptions& options) + : epsilon(static_cast(options.atol())) {} + + bool operator()(uint16_t x, uint16_t y) const { + Float16 f_x = Float16::FromBits(x); + Float16 f_y = Float16::FromBits(y); + if (x == y) { + return Flags::signed_zeros_equal || (f_x.signbit() == f_y.signbit()); + } + if (Flags::nans_equal && f_x.is_nan() && f_y.is_nan()) { + return true; + } + if (Flags::approximate && (fabs(f_x.ToFloat() - f_y.ToFloat()) <= epsilon)) { + return true; + } + return false; + } + + const float epsilon; +}; + template struct FloatingEqualityDispatcher { const EqualOptions& options; @@ -259,6 +285,8 @@ class RangeDataEqualsImpl { Status Visit(const DoubleType& type) { return CompareFloating(type); } + Status Visit(const HalfFloatType& type) { return CompareFloating(type); } + // Also matches StringType Status Visit(const BinaryType& type) { return CompareBinary(type); } @@ -863,6 +891,8 @@ class ScalarEqualsVisitor { Status Visit(const DoubleScalar& left) { return CompareFloating(left); } + Status Visit(const HalfFloatScalar& left) { return CompareFloating(left); } + template enable_if_t::value, Status> Visit(const T& left) { const auto& right = checked_cast(right_); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index 8cf5a04addb00..d8c4088759643 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -19,10 +19,13 @@ #include "arrow/compute/cast_internal.h" #include "arrow/compute/kernels/common_internal.h" #include "arrow/extension_type.h" +#include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/float16.h" namespace arrow { +using arrow::util::Float16; using internal::checked_cast; using internal::PrimitiveScalarBase; @@ -47,6 +50,42 @@ struct CastPrimitive { } }; +// Converting floating types to half float. +template +struct CastPrimitive> { + static void Exec(const ArraySpan& arr, ArraySpan* out) { + using InT = typename InType::c_type; + const InT* in_values = arr.GetValues(1); + uint16_t* out_values = out->GetValues(1); + for (int64_t i = 0; i < arr.length; ++i) { + *out_values++ = Float16(*in_values++).bits(); + } + } +}; + +// Converting from half float to other floating types. +template <> +struct CastPrimitive> { + static void Exec(const ArraySpan& arr, ArraySpan* out) { + const uint16_t* in_values = arr.GetValues(1); + float* out_values = out->GetValues(1); + for (int64_t i = 0; i < arr.length; ++i) { + *out_values++ = Float16::FromBits(*in_values++).ToFloat(); + } + } +}; + +template <> +struct CastPrimitive> { + static void Exec(const ArraySpan& arr, ArraySpan* out) { + const uint16_t* in_values = arr.GetValues(1); + double* out_values = out->GetValues(1); + for (int64_t i = 0; i < arr.length; ++i) { + *out_values++ = Float16::FromBits(*in_values++).ToDouble(); + } + } +}; + template struct CastPrimitive::value>> { // memcpy output @@ -56,6 +95,33 @@ struct CastPrimitive: } }; +// Cast int to half float +template +struct CastPrimitive> { + static void Exec(const ArraySpan& arr, ArraySpan* out) { + using InT = typename InType::c_type; + const InT* in_values = arr.GetValues(1); + uint16_t* out_values = out->GetValues(1); + for (int64_t i = 0; i < arr.length; ++i) { + float temp = static_cast(*in_values++); + *out_values++ = Float16(temp).bits(); + } + } +}; + +// Cast half float to int +template +struct CastPrimitive> { + static void Exec(const ArraySpan& arr, ArraySpan* out) { + using OutT = typename OutType::c_type; + const uint16_t* in_values = arr.GetValues(1); + OutT* out_values = out->GetValues(1); + for (int64_t i = 0; i < arr.length; ++i) { + *out_values++ = static_cast(Float16::FromBits(*in_values++).ToFloat()); + } + } +}; + template void CastNumberImpl(Type::type out_type, const ArraySpan& input, ArraySpan* out) { switch (out_type) { @@ -79,6 +145,8 @@ void CastNumberImpl(Type::type out_type, const ArraySpan& input, ArraySpan* out) return CastPrimitive::Exec(input, out); case Type::DOUBLE: return CastPrimitive::Exec(input, out); + case Type::HALF_FLOAT: + return CastPrimitive::Exec(input, out); default: break; } @@ -109,6 +177,8 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, return CastNumberImpl(out_type, input, out); case Type::DOUBLE: return CastNumberImpl(out_type, input, out); + case Type::HALF_FLOAT: + return CastNumberImpl(out_type, input, out); default: DCHECK(false); break; diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index b054e57f04d12..3df86e7d6936c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -23,6 +23,7 @@ #include "arrow/compute/kernels/util_internal.h" #include "arrow/scalar.h" #include "arrow/util/bit_block_counter.h" +#include "arrow/util/float16.h" #include "arrow/util/int_util.h" #include "arrow/util/value_parsing.h" @@ -34,6 +35,7 @@ using internal::IntegersCanFit; using internal::OptionalBitBlockCounter; using internal::ParseValue; using internal::PrimitiveScalarBase; +using util::Float16; namespace compute { namespace internal { @@ -56,18 +58,37 @@ Status CastFloatingToFloating(KernelContext*, const ExecSpan& batch, ExecResult* // ---------------------------------------------------------------------- // Implement fast safe floating point to integer cast +// +template +struct WasTruncated { + static bool Check(OutT out_val, InT in_val) { + return static_cast(out_val) != in_val; + } + + static bool CheckMaybeNull(OutT out_val, InT in_val, bool is_valid) { + return is_valid && static_cast(out_val) != in_val; + } +}; + +// Half float to int +template +struct WasTruncated { + using OutT = typename OutType::c_type; + static bool Check(OutT out_val, uint16_t in_val) { + return static_cast(out_val) != Float16::FromBits(in_val).ToFloat(); + } + + static bool CheckMaybeNull(OutT out_val, uint16_t in_val, bool is_valid) { + return is_valid && static_cast(out_val) != Float16::FromBits(in_val).ToFloat(); + } +}; // InType is a floating point type we are planning to cast to integer template ARROW_DISABLE_UBSAN("float-cast-overflow") Status CheckFloatTruncation(const ArraySpan& input, const ArraySpan& output) { - auto WasTruncated = [&](OutT out_val, InT in_val) -> bool { - return static_cast(out_val) != in_val; - }; - auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool { - return is_valid && static_cast(out_val) != in_val; - }; auto GetErrorMessage = [&](InT val) { return Status::Invalid("Float value ", val, " was truncated converting to ", *output.type); @@ -86,26 +107,28 @@ Status CheckFloatTruncation(const ArraySpan& input, const ArraySpan& output) { if (block.popcount == block.length) { // Fast path: branchless for (int64_t i = 0; i < block.length; ++i) { - block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]); + block_out_of_bounds |= + WasTruncated::Check(out_data[i], in_data[i]); } } else if (block.popcount > 0) { // Indices have nulls, must only boundscheck non-null values for (int64_t i = 0; i < block.length; ++i) { - block_out_of_bounds |= WasTruncatedMaybeNull( + block_out_of_bounds |= WasTruncated::CheckMaybeNull( out_data[i], in_data[i], bit_util::GetBit(bitmap, offset_position + i)); } } if (ARROW_PREDICT_FALSE(block_out_of_bounds)) { if (input.GetNullCount() > 0) { for (int64_t i = 0; i < block.length; ++i) { - if (WasTruncatedMaybeNull(out_data[i], in_data[i], - bit_util::GetBit(bitmap, offset_position + i))) { + if (WasTruncated::CheckMaybeNull( + out_data[i], in_data[i], + bit_util::GetBit(bitmap, offset_position + i))) { return GetErrorMessage(in_data[i]); } } } else { for (int64_t i = 0; i < block.length; ++i) { - if (WasTruncated(out_data[i], in_data[i])) { + if (WasTruncated::Check(out_data[i], in_data[i])) { return GetErrorMessage(in_data[i]); } } @@ -151,6 +174,9 @@ Status CheckFloatToIntTruncation(const ExecValue& input, const ExecResult& outpu return CheckFloatToIntTruncationImpl(input.array, *output.array_span()); case Type::DOUBLE: return CheckFloatToIntTruncationImpl(input.array, *output.array_span()); + case Type::HALF_FLOAT: + return CheckFloatToIntTruncationImpl(input.array, + *output.array_span()); default: break; } @@ -293,6 +319,15 @@ struct CastFunctor< } }; +template <> +struct CastFunctor> { + static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return applicator::ScalarUnaryNotNull>::Exec(ctx, batch, + out); + } +}; + // ---------------------------------------------------------------------- // Decimal to integer @@ -689,6 +724,10 @@ std::shared_ptr GetCastToInteger(std::string name) { DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger)); } + // Cast from half-float + DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {InputType(Type::HALF_FLOAT)}, out_ty, + CastFloatingToInteger)); + // From other numbers to integer AddCommonNumberCasts(out_ty, func.get()); @@ -715,6 +754,10 @@ std::shared_ptr GetCastToFloating(std::string name) { DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating)); } + // From half-float to float/double + DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {InputType(Type::HALF_FLOAT)}, out_ty, + CastFloatingToFloating)); + // From other numbers to floating point AddCommonNumberCasts(out_ty, func.get()); @@ -723,6 +766,7 @@ std::shared_ptr GetCastToFloating(std::string name) { CastFunctor::Exec)); DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty, CastFunctor::Exec)); + return func; } @@ -795,6 +839,32 @@ std::shared_ptr GetCastToDecimal256() { return func; } +std::shared_ptr GetCastToHalfFloat() { + // HalfFloat is a bit brain-damaged for now + auto func = std::make_shared("func", Type::HALF_FLOAT); + AddCommonCasts(Type::HALF_FLOAT, float16(), func.get()); + + // Casts from integer to floating point + for (const std::shared_ptr& in_ty : IntTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, + TypeTraits::type_singleton(), + CastIntegerToFloating)); + } + + // Cast from other strings to half float. + for (const std::shared_ptr& in_ty : BaseBinaryTypes()) { + auto exec = GenerateVarBinaryBase(*in_ty); + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, + TypeTraits::type_singleton(), exec)); + } + + DCHECK_OK(func.get()->AddKernel(Type::FLOAT, {InputType(Type::FLOAT)}, float16(), + CastFloatingToFloating)); + DCHECK_OK(func.get()->AddKernel(Type::DOUBLE, {InputType(Type::DOUBLE)}, float16(), + CastFloatingToFloating)); + return func; +} + } // namespace std::vector> GetNumericCasts() { @@ -830,13 +900,14 @@ std::vector> GetNumericCasts() { functions.push_back(GetCastToInteger("cast_uint64")); // HalfFloat is a bit brain-damaged for now - auto cast_half_float = - std::make_shared("cast_half_float", Type::HALF_FLOAT); - AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get()); + auto cast_half_float = GetCastToHalfFloat(); functions.push_back(cast_half_float); - functions.push_back(GetCastToFloating("cast_float")); - functions.push_back(GetCastToFloating("cast_double")); + auto cast_float = GetCastToFloating("cast_float"); + functions.push_back(cast_float); + + auto cast_double = GetCastToFloating("cast_double"); + functions.push_back(cast_double); functions.push_back(GetCastToDecimal128()); functions.push_back(GetCastToDecimal256()); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index a6576e4e4c26f..3a8352a9b870f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -437,6 +437,10 @@ void AddNumberToStringCasts(CastFunction* func) { GenerateNumeric(*in_ty), NullHandling::COMPUTED_NO_PREALLOCATE)); } + + DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {float16()}, out_ty, + NumericToStringCastFunctor::Exec, + NullHandling::COMPUTED_NO_PREALLOCATE)); } template diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index a8acf68f66c8b..af62b4da2caa5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -389,7 +389,7 @@ TEST(Cast, ToIntDowncastUnsafe) { } TEST(Cast, FloatingToInt) { - for (auto from : {float32(), float64()}) { + for (auto from : {float16(), float32(), float64()}) { for (auto to : {int32(), int64()}) { // float to int no truncation CheckCast(ArrayFromJSON(from, "[1.0, null, 0.0, -1.0, 5.0]"), @@ -407,6 +407,15 @@ TEST(Cast, FloatingToInt) { } } +TEST(Cast, FloatingToFloating) { + for (auto from : {float16(), float32(), float64()}) { + for (auto to : {float16(), float32(), float64()}) { + CheckCast(ArrayFromJSON(from, "[1.0, 0.0, -1.0, 5.0]"), + ArrayFromJSON(to, "[1.0, 0.0, -1.0, 5.0]")); + } + } +} + TEST(Cast, IntToFloating) { for (auto from : {uint32(), int32()}) { std::string two_24 = "[16777216, 16777217]"; @@ -2220,14 +2229,12 @@ TEST(Cast, IntToString) { } TEST(Cast, FloatingToString) { - for (auto string_type : {utf8(), large_utf8()}) { - CheckCast( - ArrayFromJSON(float32(), "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]"), - ArrayFromJSON(string_type, R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])")); - - CheckCast( - ArrayFromJSON(float64(), "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]"), - ArrayFromJSON(string_type, R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])")); + for (auto float_type : {float16(), float32(), float64()}) { + for (auto string_type : {utf8(), large_utf8()}) { + CheckCast(ArrayFromJSON(float_type, "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]"), + ArrayFromJSON(string_type, + R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])")); + } } } diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index ceeabe01677ed..9fd449831c980 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -36,6 +36,7 @@ #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" +#include "arrow/util/float16.h" #include "arrow/util/logging.h" #include "arrow/util/value_parsing.h" @@ -52,6 +53,7 @@ namespace rj = arrow::rapidjson; namespace arrow { using internal::ParseValue; +using util::Float16; namespace ipc { namespace internal { @@ -232,9 +234,9 @@ enable_if_physical_signed_integer ConvertNumber(const rj::Value& json // Convert single unsigned integer value template -enable_if_physical_unsigned_integer ConvertNumber(const rj::Value& json_obj, - const DataType& type, - typename T::c_type* out) { +enable_if_unsigned_integer ConvertNumber(const rj::Value& json_obj, + const DataType& type, + typename T::c_type* out) { if (json_obj.IsUint64()) { uint64_t v64 = json_obj.GetUint64(); *out = static_cast(v64); @@ -249,6 +251,30 @@ enable_if_physical_unsigned_integer ConvertNumber(const rj::Value& js } } +// Convert float16/HalfFloatType +template +enable_if_half_float ConvertNumber(const rj::Value& json_obj, + const DataType& type, uint16_t* out) { + if (json_obj.IsDouble()) { + double f64 = json_obj.GetDouble(); + *out = Float16(f64).bits(); + return Status::OK(); + } else if (json_obj.IsUint()) { + uint32_t u32t = json_obj.GetUint(); + double f64 = static_cast(u32t); + *out = Float16(f64).bits(); + return Status::OK(); + } else if (json_obj.IsInt()) { + int32_t i32t = json_obj.GetInt(); + double f64 = static_cast(i32t); + *out = Float16(f64).bits(); + return Status::OK(); + } else { + *out = static_cast(0); + return JSONTypeError("unsigned int", json_obj.GetType()); + } +} + // Convert single floating point value template enable_if_physical_floating_point ConvertNumber(const rj::Value& json_obj, diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index ea3a9ae1a14a9..b3f7fc5b3458b 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -44,6 +44,7 @@ #include "arrow/util/bitmap_builders.h" #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" +#include "arrow/util/float16.h" #if defined(_MSC_VER) // "warning C4307: '+': integral constant overflow" @@ -51,6 +52,9 @@ #endif namespace arrow { + +using util::Float16; + namespace ipc { namespace internal { namespace json { @@ -185,6 +189,21 @@ class TestIntegers : public ::testing::Test { TYPED_TEST_SUITE_P(TestIntegers); +template +std::vector TestIntegersMutateIfNeeded( + std::vector data) { + return data; +} + +// TODO: This works, but is it the right way to do this? +template <> +std::vector TestIntegersMutateIfNeeded( + std::vector data) { + std::for_each(data.begin(), data.end(), + [](HalfFloatType::c_type& value) { value = Float16(value).bits(); }); + return data; +} + TYPED_TEST_P(TestIntegers, Basics) { using T = TypeParam; using c_type = typename T::c_type; @@ -193,16 +212,17 @@ TYPED_TEST_P(TestIntegers, Basics) { auto type = this->type(); AssertJSONArray(type, "[]", {}); - AssertJSONArray(type, "[4, 0, 5]", {4, 0, 5}); - AssertJSONArray(type, "[4, null, 5]", {true, false, true}, {4, 0, 5}); + AssertJSONArray(type, "[4, 0, 5]", TestIntegersMutateIfNeeded({4, 0, 5})); + AssertJSONArray(type, "[4, null, 5]", {true, false, true}, + TestIntegersMutateIfNeeded({4, 0, 5})); // Test limits const auto min_val = std::numeric_limits::min(); const auto max_val = std::numeric_limits::max(); std::string json_string = JSONArray(0, 1, min_val); - AssertJSONArray(type, json_string, {0, 1, min_val}); + AssertJSONArray(type, json_string, TestIntegersMutateIfNeeded({0, 1, min_val})); json_string = JSONArray(0, 1, max_val); - AssertJSONArray(type, json_string, {0, 1, max_val}); + AssertJSONArray(type, json_string, TestIntegersMutateIfNeeded({0, 1, max_val})); } TYPED_TEST_P(TestIntegers, Errors) { @@ -269,7 +289,12 @@ INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt8, TestIntegers, UInt8Type); INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt16, TestIntegers, UInt16Type); INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt32, TestIntegers, UInt32Type); INSTANTIATE_TYPED_TEST_SUITE_P(TestUInt64, TestIntegers, UInt64Type); -INSTANTIATE_TYPED_TEST_SUITE_P(TestHalfFloat, TestIntegers, HalfFloatType); +// FIXME: I understand that HalfFloatType is backed by a uint16_t, but does it +// make sense to run this test over it? +// The way ConvertNumber for HalfFloatType is currently written, it allows the +// conversion of floating point notation to a half float, which causes failures +// in this test, one example is asserting 0.0 cannot be parsed as a half float. +// INSTANTIATE_TYPED_TEST_SUITE_P(TestHalfFloat, TestIntegers, HalfFloatType); template class TestStrings : public ::testing::Test { diff --git a/cpp/src/arrow/record_batch_test.cc b/cpp/src/arrow/record_batch_test.cc index 7e0eb1d460555..95f601465b440 100644 --- a/cpp/src/arrow/record_batch_test.cc +++ b/cpp/src/arrow/record_batch_test.cc @@ -36,11 +36,14 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/type.h" +#include "arrow/util/float16.h" #include "arrow/util/iterator.h" #include "arrow/util/key_value_metadata.h" namespace arrow { +using util::Float16; + class TestRecordBatch : public ::testing::Test {}; TEST_F(TestRecordBatch, Equals) { diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index ed66c9367dc36..8caf4400fe86d 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -305,6 +305,7 @@ struct TypeTraits { using BuilderType = HalfFloatBuilder; using ScalarType = HalfFloatScalar; using TensorType = HalfFloatTensor; + using CType = uint16_t; static constexpr int64_t bytes_required(int64_t elements) { return elements * static_cast(sizeof(uint16_t)); diff --git a/cpp/src/arrow/util/formatting.cc b/cpp/src/arrow/util/formatting.cc index c16d42ce5cfe2..c5a7e03f8573a 100644 --- a/cpp/src/arrow/util/formatting.cc +++ b/cpp/src/arrow/util/formatting.cc @@ -18,10 +18,12 @@ #include "arrow/util/formatting.h" #include "arrow/util/config.h" #include "arrow/util/double_conversion.h" +#include "arrow/util/float16.h" #include "arrow/util/logging.h" namespace arrow { +using util::Float16; using util::double_conversion::DoubleToStringConverter; static constexpr int kMinBufferSize = DoubleToStringConverter::kBase10MaximalLength + 1; @@ -87,5 +89,14 @@ int FloatToStringFormatter::FormatFloat(double v, char* out_buffer, int out_size return builder.position(); } +int FloatToStringFormatter::FormatFloat(uint16_t v, char* out_buffer, int out_size) { + DCHECK_GE(out_size, kMinBufferSize); + util::double_conversion::StringBuilder builder(out_buffer, out_size); + bool result = impl_->converter_.ToShortest(Float16::FromBits(v).ToFloat(), &builder); + DCHECK(result); + ARROW_UNUSED(result); + return builder.position(); +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/formatting.h b/cpp/src/arrow/util/formatting.h index 71bae74629e35..6125f792ff988 100644 --- a/cpp/src/arrow/util/formatting.h +++ b/cpp/src/arrow/util/formatting.h @@ -268,6 +268,7 @@ class ARROW_EXPORT FloatToStringFormatter { // Returns the number of characters written int FormatFloat(float v, char* out_buffer, int out_size); int FormatFloat(double v, char* out_buffer, int out_size); + int FormatFloat(uint16_t v, char* out_buffer, int out_size); protected: struct Impl; @@ -301,6 +302,12 @@ class FloatToStringFormatterMixin : public FloatToStringFormatter { } }; +template <> +class StringFormatter : public FloatToStringFormatterMixin { + public: + using FloatToStringFormatterMixin::FloatToStringFormatterMixin; +}; + template <> class StringFormatter : public FloatToStringFormatterMixin { public: diff --git a/cpp/src/arrow/util/value_parsing.cc b/cpp/src/arrow/util/value_parsing.cc index f6a24ac1467f8..e84aac995e35f 100644 --- a/cpp/src/arrow/util/value_parsing.cc +++ b/cpp/src/arrow/util/value_parsing.cc @@ -22,8 +22,11 @@ #include #include +#include "arrow/util/float16.h" #include "arrow/vendored/fast_float/fast_float.h" +using arrow::util::Float16; + namespace arrow { namespace internal { @@ -43,6 +46,17 @@ bool StringToFloat(const char* s, size_t length, char decimal_point, double* out return res.ec == std::errc() && res.ptr == s + length; } +// Half float +bool StringToFloat(const char* s, size_t length, char decimal_point, uint16_t* out) { + ::arrow_vendored::fast_float::parse_options options{ + ::arrow_vendored::fast_float::chars_format::general, decimal_point}; + float temp_out; + const auto res = + ::arrow_vendored::fast_float::from_chars_advanced(s, s + length, temp_out, options); + *out = Float16::FromFloat(temp_out).bits(); + return res.ec == std::errc() && res.ptr == s + length; +} + // ---------------------------------------------------------------------- // strptime-like parsing diff --git a/cpp/src/arrow/util/value_parsing.h b/cpp/src/arrow/util/value_parsing.h index b3c711840f3e2..609906052cd20 100644 --- a/cpp/src/arrow/util/value_parsing.h +++ b/cpp/src/arrow/util/value_parsing.h @@ -135,6 +135,9 @@ bool StringToFloat(const char* s, size_t length, char decimal_point, float* out) ARROW_EXPORT bool StringToFloat(const char* s, size_t length, char decimal_point, double* out); +ARROW_EXPORT +bool StringToFloat(const char* s, size_t length, char decimal_point, uint16_t* out); + template <> struct StringConverter { using value_type = float; @@ -163,6 +166,20 @@ struct StringConverter { const char decimal_point; }; +template <> +struct StringConverter { + using value_type = uint16_t; + + explicit StringConverter(char decimal_point = '.') : decimal_point(decimal_point) {} + + bool Convert(const HalfFloatType&, const char* s, size_t length, value_type* out) { + return ARROW_PREDICT_TRUE(StringToFloat(s, length, decimal_point, out)); + } + + private: + const char decimal_point; +}; + // NOTE: HalfFloatType would require a half<->float conversion library inline uint8_t ParseDecimalDigit(char c) { return static_cast(c - '0'); } diff --git a/docs/source/status.rst b/docs/source/status.rst index 9af2fd1921e22..71d33eaa6520c 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -40,7 +40,7 @@ Data Types +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | UInt8/16/32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| Float16 | ✓ (1) | ✓ (2) | ✓ | ✓ | ✓ (3)| ✓ | ✓ | | +| Float16 | ✓ | ✓ (1) | ✓ | ✓ | ✓ (2)| ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Float32/64 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ @@ -104,7 +104,7 @@ Data Types | Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia | Swift | | (special) | | | | | | | | | +===================+=======+=======+=======+============+=======+=======+=======+=======+ -| Dictionary | ✓ | ✓ (4) | ✓ | ✓ | ✓ | ✓ (3) | ✓ | | +| Dictionary | ✓ | ✓ (3) | ✓ | ✓ | ✓ | ✓ (3) | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Extension | ✓ | ✓ | ✓ | | | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ @@ -113,10 +113,9 @@ Data Types Notes: -* \(1) Casting to/from Float16 in C++ is not supported. -* \(2) Casting to/from Float16 in Java is not supported. -* \(3) Float16 support in C# is only available when targeting .NET 6+. -* \(4) Nested dictionaries not supported +* \(1) Casting to/from Float16 in Java is not supported. +* \(2) Float16 support in C# is only available when targeting .NET 6+. +* \(3) Nested dictionaries not supported .. seealso:: The :ref:`format_columnar` specification.