diff --git a/arrowbic/core/base_extension_array.py b/arrowbic/core/base_extension_array.py index baf635f..eb6a3ce 100644 --- a/arrowbic/core/base_extension_array.py +++ b/arrowbic/core/base_extension_array.py @@ -1,7 +1,11 @@ """Implementation of base extension array class used in Arrowbic. """ +from typing import Optional, Sequence, TypeVar + import pyarrow as pa +TItem = TypeVar("TItem") + -class BaseExtensionArray(pa.ExtensionArray): +class BaseExtensionArray(pa.ExtensionArray, Sequence[Optional[TItem]]): """Base extension array, adding interface to make simple operations easier.""" diff --git a/arrowbic/core/base_extension_type.py b/arrowbic/core/base_extension_type.py index 319b9fa..f414123 100644 --- a/arrowbic/core/base_extension_type.py +++ b/arrowbic/core/base_extension_type.py @@ -1,12 +1,14 @@ """Implementation of base extension type class used in Arrowbic. """ import json -from typing import Any, Dict, Iterable, Optional, Type +from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar import pyarrow as pa from .base_extension_array import BaseExtensionArray +TItem = TypeVar("TItem") + def make_extension_name(extension_basename: str, package_name: str) -> str: """Make a full Arrowbic extension name. @@ -21,7 +23,7 @@ def make_extension_name(extension_basename: str, package_name: str) -> str: return extension_name -class BaseExtensionType(pa.ExtensionType): +class BaseExtensionType(pa.ExtensionType, Generic[TItem]): """Base class for all Arrowbic extension type. This class must be the parent class of any extension type registered in Arrowbic. It defines the standard @@ -41,7 +43,7 @@ class BaseExtensionType(pa.ExtensionType): def __init__( self, storage_type: Optional[pa.DataType], - item_pyclass: Optional[Type[Any]], + item_pyclass: Optional[Type[TItem]], package_name: Optional[str] = None, ): self._package_name: str = package_name or "core" @@ -68,7 +70,7 @@ def package_name(self) -> str: return self._package_name @property - def item_pyclass(self) -> Optional[Type[Any]]: + def item_pyclass(self) -> Optional[Type[TItem]]: """Get the item Python class associated with the extension type. None if the extension type instance is a root instance. @@ -136,8 +138,8 @@ def __arrowbic_make_item_pyclass__(cls, storage_type: pa.DataType, ext_metadata: @classmethod def __arrowbic_from_item_iterator__( - cls, it_items: Iterable[Any], size: Optional[int] = None, registry: Optional[Any] = None - ) -> BaseExtensionArray: + cls, it_items: Iterable[Optional[TItem]], size: Optional[int] = None, registry: Optional[Any] = None + ) -> BaseExtensionArray[TItem]: """Build the extension array from a Python item iterator. Args: @@ -160,7 +162,7 @@ def __arrow_ext_serialize__(self) -> bytes: return json.dumps(ext_metadata).encode() @classmethod - def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> "BaseExtensionType": + def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes) -> "BaseExtensionType[None]": """Deserialization of Arrowbic extension type based on the storage type and the metadata. Args: diff --git a/arrowbic/core/extension_type_registry.py b/arrowbic/core/extension_type_registry.py index e7071f9..9736886 100644 --- a/arrowbic/core/extension_type_registry.py +++ b/arrowbic/core/extension_type_registry.py @@ -1,12 +1,15 @@ """Arrowbic extension type main registry implementation. """ import logging -from typing import Any, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union import pyarrow as pa from .base_extension_type import BaseExtensionType +TItem = TypeVar("TItem") +TExtType = TypeVar("TExtType", bound=BaseExtensionType[Any]) + class ExtensionTypeRegistry: """The Arrowbic extension type registry is wrapping the PyArrow extension registry, with additional @@ -26,11 +29,11 @@ class ExtensionTypeRegistry: def __init__(self, sync_with_pyarrow: bool = False): self._sync_with_pyarrow = sync_with_pyarrow # Root extension types, i.e. not attached to a particular item Python class. - self._root_extension_types: Dict[str, BaseExtensionType] = {} + self._root_extension_types: Dict[str, BaseExtensionType[None]] = {} # Cache associating item Python classes to extension types (with all variations of storage type). - self._item_pyclasses_cache: Dict[Type[Any], Dict[pa.DataType, BaseExtensionType]] = {} + self._item_pyclasses_cache: Dict[Type[Any], Dict[pa.DataType, BaseExtensionType[Any]]] = {} - def register_root_extension_type(self, extension_type: BaseExtensionType) -> None: + def register_root_extension_type(self, extension_type: BaseExtensionType[None]) -> None: """Register a (root) extension type in an Arrowbic registry. Args: @@ -59,7 +62,7 @@ def register_root_extension_type(self, extension_type: BaseExtensionType) -> Non if self._sync_with_pyarrow: pa.register_extension_type(extension_type) - def register_item_pyclass(self, item_pyclass: Type[Any]) -> BaseExtensionType: + def register_item_pyclass(self, item_pyclass: Type[TItem]) -> BaseExtensionType[None]: """Add an item Python class in the registry (with future caching of extension types associated to it). Args: @@ -78,7 +81,7 @@ def register_item_pyclass(self, item_pyclass: Type[Any]) -> BaseExtensionType: self._item_pyclasses_cache[item_pyclass] = {pa.null(): root_ext_type} return root_ext_type - def unregister_item_pyclass(self, item_pyclass: Type[Any]): + def unregister_item_pyclass(self, item_pyclass: Type[Any]) -> None: """Unregister a Python item class from the registry. Args: @@ -87,8 +90,8 @@ def unregister_item_pyclass(self, item_pyclass: Type[Any]): self._item_pyclasses_cache.pop(item_pyclass) def find_extension_type( - self, item_pyclass: Type[Any], storage_type: Optional[pa.DataType] = None - ) -> BaseExtensionType: + self, item_pyclass: Type[TItem], storage_type: Optional[pa.DataType] = None + ) -> BaseExtensionType[TItem]: """Find (or make if not yet cached) the extension type corresponding to an item Python class and a storage type. Args: @@ -117,7 +120,7 @@ def find_extension_type( item_pyclass_types_cache[storage_type] = ext_type return ext_type - def _associate_item_pyclass_to_root_extension_type(self, item_pyclass: Type[Any]) -> BaseExtensionType: + def _associate_item_pyclass_to_root_extension_type(self, item_pyclass: Type[TItem]) -> BaseExtensionType[None]: """Find the root extension type to associate to an item Python class. Args: @@ -133,7 +136,7 @@ def _associate_item_pyclass_to_root_extension_type(self, item_pyclass: Type[Any] raise KeyError(f"Could not find any Arrowbic extension type to associate to the Python class '{item_pyclass}'.") @property - def root_extension_types(self) -> List[BaseExtensionType]: + def root_extension_types(self) -> List[BaseExtensionType[None]]: """Get all the root registered extension types.""" return list(self._root_extension_types.values()) @@ -144,27 +147,27 @@ def root_extension_types(self) -> List[BaseExtensionType]: def register_extension_type( - extension_type_cls: Type[BaseExtensionType] = None, + extension_type_cls: Type[TExtType] = None, *, package_name: Optional[str] = None, registry: Optional[ExtensionTypeRegistry] = None, -): +) -> Callable[[Type[TExtType]], Type[TExtType]]: """Extension type class decorator: registering the extension type class in Arrowbic. Args: package_name: Optional package name to use in the extension name. registry: Registry to use. Global one by default. """ - registry = registry or _global_registry + reg = registry if registry is not None else _global_registry - def wrap(_extension_type_cls): + def wrap(_extension_type_cls: Type[TExtType]) -> Type[TExtType]: # Build the default/root instance of the extension type. ext_type = _extension_type_cls( storage_type=None, item_pyclass=None, package_name=package_name, ) - registry.register_root_extension_type(ext_type) + reg.register_root_extension_type(ext_type) return _extension_type_cls # Decorator called with parens: register_extension_type(...) @@ -175,10 +178,10 @@ def wrap(_extension_type_cls): def register_item_pyclass( - item_pyclass: Type[Any] = None, + item_pyclass: Type[TItem] = None, *, registry: Optional[ExtensionTypeRegistry] = None, -): +) -> Union[Type[TItem], Callable[[Type[TItem]], Type[TItem]]]: """Item Python class decorator: registering the item Python class in Arrowbic. The Arrowbic registry is caching all extension type instances and the association to @@ -187,10 +190,10 @@ def register_item_pyclass( Args: registry: Registry to use. Global one by default. """ - registry = registry or _global_registry + reg = registry if registry is not None else _global_registry - def wrap(_item_cls): - registry.register_item_pyclass(_item_cls) + def wrap(_item_cls: Type[TItem]) -> Type[TItem]: + reg.register_item_pyclass(_item_cls) return _item_cls # Decorator called with parens: register_item_pyclass(...) @@ -200,18 +203,18 @@ def wrap(_item_cls): return wrap(item_pyclass) -def unregister_item_pyclass(item_pyclass: Type[Any], *, registry: Optional[ExtensionTypeRegistry] = None): +def unregister_item_pyclass(item_pyclass: Type[TItem], *, registry: Optional[ExtensionTypeRegistry] = None) -> None: """Unregister item Python class from the global Arrowbic registry.""" registry = registry or _global_registry registry.unregister_item_pyclass(item_pyclass) def find_registry_extension_type( - item_pyclass: Type[Any], + item_pyclass: Type[TItem], storage_type: Optional[pa.DataType] = None, *, registry: Optional[ExtensionTypeRegistry] = None, -): +) -> BaseExtensionType[TItem]: """Find an extension type in the Arrowbic registry. Args: diff --git a/arrowbic/extensions/int_enum_array.py b/arrowbic/extensions/int_enum_array.py index 001746b..86852f6 100644 --- a/arrowbic/extensions/int_enum_array.py +++ b/arrowbic/extensions/int_enum_array.py @@ -1,5 +1,10 @@ +from enum import IntEnum +from typing import TypeVar + from arrowbic.core.base_extension_array import BaseExtensionArray +TItem = TypeVar("TItem", bound=IntEnum) + -class IntEnumArray(BaseExtensionArray): +class IntEnumArray(BaseExtensionArray[TItem]): """IntEnum extension array.""" diff --git a/arrowbic/extensions/int_enum_type.py b/arrowbic/extensions/int_enum_type.py index 6e83f73..a52fd17 100644 --- a/arrowbic/extensions/int_enum_type.py +++ b/arrowbic/extensions/int_enum_type.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Any, Dict, Iterable, Optional, Type +from typing import Any, Dict, Iterable, Optional, Type, TypeVar import pyarrow as pa @@ -13,9 +13,11 @@ from .int_enum_array import IntEnumArray +TItem = TypeVar("TItem", bound=IntEnum) + @register_extension_type -class IntEnumType(BaseExtensionType): +class IntEnumType(BaseExtensionType[TItem]): """IntEnum Arrowbic extension type. This extension type enables the support in Arrowbic of standard Python IntEnum, @@ -42,8 +44,8 @@ def __init__( if not is_valid_storage: raise TypeError(f"Invalid Arrow storage type for an IntEnum extension type: {storage_type}.") - def __arrow_ext_class__(self): - return IntEnumArray + def __arrow_ext_class__(self) -> Type[IntEnumArray[TItem]]: + return IntEnumArray[TItem] def __arrowbic_ext_metadata__(self) -> Dict[str, Any]: """Generate the IntEnum extension type metadata, with the full IntEnum @@ -89,10 +91,10 @@ def __arrowbic_make_item_pyclass__(cls, storage_type: pa.DataType, ext_metadata: @classmethod def __arrowbic_from_item_iterator__( cls, - it_items: Iterable[Optional[IntEnum]], + it_items: Iterable[Optional[TItem]], size: Optional[int] = None, registry: Optional[ExtensionTypeRegistry] = None, - ) -> IntEnumArray: + ) -> IntEnumArray[TItem]: """Build the IntEnum extension array from a Python item iterator. Args: diff --git a/setup.cfg b/setup.cfg index b9d1739..d523c3b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,5 +20,19 @@ known_first_party = arrowbic # combine_as_imports = True [mypy] +# Config heavily inspired by Pydantic! python_version = 3.9 show_error_codes = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_unused_configs = True + +[mypy-arrowbic.*] +disallow_any_generics = True +check_untyped_defs = True +# disallow_subclassing_any = True +disallow_incomplete_defs = True +disallow_untyped_decorators = True +disallow_untyped_calls = True +# for strict mypy: (this is the tricky one :-)) +disallow_untyped_defs = True