Skip to content

Commit

Permalink
Stricter equality check
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Sep 24, 2024
1 parent aa6ab95 commit 4d591e5
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
25 changes: 14 additions & 11 deletions cpp/src/arrow/extension/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,13 @@
namespace arrow::extension {

bool JsonExtensionType::ExtensionEquals(const ExtensionType& other) const {
return other.extension_name() == this->extension_name();
return other.extension_name() == this->extension_name() &&
other.storage_type()->Equals(storage_type_);
}

Result<std::shared_ptr<DataType>> JsonExtensionType::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW &&
storage_type->id() != Type::LARGE_STRING) {
return Status::Invalid("Invalid storage type for JsonExtensionType: ",
storage_type->ToString());
}
return std::make_shared<JsonExtensionType>(storage_type);
return JsonExtensionType::Make(storage_type);
}

std::string JsonExtensionType::Serialize() const { return ""; }
Expand All @@ -51,11 +47,18 @@ std::shared_ptr<Array> JsonExtensionType::MakeArray(
return std::make_shared<ExtensionArray>(data);
}

std::shared_ptr<DataType> json(const std::shared_ptr<DataType> storage_type) {
ARROW_CHECK(storage_type->id() != Type::STRING ||
storage_type->id() != Type::STRING_VIEW ||
storage_type->id() != Type::LARGE_STRING);
Result<std::shared_ptr<DataType>> JsonExtensionType::Make(
const std::shared_ptr<DataType>& storage_type) {
if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW &&
storage_type->id() != Type::LARGE_STRING) {
return Status::Invalid("Invalid storage type for JsonExtensionType: ",
storage_type->ToString());
}
return std::make_shared<JsonExtensionType>(storage_type);
}

std::shared_ptr<DataType> json(const std::shared_ptr<DataType>& storage_type) {
return JsonExtensionType::Make(storage_type).ValueOrDie();
}

} // namespace arrow::extension
4 changes: 3 additions & 1 deletion cpp/src/arrow/extension/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ class ARROW_EXPORT JsonExtensionType : public ExtensionType {

std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

static Result<std::shared_ptr<DataType>> Make(const std::shared_ptr<DataType>& storage_type);

private:
std::shared_ptr<DataType> storage_type_;
};

/// \brief Return a JsonExtensionType instance.
ARROW_EXPORT std::shared_ptr<DataType> json(
std::shared_ptr<DataType> storage_type = utf8());
const std::shared_ptr<DataType>& storage_type = utf8());

} // namespace arrow::extension
14 changes: 14 additions & 0 deletions cpp/src/arrow/extension/json_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,18 @@ TEST_F(TestJsonExtensionType, InvalidUTF8) {
}
}

TEST_F(TestJsonExtensionType, StorageTypeValidation) {
ASSERT_TRUE(json(utf8())->Equals(json(utf8())));
ASSERT_FALSE(json(large_utf8())->Equals(json(utf8())));
ASSERT_FALSE(json(utf8_view())->Equals(json(utf8())));
ASSERT_FALSE(json(utf8_view())->Equals(json(large_utf8())));

for (const auto& storage_type : {int16(), binary(), float64(), null()}) {
ASSERT_RAISES_WITH_MESSAGE(Invalid,
"Invalid: Invalid storage type for JsonExtensionType: " +
storage_type->ToString(),
extension::JsonExtensionType::Make(storage_type));
}
}

} // namespace arrow
6 changes: 3 additions & 3 deletions cpp/src/parquet/arrow/arrow_schema_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {
props.set_arrow_extensions_enabled(true);
auto arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()),
true)});
std::shared_ptr<KeyValueMetadata> metadata{};
ASSERT_OK(ConvertSchema(parquet_fields, metadata, props));
Expand All @@ -780,7 +780,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {
::arrow::key_value_metadata({"foo", "bar"}, {"biz", "baz"});
auto arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true, field_metadata),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()),
true)});

std::shared_ptr<KeyValueMetadata> metadata;
Expand All @@ -798,7 +798,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {
::arrow::key_value_metadata({"foo", "bar"}, {"biz", "baz"});
auto arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true, field_metadata),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()),
true)});

std::shared_ptr<KeyValueMetadata> metadata;
Expand Down

0 comments on commit 4d591e5

Please sign in to comment.