From 4d591e58edf70d220c44306c2f7db36927784c9f Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Tue, 24 Sep 2024 13:45:29 +0200 Subject: [PATCH] Stricter equality check --- cpp/src/arrow/extension/json.cc | 25 ++++++++++++---------- cpp/src/arrow/extension/json.h | 4 +++- cpp/src/arrow/extension/json_test.cc | 14 ++++++++++++ cpp/src/parquet/arrow/arrow_schema_test.cc | 6 +++--- 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/extension/json.cc b/cpp/src/arrow/extension/json.cc index d793233c2b573..5e396738e62df 100644 --- a/cpp/src/arrow/extension/json.cc +++ b/cpp/src/arrow/extension/json.cc @@ -28,17 +28,13 @@ namespace arrow::extension { bool JsonExtensionType::ExtensionEquals(const ExtensionType& other) const { - return other.extension_name() == this->extension_name(); + return other.extension_name() == this->extension_name() && + other.storage_type()->Equals(storage_type_); } Result> JsonExtensionType::Deserialize( std::shared_ptr storage_type, const std::string& serialized) const { - if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW && - storage_type->id() != Type::LARGE_STRING) { - return Status::Invalid("Invalid storage type for JsonExtensionType: ", - storage_type->ToString()); - } - return std::make_shared(storage_type); + return JsonExtensionType::Make(storage_type); } std::string JsonExtensionType::Serialize() const { return ""; } @@ -51,11 +47,18 @@ std::shared_ptr JsonExtensionType::MakeArray( return std::make_shared(data); } -std::shared_ptr json(const std::shared_ptr storage_type) { - ARROW_CHECK(storage_type->id() != Type::STRING || - storage_type->id() != Type::STRING_VIEW || - storage_type->id() != Type::LARGE_STRING); +Result> JsonExtensionType::Make( + const std::shared_ptr& storage_type) { + if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW && + storage_type->id() != Type::LARGE_STRING) { + return Status::Invalid("Invalid storage type for JsonExtensionType: ", + storage_type->ToString()); + } return std::make_shared(storage_type); } +std::shared_ptr json(const std::shared_ptr& storage_type) { + return JsonExtensionType::Make(storage_type).ValueOrDie(); +} + } // namespace arrow::extension diff --git a/cpp/src/arrow/extension/json.h b/cpp/src/arrow/extension/json.h index 4793ab2bc9b36..4d475536cff59 100644 --- a/cpp/src/arrow/extension/json.h +++ b/cpp/src/arrow/extension/json.h @@ -45,12 +45,14 @@ class ARROW_EXPORT JsonExtensionType : public ExtensionType { std::shared_ptr MakeArray(std::shared_ptr data) const override; + static Result> Make(const std::shared_ptr& storage_type); + private: std::shared_ptr storage_type_; }; /// \brief Return a JsonExtensionType instance. ARROW_EXPORT std::shared_ptr json( - std::shared_ptr storage_type = utf8()); + const std::shared_ptr& storage_type = utf8()); } // namespace arrow::extension diff --git a/cpp/src/arrow/extension/json_test.cc b/cpp/src/arrow/extension/json_test.cc index 143e4f9ceeac7..b938ddb2cfef3 100644 --- a/cpp/src/arrow/extension/json_test.cc +++ b/cpp/src/arrow/extension/json_test.cc @@ -80,4 +80,18 @@ TEST_F(TestJsonExtensionType, InvalidUTF8) { } } +TEST_F(TestJsonExtensionType, StorageTypeValidation) { + ASSERT_TRUE(json(utf8())->Equals(json(utf8()))); + ASSERT_FALSE(json(large_utf8())->Equals(json(utf8()))); + ASSERT_FALSE(json(utf8_view())->Equals(json(utf8()))); + ASSERT_FALSE(json(utf8_view())->Equals(json(large_utf8()))); + + for (const auto& storage_type : {int16(), binary(), float64(), null()}) { + ASSERT_RAISES_WITH_MESSAGE(Invalid, + "Invalid: Invalid storage type for JsonExtensionType: " + + storage_type->ToString(), + extension::JsonExtensionType::Make(storage_type)); + } +} + } // namespace arrow diff --git a/cpp/src/parquet/arrow/arrow_schema_test.cc b/cpp/src/parquet/arrow/arrow_schema_test.cc index 31ead461aa6e2..334e2919d46ac 100644 --- a/cpp/src/parquet/arrow/arrow_schema_test.cc +++ b/cpp/src/parquet/arrow/arrow_schema_test.cc @@ -763,7 +763,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) { props.set_arrow_extensions_enabled(true); auto arrow_schema = ::arrow::schema( {::arrow::field("json_1", ::arrow::extension::json(), true), - ::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()), + ::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()), true)}); std::shared_ptr metadata{}; ASSERT_OK(ConvertSchema(parquet_fields, metadata, props)); @@ -780,7 +780,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) { ::arrow::key_value_metadata({"foo", "bar"}, {"biz", "baz"}); auto arrow_schema = ::arrow::schema( {::arrow::field("json_1", ::arrow::extension::json(), true, field_metadata), - ::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()), + ::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()), true)}); std::shared_ptr metadata; @@ -798,7 +798,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) { ::arrow::key_value_metadata({"foo", "bar"}, {"biz", "baz"}); auto arrow_schema = ::arrow::schema( {::arrow::field("json_1", ::arrow::extension::json(), true, field_metadata), - ::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()), + ::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()), true)}); std::shared_ptr metadata;