Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not generate, instantiate, nor carry Arrow extensions around #8153

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 19 additions & 29 deletions crates/build/re_types_builder/src/codegen/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,13 @@ impl PythonCodeGenerator {
let name = &obj.name;

if obj.is_delegating_component() {
vec![name.clone(), format!("{name}Batch"), format!("{name}Type")]
vec![name.clone(), format!("{name}Batch")]
} else {
vec![
format!("{name}"),
format!("{name}ArrayLike"),
format!("{name}Batch"),
format!("{name}Like"),
format!("{name}Type"),
]
}
}
Expand Down Expand Up @@ -408,7 +407,6 @@ impl PythonCodeGenerator {
from {rerun_path}error_utils import catch_and_log_exceptions
from {rerun_path}_baseclasses import (
Archetype,
BaseExtensionType,
BaseBatch,
ComponentBatchMixin,
ComponentMixin,
Expand Down Expand Up @@ -1851,7 +1849,6 @@ fn quote_arrow_support_from_obj(
};

if obj.kind == ObjectKind::Datatype {
type_superclasses.push("BaseExtensionType".to_owned());
batch_superclasses.push(format!("BaseBatch[{many_aliases}]"));
} else if obj.kind == ObjectKind::Component {
if let Some(data_type) = obj.delegate_datatype(objects) {
Expand All @@ -1864,15 +1861,13 @@ fn quote_arrow_support_from_obj(
type_superclasses.push(data_extension_type);
batch_superclasses.push(data_extension_array);
} else {
type_superclasses.push("BaseExtensionType".to_owned());
batch_superclasses.push(format!("BaseBatch[{many_aliases}]"));
}
batch_superclasses.push("ComponentBatchMixin".to_owned());
}

let datatype = quote_arrow_datatype(&arrow_registry.get(fqname));
let extension_batch = format!("{name}Batch");
let extension_type = format!("{name}Type");

let native_to_pa_array_impl = match quote_arrow_serialization(
reporter,
Expand Down Expand Up @@ -1906,32 +1901,32 @@ fn quote_arrow_support_from_obj(
}
};

let type_superclass_decl = if type_superclasses.is_empty() {
String::new()
} else {
format!("({})", type_superclasses.join(","))
};

let batch_superclass_decl = if batch_superclasses.is_empty() {
String::new()
} else {
format!("({})", batch_superclasses.join(","))
};

if obj.kind == ObjectKind::Datatype || obj.is_non_delegating_component() {
if obj.kind == ObjectKind::Datatype {
// Datatypes and non-delegating components declare init
let mut code = unindent(&format!(
r#"
class {extension_type}{type_superclass_decl}:
_TYPE_NAME: str = "{fqname}"

def __init__(self) -> None:
pa.ExtensionType.__init__(
self, {datatype}, self._TYPE_NAME
)
class {extension_batch}{batch_superclass_decl}:
_ARROW_DATATYPE = {datatype}

@staticmethod
def _native_to_pa_array(data: {many_aliases}, data_type: pa.DataType) -> pa.Array:
"#
));
code.push_indented(2, native_to_pa_array_impl, 1);
code
} else if obj.is_non_delegating_component() {
// Datatypes and non-delegating components declare init
let mut code = unindent(&format!(
r#"
class {extension_batch}{batch_superclass_decl}:
_ARROW_TYPE = {extension_type}()
_ARROW_DATATYPE = {datatype}
_COMPONENT_NAME: str = "{fqname}"

@staticmethod
def _native_to_pa_array(data: {many_aliases}, data_type: pa.DataType) -> pa.Array:
Expand All @@ -1943,11 +1938,8 @@ fn quote_arrow_support_from_obj(
// Delegating components are already inheriting from their base type
unindent(&format!(
r#"
class {extension_type}{type_superclass_decl}:
_TYPE_NAME: str = "{fqname}"

class {extension_batch}{batch_superclass_decl}:
_ARROW_TYPE = {extension_type}()
_COMPONENT_NAME: str = "{fqname}"
"#
))
}
Expand Down Expand Up @@ -2072,7 +2064,7 @@ fn quote_arrow_serialization(
// Type checker struggles with this occasionally, exact pattern is unclear.
// Tried casting the array earlier via `cast(Sequence[{name}], data)` but to no avail.
let field_fwd =
format!("{field_batch_type}({field_array}).as_arrow_array().storage, # type: ignore[misc, arg-type]");
format!("{field_batch_type}({field_array}).as_arrow_array(), # type: ignore[misc, arg-type]");
code.push_indented(2, &field_fwd, 1);
}
}
Expand Down Expand Up @@ -2165,9 +2157,7 @@ return pa.array(pa_data, type=data_type)
let variant_list_to_pa_array = match &field.typ {
Type::Object(fqname) => {
let field_type_name = &objects[fqname].name;
format!(
"{field_type_name}Batch({variant_kind_list}).as_arrow_array().storage"
)
format!("{field_type_name}Batch({variant_kind_list}).as_arrow_array()")
}
Type::Unit => {
format!("pa.nulls({variant_kind_list})")
Expand Down
49 changes: 9 additions & 40 deletions rerun_py/rerun_sdk/rerun/_baseclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Generic, Iterable, Protocol, TypeVar
from typing import Generic, Iterable, Protocol, TypeVar

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -56,7 +56,7 @@ def __str__(self) -> str:
comp = getattr(self, fld.name)
datatype = getattr(comp, "type", None)
if datatype:
s += f" {datatype.extension_name}<{datatype.storage_type}>(\n {comp.to_pylist()}\n )\n"
s += f" {datatype.extension_name}<{datatype}>(\n {comp.to_pylist()}\n )\n"
s += ")"

return s
Expand Down Expand Up @@ -112,36 +112,8 @@ def as_component_batches(self) -> Iterable[ComponentBatchLike]:
__repr__ = __str__


class BaseExtensionType(pa.ExtensionType): # type: ignore[misc]
"""Extension type for datatypes and non-delegating components."""

_TYPE_NAME: str
"""The name used when constructing the extension type.

Should following rerun typing conventions:
- `rerun.datatypes.<TYPE>` for datatypes
- `rerun.components.<TYPE>` for components

Many component types simply subclass a datatype type and override
the `_TYPE_NAME` field.
"""

_ARRAY_TYPE: type[pa.ExtensionArray] = pa.ExtensionArray
"""The extension array class associated with this class."""

# Note: (de)serialization is not used in the Python SDK

def __arrow_ext_serialize__(self) -> bytes:
return b""

# noinspection PyMethodOverriding
@classmethod
def __arrow_ext_deserialize__(cls, storage_type: Any, serialized: Any) -> pa.ExtensionType:
return cls()


class BaseBatch(Generic[T]):
_ARROW_TYPE: BaseExtensionType = None # type: ignore[assignment]
_ARROW_DATATYPE: pa.DataType | None = None
"""The pyarrow type of this batch."""

def __init__(self, data: T | None, strict: bool | None = None) -> None:
Expand Down Expand Up @@ -173,17 +145,14 @@ def __init__(self, data: T | None, strict: bool | None = None) -> None:
if data is not None:
with catch_and_log_exceptions(self.__class__.__name__, strict=strict):
# If data is already an arrow array, use it
if isinstance(data, pa.Array) and data.type == self._ARROW_TYPE:
if isinstance(data, pa.Array) and data.type == self._ARROW_DATATYPE:
self.pa_array = data
elif isinstance(data, pa.Array) and data.type == self._ARROW_TYPE.storage_type:
self.pa_array = self._ARROW_TYPE.wrap_array(data)
else:
array = self._native_to_pa_array(data, self._ARROW_TYPE.storage_type)
self.pa_array = self._ARROW_TYPE.wrap_array(array)
self.pa_array = self._native_to_pa_array(data, self._ARROW_DATATYPE)
return

# If we didn't return above, default to the empty array
self.pa_array = _empty_pa_array(self._ARROW_TYPE)
self.pa_array = _empty_pa_array(self._ARROW_DATATYPE)

@classmethod
def _required(cls, data: T | None) -> BaseBatch[T]:
Expand Down Expand Up @@ -317,7 +286,7 @@ def component_name(self) -> str:

Part of the `ComponentBatchLike` logging interface.
"""
return self._ARROW_TYPE._TYPE_NAME # type: ignore[attr-defined, no-any-return]
return self._COMPONENT_NAME # type: ignore[attr-defined, no-any-return]

def partition(self, lengths: npt.ArrayLike) -> ComponentColumn:
"""
Expand Down Expand Up @@ -355,15 +324,15 @@ def arrow_type(cls) -> pa.DataType:

Part of the `ComponentBatchLike` logging interface.
"""
return cls._BATCH_TYPE._ARROW_TYPE.storage_type # type: ignore[attr-defined, no-any-return]
return cls._BATCH_TYPE._ARROW_DATATYPE # type: ignore[attr-defined, no-any-return]

def component_name(self) -> str:
"""
The name of the component.

Part of the `ComponentBatchLike` logging interface.
"""
return self._BATCH_TYPE._ARROW_TYPE._TYPE_NAME # type: ignore[attr-defined, no-any-return]
return self._BATCH_TYPE._COMPONENT_NAME # type: ignore[attr-defined, no-any-return]

def as_arrow_array(self) -> pa.Array:
"""
Expand Down
5 changes: 0 additions & 5 deletions rerun_py/rerun_sdk/rerun/_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,6 @@ def log_components(
else:
added.add(name)

# Strip off the ExtensionArray if it's present. We will always log via component_name.
# TODO(jleibs): Maybe warn if there is a name mismatch here.
if isinstance(array, pa.ExtensionArray):
array = array.storage

instanced[name] = array

bindings.log_arrow_msg( # pyright: ignore[reportGeneralTypeIssues]
Expand Down
2 changes: 1 addition & 1 deletion rerun_py/rerun_sdk/rerun/archetypes/asset_video_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,6 @@ def read_frame_timestamps_ns(self: Any) -> npt.NDArray[np.int64]:
raise RuntimeError("Asset video has no video buffer")

if self.media_type is not None:
media_type = self.media_type.as_arrow_array().storage[0].as_py()
media_type = self.media_type.as_arrow_array()[0].as_py()

return np.array(bindings.asset_video_read_frame_timestamps_ns(video_buffer, media_type), dtype=np.int64)
2 changes: 1 addition & 1 deletion rerun_py/rerun_sdk/rerun/archetypes/bar_chart_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def values__field_converter_override(data: TensorDataArrayLike) -> TensorDataBat

# TODO(jleibs): Doing this on raw arrow data is not great. Clean this up
# once we coerce to a canonical non-arrow type.
shape_dims = tensor_data.as_arrow_array()[0].value["shape"].values.field(0).to_numpy()
shape_dims = tensor_data.as_arrow_array()[0][0].values.field(0).to_numpy()

if len([d for d in shape_dims if d != 1]) != 1:
_send_warning_or_raise(
Expand Down
8 changes: 2 additions & 6 deletions rerun_py/rerun_sdk/rerun/archetypes/image_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def compress(self: Any, jpeg_quality: int = 95) -> EncodedImage | Image:
if self.format is None:
raise ValueError("Cannot JPEG compress an image without a known image_format")

image_format_arrow = self.format.as_arrow_array().storage[0].as_py()
image_format_arrow = self.format.as_arrow_array()[0].as_py()

image_format = ImageFormat(
width=image_format_arrow["width"],
Expand Down Expand Up @@ -292,11 +292,7 @@ def compress(self: Any, jpeg_quality: int = 95) -> EncodedImage | Image:

buf = None
if self.buffer is not None:
buf = (
self.buffer.as_arrow_array()
.storage.values.to_numpy()
.view(image_format.channel_datatype.to_np_dtype())
)
buf = self.buffer.as_arrow_array().values.to_numpy().view(image_format.channel_datatype.to_np_dtype())

if buf is None:
raise ValueError("Cannot JPEG compress an image without data")
Expand Down
Loading
Loading