Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-17682: [C++][Python] Bool8 Extension Type Implementation #43488

Merged
merged 21 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ endif()

if(ARROW_JSON)
arrow_add_object_library(ARROW_JSON
extension/bool8.cc
extension/fixed_shape_tensor.cc
extension/opaque.cc
json/options.cc
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
# specific language governing permissions and limitations
# under the License.

add_arrow_test(test
SOURCES
bool8_test.cc
PREFIX
"arrow-extension-bool8")

add_arrow_test(test
SOURCES
fixed_shape_tensor_test.cc
Expand Down
61 changes: 61 additions & 0 deletions cpp/src/arrow/extension/bool8.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <sstream>

#include "arrow/extension/bool8.h"
#include "arrow/util/logging.h"

namespace arrow::extension {

bool Bool8Type::ExtensionEquals(const ExtensionType& other) const {
return extension_name() == other.extension_name();
}

std::string Bool8Type::ToString(bool show_metadata) const {
std::stringstream ss;
ss << "extension<" << this->extension_name() << ">";
return ss.str();
}

std::string Bool8Type::Serialize() const { return ""; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emm why is this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what's specified in "description of the serialization" for Bool8.

This method is generally used to encode type parameters, but for bool8 there are no parameters. The type is fully defined by its name and storage type.


Result<std::shared_ptr<DataType>> Bool8Type::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const {
if (storage_type->id() != Type::INT8) {
return Status::Invalid("Expected INT8 storage type, got ", storage_type->ToString());
}
joellubi marked this conversation as resolved.
Show resolved Hide resolved
if (serialized_data != "") {
return Status::Invalid("Serialize data must be empty, got ", serialized_data);
}
return bool8();
}

std::shared_ptr<Array> Bool8Type::MakeArray(std::shared_ptr<ArrayData> data) const {
DCHECK_EQ(data->type->id(), Type::EXTENSION);
DCHECK_EQ("arrow.bool8",
internal::checked_cast<const ExtensionType&>(*data->type).extension_name());
return std::make_shared<Bool8Array>(data);
}

Result<std::shared_ptr<DataType>> Bool8Type::Make() {
return std::make_shared<Bool8Type>();
}

std::shared_ptr<DataType> bool8() { return std::make_shared<Bool8Type>(); }

} // namespace arrow::extension
58 changes: 58 additions & 0 deletions cpp/src/arrow/extension/bool8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/extension_type.h"

namespace arrow::extension {

/// \brief Bool8 is an alternate representation for boolean
/// arrays using 8 bits instead of 1 bit per value. The underlying
/// storage type is int8.
class ARROW_EXPORT Bool8Array : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};

/// \brief Bool8 is an alternate representation for boolean
/// arrays using 8 bits instead of 1 bit per value. The underlying
/// storage type is int8.
class ARROW_EXPORT Bool8Type : public ExtensionType {
public:
/// \brief Construct a Bool8Type.
Bool8Type() : ExtensionType(int8()) {}

std::string extension_name() const override { return "arrow.bool8"; }
std::string ToString(bool show_metadata = false) const override;

bool ExtensionEquals(const ExtensionType& other) const override;

std::string Serialize() const override;

Result<std::shared_ptr<DataType>> Deserialize(
std::shared_ptr<DataType> storage_type,
const std::string& serialized_data) const override;

/// Create a Bool8Array from ArrayData
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

static Result<std::shared_ptr<DataType>> Make();
};

/// \brief Return a Bool8Type instance.
ARROW_EXPORT std::shared_ptr<DataType> bool8();

} // namespace arrow::extension
91 changes: 91 additions & 0 deletions cpp/src/arrow/extension/bool8_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/extension/bool8.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/testing/extension_type.h"
#include "arrow/testing/gtest_util.h"

namespace arrow {

TEST(Bool8Type, Basics) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
auto type2 = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
ASSERT_EQ("arrow.bool8", type->extension_name());
ASSERT_EQ(*type, *type);
ASSERT_NE(*arrow::null(), *type);
ASSERT_EQ(*type, *type2);
ASSERT_EQ(*arrow::int8(), *type->storage_type());
ASSERT_EQ("", type->Serialize());
ASSERT_EQ("extension<arrow.bool8>", type->ToString(false));
}

TEST(Bool8Type, CreateFromArray) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
auto storage = ArrayFromJSON(int8(), "[-1,0,1,2,null]");
auto array = ExtensionType::WrapArray(type, storage);
ASSERT_EQ(5, array->length());
ASSERT_EQ(1, array->null_count());
}

TEST(Bool8Type, Deserialize) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
ASSERT_OK_AND_ASSIGN(auto deserialized, type->Deserialize(type->storage_type(), ""));
ASSERT_EQ(*type, *deserialized);
ASSERT_NOT_OK(type->Deserialize(type->storage_type(), "must be empty"));
ASSERT_EQ(*type, *deserialized);
ASSERT_NOT_OK(type->Deserialize(uint8(), ""));
ASSERT_EQ(*type, *deserialized);
}

TEST(Bool8Type, MetadataRoundTrip) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
std::string serialized = type->Serialize();
ASSERT_OK_AND_ASSIGN(auto deserialized,
type->Deserialize(type->storage_type(), serialized));
ASSERT_EQ(*type, *deserialized);
}

TEST(Bool8Type, BatchRoundTrip) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());

auto storage = ArrayFromJSON(int8(), "[-1,0,1,2,null]");
auto array = ExtensionType::WrapArray(type, storage);
auto batch =
RecordBatch::Make(schema({field("field", type)}), array->length(), {array});

std::shared_ptr<RecordBatch> written;
{
ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
out_stream.get()));

ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());

io::BufferReader reader(complete_ipc_stream);
std::shared_ptr<RecordBatchReader> batch_reader;
ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
ASSERT_OK(batch_reader->ReadNext(&written));
}

ASSERT_EQ(*batch->schema(), *written->schema());
ASSERT_BATCHES_EQUAL(*batch, *written);
}

} // namespace arrow
7 changes: 5 additions & 2 deletions cpp/src/arrow/extension_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/chunked_array.h"
#include "arrow/config.h"
#ifdef ARROW_JSON
#include "arrow/extension/bool8.h"
#include "arrow/extension/fixed_shape_tensor.h"
#endif
#include "arrow/status.h"
Expand Down Expand Up @@ -146,10 +147,12 @@ static void CreateGlobalRegistry() {

#ifdef ARROW_JSON
// Register canonical extension types
auto ext_type =
auto fst_ext_type =
checked_pointer_cast<ExtensionType>(extension::fixed_shape_tensor(int64(), {}));
ARROW_CHECK_OK(g_registry->RegisterType(fst_ext_type));

ARROW_CHECK_OK(g_registry->RegisterType(ext_type));
auto bool8_ext_type = checked_pointer_cast<ExtensionType>(extension::bool8());
ARROW_CHECK_OK(g_registry->RegisterType(bool8_ext_type));
#endif
}

Expand Down
7 changes: 4 additions & 3 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def print_entry(label, value):
run_end_encoded,
fixed_shape_tensor,
opaque,
bool8,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep the PyArrow changes for another PR to make reviewing simpler, IMHO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's ok to leave this PR as it is?

field,
type_for_alias,
DataType, DictionaryType, StructType,
Expand All @@ -184,7 +185,7 @@ def print_entry(label, value):
FixedSizeBinaryType, Decimal128Type, Decimal256Type,
BaseExtensionType, ExtensionType,
RunEndEncodedType, FixedShapeTensorType, OpaqueType,
PyExtensionType, UnknownExtensionType,
Bool8Type, PyExtensionType, UnknownExtensionType,
register_extension_type, unregister_extension_type,
DictionaryMemo,
KeyValueMetadata,
Expand Down Expand Up @@ -218,7 +219,7 @@ def print_entry(label, value):
MonthDayNanoIntervalArray,
Decimal128Array, Decimal256Array, StructArray, ExtensionArray,
RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray,
scalar, NA, _NULL as NULL, Scalar,
Bool8Array, scalar, NA, _NULL as NULL, Scalar,
NullScalar, BooleanScalar,
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar,
Expand All @@ -235,7 +236,7 @@ def print_entry(label, value):
FixedSizeBinaryScalar, DictionaryScalar,
MapScalar, StructScalar, UnionScalar,
RunEndEncodedScalar, ExtensionScalar,
FixedShapeTensorScalar, OpaqueScalar)
FixedShapeTensorScalar, OpaqueScalar, Bool8Scalar)

# Buffers, allocation
from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager,
Expand Down
Loading
Loading