diff --git a/python/src/iceberg/expressions/base.py b/python/src/iceberg/expressions/base.py index 216ae4f802cf..71cf20a000c9 100644 --- a/python/src/iceberg/expressions/base.py +++ b/python/src/iceberg/expressions/base.py @@ -20,7 +20,7 @@ from typing import Any, Generic, TypeVar from iceberg.files import StructProtocol -from iceberg.schema import Schema +from iceberg.schema import Accessor, Schema from iceberg.types import NestedField, Singleton T = TypeVar("T") @@ -267,35 +267,6 @@ def __str__(self) -> str: return "false" -class Accessor: - """An accessor for a specific position in a container that implements the StructProtocol""" - - def __init__(self, position: int): - self._position = position - - def __str__(self): - return f"Accessor(position={self._position})" - - def __repr__(self): - return f"Accessor(position={self._position})" - - @property - def position(self): - """The position in the container to access""" - return self._position - - def get(self, container: StructProtocol) -> Any: - """Returns the value at self.position in `container` - - Args: - container(StructProtocol): A container to access at position `self.position` - - Returns: - Any: The value at position `self.position` in the container - """ - return container.get(self.position) - - class BoundReference: """A reference bound to a field in a schema diff --git a/python/src/iceberg/files.py b/python/src/iceberg/files.py index 409aae572844..c077574ce98c 100644 --- a/python/src/iceberg/files.py +++ b/python/src/iceberg/files.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from abc import abstractmethod from enum import Enum, auto from typing import Any @@ -46,8 +46,10 @@ class FileFormat(Enum): class StructProtocol(Protocol): # pragma: no cover """A generic protocol used by accessors to get and set at positions of an object""" + @abstractmethod def get(self, pos: int) -> Any: ... + @abstractmethod def set(self, pos: int, value) -> None: ... diff --git a/python/src/iceberg/schema.py b/python/src/iceberg/schema.py index f3790036e2c9..3ae23feda7e9 100644 --- a/python/src/iceberg/schema.py +++ b/python/src/iceberg/schema.py @@ -19,10 +19,10 @@ import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Generic, Iterable, List, TypeVar +from dataclasses import dataclass +from typing import Any, Dict, Generic, Iterable, List, Optional, TypeVar -if TYPE_CHECKING: - from iceberg.expressions.base import Accessor +from iceberg.files import StructProtocol if sys.version_info >= (3, 8): from functools import singledispatch # pragma: no cover @@ -49,10 +49,10 @@ class Schema: >>> from iceberg import types """ - def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_field_ids: List[int] = []): + def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_field_ids: Optional[List[int]] = None): self._struct = StructType(*columns) # type: ignore self._schema_id = schema_id - self._identifier_field_ids = identifier_field_ids + self._identifier_field_ids = identifier_field_ids or [] self._name_to_id: Dict[str, int] = index_by_name(self) self._name_to_id_lower: Dict[str, int] = {} # Should be accessed through self._lazy_name_to_id_lower() self._id_to_field: Dict[int, NestedField] = {} # Should be accessed through self._lazy_id_to_field() @@ -262,10 +262,44 @@ def primitive(self, primitive: PrimitiveType) -> T: ... # pragma: no cover +@dataclass(init=True, eq=True, frozen=True) +class Accessor: + """An accessor for a specific position in a container that implements the StructProtocol""" + + position: int + inner: Optional["Accessor"] = None + + def __str__(self): + return f"Accessor(position={self.position},inner={self.inner})" + + def __repr__(self): + return self.__str__() + + def get(self, container: StructProtocol) -> Any: + """Returns the value at self.position in `container` + + Args: + container(StructProtocol): A container to access at position `self.position` + + Returns: + Any: The value at position `self.position` in the container + """ + pos = self.position + val = container.get(pos) + inner = self + while inner.inner: + inner = inner.inner + val = val.get(inner.position) + + return val + + @singledispatch def visit(obj, visitor: SchemaVisitor[T]) -> T: """A generic function for applying a schema visitor to any point within a schema + The function traverses the schema in post-order fashion + Args: obj(Schema | IcebergType): An instance of a Schema or an IcebergType visitor (SchemaVisitor[T]): An instance of an implementation of the generic SchemaVisitor base class @@ -286,13 +320,11 @@ def _(obj: Schema, visitor: SchemaVisitor[T]) -> T: def _(obj: StructType, visitor: SchemaVisitor[T]) -> T: """Visit a StructType with a concrete SchemaVisitor""" results = [] + for field in obj.fields: visitor.before_field(field) - try: - result = visit(field.type, visitor) - finally: - visitor.after_field(field) - + result = visit(field.type, visitor) + visitor.after_field(field) results.append(visitor.field(field, result)) return visitor.struct(obj, results) @@ -301,11 +333,10 @@ def _(obj: StructType, visitor: SchemaVisitor[T]) -> T: @visit.register(ListType) def _(obj: ListType, visitor: SchemaVisitor[T]) -> T: """Visit a ListType with a concrete SchemaVisitor""" + visitor.before_list_element(obj.element) - try: - result = visit(obj.element.type, visitor) - finally: - visitor.after_list_element(obj.element) + result = visit(obj.element.type, visitor) + visitor.after_list_element(obj.element) return visitor.list(obj, result) @@ -314,16 +345,12 @@ def _(obj: ListType, visitor: SchemaVisitor[T]) -> T: def _(obj: MapType, visitor: SchemaVisitor[T]) -> T: """Visit a MapType with a concrete SchemaVisitor""" visitor.before_map_key(obj.key) - try: - key_result = visit(obj.key.type, visitor) - finally: - visitor.after_map_key(obj.key) + key_result = visit(obj.key.type, visitor) + visitor.after_map_key(obj.key) visitor.before_map_value(obj.value) - try: - value_result = visit(obj.value.type, visitor) - finally: - visitor.after_list_element(obj.value) + value_result = visit(obj.value.type, visitor) + visitor.after_list_element(obj.value) return visitor.map(obj, key_result, value_result) @@ -340,29 +367,29 @@ class _IndexById(SchemaVisitor[Dict[int, NestedField]]): def __init__(self) -> None: self._index: Dict[int, NestedField] = {} - def schema(self, schema, result): + def schema(self, schema: Schema, result) -> Dict[int, NestedField]: return self._index - def struct(self, struct, results): + def struct(self, struct: StructType, result) -> Dict[int, NestedField]: return self._index - def field(self, field, result): + def field(self, field: NestedField, result) -> Dict[int, NestedField]: """Add the field ID to the index""" self._index[field.field_id] = field return self._index - def list(self, list_type, result): + def list(self, list_type: ListType, result) -> Dict[int, NestedField]: """Add the list element ID to the index""" self._index[list_type.element.field_id] = list_type.element return self._index - def map(self, map_type, key_result, value_result): + def map(self, map_type: MapType, key_result, value_result) -> Dict[int, NestedField]: """Add the key ID and value ID as individual items in the index""" self._index[map_type.key.field_id] = map_type.key self._index[map_type.value.field_id] = map_type.value return self._index - def primitive(self, primitive): + def primitive(self, primitive) -> Dict[int, NestedField]: return self._index @@ -409,24 +436,27 @@ def after_field(self, field: NestedField) -> None: self._field_names.pop() self._short_field_names.pop() - def schema(self, schema, struct_result): + def schema(self, schema: Schema, struct_result: Dict[str, int]) -> Dict[str, int]: return self._index - def struct(self, struct, field_results): + def struct(self, struct: StructType, struct_result: List[Dict[str, int]]) -> Dict[str, int]: return self._index - def field(self, field, field_result): + def field(self, field: NestedField, struct_result: Dict[str, int]) -> Dict[str, int]: """Add the field name to the index""" self._add_field(field.name, field.field_id) + return self._index - def list(self, list_type, result): + def list(self, list_type: ListType, struct_result: Dict[str, int]) -> Dict[str, int]: """Add the list element name to the index""" self._add_field(list_type.element.name, list_type.element.field_id) + return self._index - def map(self, map_type, key_result, value_result): + def map(self, map_type: MapType, key_result: Dict[str, int], value_result: Dict[str, int]) -> Dict[str, int]: """Add the key name and value name as individual items in the index""" self._add_field(map_type.key.name, map_type.key.field_id) self._add_field(map_type.value.name, map_type.value.field_id) + return self._index def _add_field(self, name: str, field_id: int): """Add a field name to the index, mapping its full name to its field ID @@ -451,10 +481,10 @@ def _add_field(self, name: str, field_id: int): short_name = ".".join([".".join(self._short_field_names), name]) self._short_name_to_id[short_name] = field_id - def primitive(self, primitive): + def primitive(self, primitive) -> Dict[str, int]: return self._index - def by_name(self): + def by_name(self) -> Dict[str, int]: """Returns an index of combined full and short names Note: Only short names that do not conflict with full names are included. @@ -463,13 +493,13 @@ def by_name(self): combined_index.update(self._index) return combined_index - def by_id(self): + def by_id(self) -> Dict[int, str]: """Returns an index of ID to full names""" id_to_full_name = dict([(value, key) for key, value in self._index.items()]) return id_to_full_name -def index_by_name(schema_or_type) -> Dict[str, int]: +def index_by_name(schema_or_type: Schema | IcebergType) -> Dict[str, int]: """Generate an index of field names to field IDs Args: @@ -483,7 +513,7 @@ def index_by_name(schema_or_type) -> Dict[str, int]: return indexer.by_name() -def index_name_by_id(schema_or_type) -> Dict[int, str]: +def index_name_by_id(schema_or_type: Schema | IcebergType) -> Dict[int, str]: """Generate an index of field IDs full field names Args: @@ -497,31 +527,73 @@ def index_name_by_id(schema_or_type) -> Dict[int, str]: return indexer.by_id() -class _BuildPositionAccessors(SchemaVisitor[Dict[int, "Accessor"]]): - """A schema visitor for generating a field ID to accessor index""" +Position = int - def __init__(self) -> None: - self._index: Dict[int, Accessor] = {} - def schema(self, schema, result: Dict[int, Accessor]) -> Dict[int, Accessor]: - return self._index +class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]): + """A schema visitor for generating a field ID to accessor index - def struct(self, struct, result: List[Dict[int, Accessor]]) -> Dict[int, Accessor]: - # TODO: Populate the `self._index` dictionary where the key is the field ID and the value is an accessor for that field. - # The equivalent java logic can be found here: https://github.com/apache/iceberg/blob/master/api/src/main/java/org/apache/iceberg/Accessors.java#L213-L230 - return self._index + Example: + >>> from iceberg.schema import Schema + >>> from iceberg.types import * + >>> schema = Schema( + ... NestedField(field_id=2, name="id", field_type=IntegerType(), is_optional=False), + ... NestedField(field_id=1, name="data", field_type=StringType(), is_optional=True), + ... NestedField( + ... field_id=3, + ... name="location", + ... field_type=StructType( + ... NestedField(field_id=5, name="latitude", field_type=FloatType(), is_optional=False), + ... NestedField(field_id=6, name="longitude", field_type=FloatType(), is_optional=False), + ... ), + ... is_optional=True, + ... ), + ... schema_id=1, + ... identifier_field_ids=[1], + ... ) + >>> result = build_position_accessors(schema) + >>> expected = { + ... 2: Accessor(position=0, inner=None), + ... 1: Accessor(position=1, inner=None), + ... 5: Accessor(position=2, inner=Accessor(position=0, inner=None)), + ... 6: Accessor(position=2, inner=Accessor(position=1, inner=None)) + ... } + >>> result == expected + True + """ - def field(self, field: NestedField, result: Dict[int, Accessor]) -> Dict[int, Accessor]: - return self._index + @staticmethod + def _wrap_leaves(result: Dict[Position, Accessor], position: Position = 0) -> Dict[Position, Accessor]: + return {field_id: Accessor(position, inner=inner) for field_id, inner in result.items()} - def list(self, list_type: ListType, result: Dict[int, Accessor]) -> Dict[int, Accessor]: - return self._index + def schema(self, schema: Schema, result: Dict[Position, Accessor]) -> Dict[Position, Accessor]: + return result - def map(self, map_type: MapType, key_result: Dict[int, Accessor], value_result: Dict[int, Accessor]) -> Dict[int, Accessor]: - return self._index + def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor]]) -> Dict[Position, Accessor]: + result = {} - def primitive(self, primitive: PrimitiveType) -> Dict[int, Accessor]: - return self._index + for position, field in enumerate(struct.fields): + if field_results[position]: + for inner_field_id, acc in field_results[position].items(): + result[inner_field_id] = Accessor(position, inner=acc) + else: + result[field.field_id] = Accessor(position) + + return result + + def field(self, field: NestedField, result: Dict[Position, Accessor]) -> Dict[Position, Accessor]: + return result + + def list(self, list_type: ListType, result: Dict[Position, Accessor]) -> Dict[Position, Accessor]: + return {} + + def map( + self, map_type: MapType, key_result: Dict[Position, Accessor], value_result: Dict[Position, Accessor] + ) -> Dict[Position, Accessor]: + return {} + + def primitive(self, primitive: PrimitiveType) -> Dict[Position, Accessor]: + return {} def build_position_accessors(schema_or_type: Schema | IcebergType) -> Dict[int, Accessor]: diff --git a/python/tests/conftest.py b/python/tests/conftest.py index a582bb3c0d6a..5f9a13a47bc6 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -95,6 +95,15 @@ def table_schema_nested(): ), is_optional=True, ), + NestedField( + field_id=15, + name="person", + field_type=StructType( + NestedField(field_id=16, name="name", field_type=StringType(), is_optional=False), + NestedField(field_id=17, name="age", field_type=IntegerType(), is_optional=True), + ), + is_optional=False, + ), schema_id=1, identifier_field_ids=[1], ) diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 131974abeca6..eee9fbd676fc 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -16,10 +16,14 @@ # under the License. from textwrap import dedent +from typing import Any, Dict import pytest from iceberg import schema +from iceberg.expressions.base import Accessor +from iceberg.files import StructProtocol +from iceberg.schema import build_position_accessors from iceberg.types import ( BooleanType, FloatType, @@ -140,6 +144,17 @@ def test_schema_index_by_id_visitor(table_schema_nested): ), 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + 15: NestedField( + field_id=15, + name="person", + field_type=StructType( + NestedField(field_id=16, name="name", field_type=StringType(), is_optional=False), + NestedField(field_id=17, name="age", field_type=IntegerType(), is_optional=True), + ), + is_optional=False, + ), + 16: NestedField(field_id=16, name="name", field_type=StringType(), is_optional=False), + 17: NestedField(field_id=17, name="age", field_type=IntegerType(), is_optional=True), } @@ -163,6 +178,9 @@ def test_schema_index_by_name_visitor(table_schema_nested): "location.element.longitude": 14, "location.latitude": 13, "location.longitude": 14, + "person": 15, + "person.name": 16, + "person.age": 17, } @@ -295,6 +313,17 @@ def test_index_by_id_schema_visitor(table_schema_nested): ), 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + 15: NestedField( + field_id=15, + name="person", + field_type=StructType( + NestedField(field_id=16, name="name", field_type=StringType(), is_optional=False), + NestedField(field_id=17, name="age", field_type=IntegerType(), is_optional=True), + ), + is_optional=False, + ), + 16: NestedField(field_id=16, name="name", field_type=StringType(), is_optional=False), + 17: NestedField(field_id=17, name="age", field_type=IntegerType(), is_optional=True), } @@ -347,3 +376,34 @@ def test_schema_find_type(table_schema_simple): == table_schema_simple.find_type("BAZ", case_sensitive=False) == BooleanType() ) + + +def test_build_position_accessors(table_schema_nested): + accessors = build_position_accessors(table_schema_nested) + assert accessors == { + 1: Accessor(position=0, inner=None), + 2: Accessor(position=1, inner=None), + 3: Accessor(position=2, inner=None), + 4: Accessor(position=3, inner=None), + 6: Accessor(position=4, inner=None), + 11: Accessor(position=5, inner=None), + 16: Accessor(position=6, inner=Accessor(position=0, inner=None)), + 17: Accessor(position=6, inner=Accessor(position=1, inner=None)), + } + + +class TestStruct(StructProtocol): + def __init__(self, pos: Dict[int, Any] = None): + self._pos: Dict[int, Any] = pos or {} + + def set(self, pos: int, value) -> None: + pass + + def get(self, pos: int) -> Any: + return self._pos[pos] + + +def test_build_position_accessors_with_struct(table_schema_nested): + accessors = build_position_accessors(table_schema_nested) + container = TestStruct({6: TestStruct({0: "name"})}) + assert accessors.get(16).get(container) == "name"