Skip to content

Commit

Permalink
Add generic typing to Arrowbic extension type and array classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jan 18, 2022
1 parent fe6efff commit 7c5a303
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 38 deletions.
6 changes: 5 additions & 1 deletion arrowbic/core/base_extension_array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Implementation of base extension array class used in Arrowbic.
"""
from typing import Optional, Sequence, TypeVar

import pyarrow as pa

TItem = TypeVar("TItem")


class BaseExtensionArray(pa.ExtensionArray):
class BaseExtensionArray(pa.ExtensionArray, Sequence[Optional[TItem]]):
"""Base extension array, adding interface to make simple operations easier."""
16 changes: 9 additions & 7 deletions arrowbic/core/base_extension_type.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Implementation of base extension type class used in Arrowbic.
"""
import json
from typing import Any, Dict, Iterable, Optional, Type
from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar

import pyarrow as pa

from .base_extension_array import BaseExtensionArray

TItem = TypeVar("TItem")


def make_extension_name(extension_basename: str, package_name: str) -> str:
"""Make a full Arrowbic extension name.
Expand All @@ -21,7 +23,7 @@ def make_extension_name(extension_basename: str, package_name: str) -> str:
return extension_name


class BaseExtensionType(pa.ExtensionType):
class BaseExtensionType(pa.ExtensionType, Generic[TItem]):
"""Base class for all Arrowbic extension type.
This class must be the parent class of any extension type registered in Arrowbic. It defines the standard
Expand All @@ -41,7 +43,7 @@ class BaseExtensionType(pa.ExtensionType):
def __init__(
self,
storage_type: Optional[pa.DataType],
item_pyclass: Optional[Type[Any]],
item_pyclass: Optional[Type[TItem]],
package_name: Optional[str] = None,
):
self._package_name: str = package_name or "core"
Expand All @@ -68,7 +70,7 @@ def package_name(self) -> str:
return self._package_name

@property
def item_pyclass(self) -> Optional[Type[Any]]:
def item_pyclass(self) -> Optional[Type[TItem]]:
"""Get the item Python class associated with the extension type.
None if the extension type instance is a root instance.
Expand Down Expand Up @@ -136,8 +138,8 @@ def __arrowbic_make_item_pyclass__(cls, storage_type: pa.DataType, ext_metadata:

@classmethod
def __arrowbic_from_item_iterator__(
cls, it_items: Iterable[Any], size: Optional[int] = None, registry: Optional[Any] = None
) -> BaseExtensionArray:
cls, it_items: Iterable[Optional[TItem]], size: Optional[int] = None, registry: Optional[Any] = None
) -> BaseExtensionArray[TItem]:
"""Build the extension array from a Python item iterator.
Args:
Expand All @@ -160,7 +162,7 @@ def __arrow_ext_serialize__(self) -> bytes:
return json.dumps(ext_metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> "BaseExtensionType":
def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> "BaseExtensionType[None]":
"""Deserialization of Arrowbic extension type based on the storage type and the metadata.
Args:
Expand Down
49 changes: 26 additions & 23 deletions arrowbic/core/extension_type_registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Arrowbic extension type main registry implementation.
"""
import logging
from typing import Any, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

import pyarrow as pa

from .base_extension_type import BaseExtensionType

TItem = TypeVar("TItem")
TExtType = TypeVar("TExtType", bound=BaseExtensionType[Any])


class ExtensionTypeRegistry:
"""The Arrowbic extension type registry is wrapping the PyArrow extension registry, with additional
Expand All @@ -26,11 +29,11 @@ class ExtensionTypeRegistry:
def __init__(self, sync_with_pyarrow: bool = False):
self._sync_with_pyarrow = sync_with_pyarrow
# Root extension types, i.e. not attached to a particular item Python class.
self._root_extension_types: Dict[str, BaseExtensionType] = {}
self._root_extension_types: Dict[str, BaseExtensionType[None]] = {}
# Cache associating item Python classes to extension types (with all variations of storage type).
self._item_pyclasses_cache: Dict[Type[Any], Dict[pa.DataType, BaseExtensionType]] = {}
self._item_pyclasses_cache: Dict[Type[Any], Dict[pa.DataType, BaseExtensionType[Any]]] = {}

def register_root_extension_type(self, extension_type: BaseExtensionType) -> None:
def register_root_extension_type(self, extension_type: BaseExtensionType[None]) -> None:
"""Register a (root) extension type in an Arrowbic registry.
Args:
Expand Down Expand Up @@ -59,7 +62,7 @@ def register_root_extension_type(self, extension_type: BaseExtensionType) -> Non
if self._sync_with_pyarrow:
pa.register_extension_type(extension_type)

def register_item_pyclass(self, item_pyclass: Type[Any]) -> BaseExtensionType:
def register_item_pyclass(self, item_pyclass: Type[TItem]) -> BaseExtensionType[None]:
"""Add an item Python class in the registry (with future caching of extension types associated to it).
Args:
Expand All @@ -78,7 +81,7 @@ def register_item_pyclass(self, item_pyclass: Type[Any]) -> BaseExtensionType:
self._item_pyclasses_cache[item_pyclass] = {pa.null(): root_ext_type}
return root_ext_type

def unregister_item_pyclass(self, item_pyclass: Type[Any]):
def unregister_item_pyclass(self, item_pyclass: Type[Any]) -> None:
"""Unregister a Python item class from the registry.
Args:
Expand All @@ -87,8 +90,8 @@ def unregister_item_pyclass(self, item_pyclass: Type[Any]):
self._item_pyclasses_cache.pop(item_pyclass)

def find_extension_type(
self, item_pyclass: Type[Any], storage_type: Optional[pa.DataType] = None
) -> BaseExtensionType:
self, item_pyclass: Type[TItem], storage_type: Optional[pa.DataType] = None
) -> BaseExtensionType[TItem]:
"""Find (or make if not yet cached) the extension type corresponding to an item Python class and a storage type.
Args:
Expand Down Expand Up @@ -117,7 +120,7 @@ def find_extension_type(
item_pyclass_types_cache[storage_type] = ext_type
return ext_type

def _associate_item_pyclass_to_root_extension_type(self, item_pyclass: Type[Any]) -> BaseExtensionType:
def _associate_item_pyclass_to_root_extension_type(self, item_pyclass: Type[TItem]) -> BaseExtensionType[None]:
"""Find the root extension type to associate to an item Python class.
Args:
Expand All @@ -133,7 +136,7 @@ def _associate_item_pyclass_to_root_extension_type(self, item_pyclass: Type[Any]
raise KeyError(f"Could not find any Arrowbic extension type to associate to the Python class '{item_pyclass}'.")

@property
def root_extension_types(self) -> List[BaseExtensionType]:
def root_extension_types(self) -> List[BaseExtensionType[None]]:
"""Get all the root registered extension types."""
return list(self._root_extension_types.values())

Expand All @@ -144,27 +147,27 @@ def root_extension_types(self) -> List[BaseExtensionType]:


def register_extension_type(
extension_type_cls: Type[BaseExtensionType] = None,
extension_type_cls: Type[TExtType] = None,
*,
package_name: Optional[str] = None,
registry: Optional[ExtensionTypeRegistry] = None,
):
) -> Callable[[Type[TExtType]], Type[TExtType]]:
"""Extension type class decorator: registering the extension type class in Arrowbic.
Args:
package_name: Optional package name to use in the extension name.
registry: Registry to use. Global one by default.
"""
registry = registry or _global_registry
reg = registry if registry is not None else _global_registry

def wrap(_extension_type_cls):
def wrap(_extension_type_cls: Type[TExtType]) -> Type[TExtType]:
# Build the default/root instance of the extension type.
ext_type = _extension_type_cls(
storage_type=None,
item_pyclass=None,
package_name=package_name,
)
registry.register_root_extension_type(ext_type)
reg.register_root_extension_type(ext_type)
return _extension_type_cls

# Decorator called with parens: register_extension_type(...)
Expand All @@ -175,10 +178,10 @@ def wrap(_extension_type_cls):


def register_item_pyclass(
item_pyclass: Type[Any] = None,
item_pyclass: Type[TItem] = None,
*,
registry: Optional[ExtensionTypeRegistry] = None,
):
) -> Union[Type[TItem], Callable[[Type[TItem]], Type[TItem]]]:
"""Item Python class decorator: registering the item Python class in Arrowbic.
The Arrowbic registry is caching all extension type instances and the association to
Expand All @@ -187,10 +190,10 @@ def register_item_pyclass(
Args:
registry: Registry to use. Global one by default.
"""
registry = registry or _global_registry
reg = registry if registry is not None else _global_registry

def wrap(_item_cls):
registry.register_item_pyclass(_item_cls)
def wrap(_item_cls: Type[TItem]) -> Type[TItem]:
reg.register_item_pyclass(_item_cls)
return _item_cls

# Decorator called with parens: register_item_pyclass(...)
Expand All @@ -200,18 +203,18 @@ def wrap(_item_cls):
return wrap(item_pyclass)


def unregister_item_pyclass(item_pyclass: Type[Any], *, registry: Optional[ExtensionTypeRegistry] = None):
def unregister_item_pyclass(item_pyclass: Type[TItem], *, registry: Optional[ExtensionTypeRegistry] = None) -> None:
"""Unregister item Python class from the global Arrowbic registry."""
registry = registry or _global_registry
registry.unregister_item_pyclass(item_pyclass)


def find_registry_extension_type(
item_pyclass: Type[Any],
item_pyclass: Type[TItem],
storage_type: Optional[pa.DataType] = None,
*,
registry: Optional[ExtensionTypeRegistry] = None,
):
) -> BaseExtensionType[TItem]:
"""Find an extension type in the Arrowbic registry.
Args:
Expand Down
7 changes: 6 additions & 1 deletion arrowbic/extensions/int_enum_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from enum import IntEnum
from typing import TypeVar

from arrowbic.core.base_extension_array import BaseExtensionArray

TItem = TypeVar("TItem", bound=IntEnum)


class IntEnumArray(BaseExtensionArray):
class IntEnumArray(BaseExtensionArray[TItem]):
"""IntEnum extension array."""
14 changes: 8 additions & 6 deletions arrowbic/extensions/int_enum_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from typing import Any, Dict, Iterable, Optional, Type
from typing import Any, Dict, Iterable, Optional, Type, TypeVar

import pyarrow as pa

Expand All @@ -13,9 +13,11 @@

from .int_enum_array import IntEnumArray

TItem = TypeVar("TItem", bound=IntEnum)


@register_extension_type
class IntEnumType(BaseExtensionType):
class IntEnumType(BaseExtensionType[TItem]):
"""IntEnum Arrowbic extension type.
This extension type enables the support in Arrowbic of standard Python IntEnum,
Expand All @@ -42,8 +44,8 @@ def __init__(
if not is_valid_storage:
raise TypeError(f"Invalid Arrow storage type for an IntEnum extension type: {storage_type}.")

def __arrow_ext_class__(self):
return IntEnumArray
def __arrow_ext_class__(self) -> Type[IntEnumArray[TItem]]:
return IntEnumArray[TItem]

def __arrowbic_ext_metadata__(self) -> Dict[str, Any]:
"""Generate the IntEnum extension type metadata, with the full IntEnum
Expand Down Expand Up @@ -89,10 +91,10 @@ def __arrowbic_make_item_pyclass__(cls, storage_type: pa.DataType, ext_metadata:
@classmethod
def __arrowbic_from_item_iterator__(
cls,
it_items: Iterable[Optional[IntEnum]],
it_items: Iterable[Optional[TItem]],
size: Optional[int] = None,
registry: Optional[ExtensionTypeRegistry] = None,
) -> IntEnumArray:
) -> IntEnumArray[TItem]:
"""Build the IntEnum extension array from a Python item iterator.
Args:
Expand Down
14 changes: 14 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,19 @@ known_first_party = arrowbic
# combine_as_imports = True

[mypy]
# Config heavily inspired by Pydantic!
python_version = 3.9
show_error_codes = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_unused_configs = True

[mypy-arrowbic.*]
disallow_any_generics = True
check_untyped_defs = True
# disallow_subclassing_any = True
disallow_incomplete_defs = True
disallow_untyped_decorators = True
disallow_untyped_calls = True
# for strict mypy: (this is the tricky one :-))
disallow_untyped_defs = True

0 comments on commit 7c5a303

Please sign in to comment.