Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-44214: [C++] JsonExtensionType equality check ignores storage type #44215

Merged
merged 16 commits into from
Oct 8, 2024
31 changes: 19 additions & 12 deletions cpp/src/arrow/extension/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,13 @@
namespace arrow::extension {

bool JsonExtensionType::ExtensionEquals(const ExtensionType& other) const {
return other.extension_name() == this->extension_name();
return other.extension_name() == this->extension_name() &&
other.storage_type()->Equals(storage_type_);
}

Result<std::shared_ptr<DataType>> JsonExtensionType::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW &&
storage_type->id() != Type::LARGE_STRING) {
return Status::Invalid("Invalid storage type for JsonExtensionType: ",
storage_type->ToString());
}
return std::make_shared<JsonExtensionType>(storage_type);
return JsonExtensionType::Make(std::move(storage_type));
}

std::string JsonExtensionType::Serialize() const { return ""; }
Expand All @@ -51,11 +47,22 @@ std::shared_ptr<Array> JsonExtensionType::MakeArray(
return std::make_shared<ExtensionArray>(data);
}

std::shared_ptr<DataType> json(const std::shared_ptr<DataType> storage_type) {
ARROW_CHECK(storage_type->id() != Type::STRING ||
storage_type->id() != Type::STRING_VIEW ||
storage_type->id() != Type::LARGE_STRING);
return std::make_shared<JsonExtensionType>(storage_type);
bool JsonExtensionType::IsSupportedStorageType(Type::type type_id) {
return type_id == Type::STRING || type_id == Type::STRING_VIEW ||
type_id == Type::LARGE_STRING;
}

Result<std::shared_ptr<DataType>> JsonExtensionType::Make(
std::shared_ptr<DataType> storage_type) {
if (!IsSupportedStorageType(storage_type->id())) {
return Status::Invalid("Invalid storage type for JsonExtensionType: ",
storage_type->ToString());
}
return std::make_shared<JsonExtensionType>(std::move(storage_type));
}

std::shared_ptr<DataType> json(std::shared_ptr<DataType> storage_type) {
return JsonExtensionType::Make(std::move(storage_type)).ValueOrDie();
}

} // namespace arrow::extension
4 changes: 4 additions & 0 deletions cpp/src/arrow/extension/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class ARROW_EXPORT JsonExtensionType : public ExtensionType {

std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

static Result<std::shared_ptr<DataType>> Make(std::shared_ptr<DataType> storage_type);

static bool IsSupportedStorageType(Type::type type_id);

private:
std::shared_ptr<DataType> storage_type_;
};
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/arrow/extension/json_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,18 @@ TEST_F(TestJsonExtensionType, InvalidUTF8) {
}
}

TEST_F(TestJsonExtensionType, StorageTypeValidation) {
ASSERT_TRUE(json(utf8())->Equals(json(utf8())));
ASSERT_FALSE(json(large_utf8())->Equals(json(utf8())));
ASSERT_FALSE(json(utf8_view())->Equals(json(utf8())));
ASSERT_FALSE(json(utf8_view())->Equals(json(large_utf8())));

for (const auto& storage_type : {int16(), binary(), float64(), null()}) {
ASSERT_RAISES_WITH_MESSAGE(Invalid,
"Invalid: Invalid storage type for JsonExtensionType: " +
storage_type->ToString(),
extension::JsonExtensionType::Make(storage_type));
}
}

} // namespace arrow
26 changes: 19 additions & 7 deletions cpp/src/parquet/arrow/arrow_schema_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,23 +757,35 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {

{
// Parquet file does not contain Arrow schema.
// If Arrow extensions are enabled, both fields should be treated as json() extension
// fields.
// If Arrow extensions are enabled, fields will be interpreted as json(utf8())
// extension fields.
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(true);
auto arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
true)});
::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()), true)});
std::shared_ptr<KeyValueMetadata> metadata{};
ASSERT_OK(ConvertSchema(parquet_fields, metadata, props));
CheckFlatSchema(arrow_schema);

// If original data was e.g. json(large_utf8()) it will be interpreted as json(utf8())
// in absence of Arrow schema.
arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
true)});
metadata = std::shared_ptr<KeyValueMetadata>{};
ASSERT_OK(ConvertSchema(parquet_fields, metadata, props));
EXPECT_TRUE(result_schema_->field(1)->type()->Equals(
::arrow::extension::json(::arrow::utf8())));
EXPECT_FALSE(
result_schema_->field(1)->type()->Equals(arrow_schema->field(1)->type()));
}

{
// Parquet file contains Arrow schema.
// Both json_1 and json_2 should be returned as a json() field
// even though extensions are not enabled.
// json_1 and json_2 will be interpreted as json(utf8()) and json(large_utf8())
// fields even though extensions are not enabled.
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(false);
std::shared_ptr<KeyValueMetadata> field_metadata =
Expand All @@ -791,7 +803,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {

{
// Parquet file contains Arrow schema. Extensions are enabled.
// Both json_1 and json_2 should be returned as a json() field
// json_1 and json_2 will be interpreted as json(utf8()) and json(large_utf8()).
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(true);
std::shared_ptr<KeyValueMetadata> field_metadata =
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/parquet/arrow/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -997,9 +997,8 @@ Result<bool> ApplyOriginalMetadata(const Field& origin_field, SchemaField* infer
const auto& ex_type = checked_cast<const ::arrow::ExtensionType&>(*origin_type);
if (inferred_type->id() != ::arrow::Type::EXTENSION &&
ex_type.extension_name() == std::string("arrow.json") &&
(inferred_type->id() == ::arrow::Type::STRING ||
inferred_type->id() == ::arrow::Type::LARGE_STRING ||
inferred_type->id() == ::arrow::Type::STRING_VIEW)) {
::arrow::extension::JsonExtensionType::IsSupportedStorageType(
inferred_type->id())) {
// Schema mismatch.
//
// Arrow extensions are DISABLED in Parquet.
Expand All @@ -1009,6 +1008,18 @@ Result<bool> ApplyOriginalMetadata(const Field& origin_field, SchemaField* infer
// Origin type is restored as Arrow should be considered the source of truth.
inferred->field = inferred->field->WithType(origin_type);
RETURN_NOT_OK(ApplyOriginalStorageMetadata(origin_field, inferred));
} else if (inferred_type->id() == ::arrow::Type::EXTENSION &&
ex_type.extension_name() == std::string("arrow.json")) {
// Potential schema mismatch.
//
// Arrow extensions are ENABLED in Parquet.
// origin_type is arrow::extension::json(...)
// inferred_type is arrow::extension::json(arrow::utf8())
auto origin_storage_field = origin_field.WithType(ex_type.storage_type());

// Apply metadata recursively to storage type
RETURN_NOT_OK(ApplyOriginalStorageMetadata(*origin_storage_field, inferred));
inferred->field = inferred->field->WithType(origin_type);
} else {
auto origin_storage_field = origin_field.WithType(ex_type.storage_type());

Expand Down
Loading