Skip to content

Commit

Permalink
Python: Add visitor to build Accessor for a Schema (apache#4685)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko authored May 13, 2022
1 parent 24f1ae0 commit 17a1be6
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 87 deletions.
31 changes: 1 addition & 30 deletions python/src/iceberg/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Generic, TypeVar

from iceberg.files import StructProtocol
from iceberg.schema import Schema
from iceberg.schema import Accessor, Schema
from iceberg.types import NestedField, Singleton

T = TypeVar("T")
Expand Down Expand Up @@ -267,35 +267,6 @@ def __str__(self) -> str:
return "false"


class Accessor:
"""An accessor for a specific position in a container that implements the StructProtocol"""

def __init__(self, position: int):
self._position = position

def __str__(self):
return f"Accessor(position={self._position})"

def __repr__(self):
return f"Accessor(position={self._position})"

@property
def position(self):
"""The position in the container to access"""
return self._position

def get(self, container: StructProtocol) -> Any:
"""Returns the value at self.position in `container`
Args:
container(StructProtocol): A container to access at position `self.position`
Returns:
Any: The value at position `self.position` in the container
"""
return container.get(self.position)


class BoundReference:
"""A reference bound to a field in a schema
Expand Down
4 changes: 3 additions & 1 deletion python/src/iceberg/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from abc import abstractmethod
from enum import Enum, auto
from typing import Any

Expand Down Expand Up @@ -46,8 +46,10 @@ class FileFormat(Enum):
class StructProtocol(Protocol): # pragma: no cover
"""A generic protocol used by accessors to get and set at positions of an object"""

@abstractmethod
def get(self, pos: int) -> Any:
...

@abstractmethod
def set(self, pos: int, value) -> None:
...
184 changes: 128 additions & 56 deletions python/src/iceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Generic, Iterable, List, TypeVar
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, List, Optional, TypeVar

if TYPE_CHECKING:
from iceberg.expressions.base import Accessor
from iceberg.files import StructProtocol

if sys.version_info >= (3, 8):
from functools import singledispatch # pragma: no cover
Expand All @@ -49,10 +49,10 @@ class Schema:
>>> from iceberg import types
"""

def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_field_ids: List[int] = []):
def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_field_ids: Optional[List[int]] = None):
self._struct = StructType(*columns) # type: ignore
self._schema_id = schema_id
self._identifier_field_ids = identifier_field_ids
self._identifier_field_ids = identifier_field_ids or []
self._name_to_id: Dict[str, int] = index_by_name(self)
self._name_to_id_lower: Dict[str, int] = {} # Should be accessed through self._lazy_name_to_id_lower()
self._id_to_field: Dict[int, NestedField] = {} # Should be accessed through self._lazy_id_to_field()
Expand Down Expand Up @@ -262,10 +262,44 @@ def primitive(self, primitive: PrimitiveType) -> T:
... # pragma: no cover


@dataclass(init=True, eq=True, frozen=True)
class Accessor:
"""An accessor for a specific position in a container that implements the StructProtocol"""

position: int
inner: Optional["Accessor"] = None

def __str__(self):
return f"Accessor(position={self.position},inner={self.inner})"

def __repr__(self):
return self.__str__()

def get(self, container: StructProtocol) -> Any:
"""Returns the value at self.position in `container`
Args:
container(StructProtocol): A container to access at position `self.position`
Returns:
Any: The value at position `self.position` in the container
"""
pos = self.position
val = container.get(pos)
inner = self
while inner.inner:
inner = inner.inner
val = val.get(inner.position)

return val


@singledispatch
def visit(obj, visitor: SchemaVisitor[T]) -> T:
"""A generic function for applying a schema visitor to any point within a schema
The function traverses the schema in post-order fashion
Args:
obj(Schema | IcebergType): An instance of a Schema or an IcebergType
visitor (SchemaVisitor[T]): An instance of an implementation of the generic SchemaVisitor base class
Expand All @@ -286,13 +320,11 @@ def _(obj: Schema, visitor: SchemaVisitor[T]) -> T:
def _(obj: StructType, visitor: SchemaVisitor[T]) -> T:
"""Visit a StructType with a concrete SchemaVisitor"""
results = []

for field in obj.fields:
visitor.before_field(field)
try:
result = visit(field.type, visitor)
finally:
visitor.after_field(field)

result = visit(field.type, visitor)
visitor.after_field(field)
results.append(visitor.field(field, result))

return visitor.struct(obj, results)
Expand All @@ -301,11 +333,10 @@ def _(obj: StructType, visitor: SchemaVisitor[T]) -> T:
@visit.register(ListType)
def _(obj: ListType, visitor: SchemaVisitor[T]) -> T:
"""Visit a ListType with a concrete SchemaVisitor"""

visitor.before_list_element(obj.element)
try:
result = visit(obj.element.type, visitor)
finally:
visitor.after_list_element(obj.element)
result = visit(obj.element.type, visitor)
visitor.after_list_element(obj.element)

return visitor.list(obj, result)

Expand All @@ -314,16 +345,12 @@ def _(obj: ListType, visitor: SchemaVisitor[T]) -> T:
def _(obj: MapType, visitor: SchemaVisitor[T]) -> T:
"""Visit a MapType with a concrete SchemaVisitor"""
visitor.before_map_key(obj.key)
try:
key_result = visit(obj.key.type, visitor)
finally:
visitor.after_map_key(obj.key)
key_result = visit(obj.key.type, visitor)
visitor.after_map_key(obj.key)

visitor.before_map_value(obj.value)
try:
value_result = visit(obj.value.type, visitor)
finally:
visitor.after_list_element(obj.value)
value_result = visit(obj.value.type, visitor)
visitor.after_list_element(obj.value)

return visitor.map(obj, key_result, value_result)

Expand All @@ -340,29 +367,29 @@ class _IndexById(SchemaVisitor[Dict[int, NestedField]]):
def __init__(self) -> None:
self._index: Dict[int, NestedField] = {}

def schema(self, schema, result):
def schema(self, schema: Schema, result) -> Dict[int, NestedField]:
return self._index

def struct(self, struct, results):
def struct(self, struct: StructType, result) -> Dict[int, NestedField]:
return self._index

def field(self, field, result):
def field(self, field: NestedField, result) -> Dict[int, NestedField]:
"""Add the field ID to the index"""
self._index[field.field_id] = field
return self._index

def list(self, list_type, result):
def list(self, list_type: ListType, result) -> Dict[int, NestedField]:
"""Add the list element ID to the index"""
self._index[list_type.element.field_id] = list_type.element
return self._index

def map(self, map_type, key_result, value_result):
def map(self, map_type: MapType, key_result, value_result) -> Dict[int, NestedField]:
"""Add the key ID and value ID as individual items in the index"""
self._index[map_type.key.field_id] = map_type.key
self._index[map_type.value.field_id] = map_type.value
return self._index

def primitive(self, primitive):
def primitive(self, primitive) -> Dict[int, NestedField]:
return self._index


Expand Down Expand Up @@ -409,24 +436,27 @@ def after_field(self, field: NestedField) -> None:
self._field_names.pop()
self._short_field_names.pop()

def schema(self, schema, struct_result):
def schema(self, schema: Schema, struct_result: Dict[str, int]) -> Dict[str, int]:
return self._index

def struct(self, struct, field_results):
def struct(self, struct: StructType, struct_result: List[Dict[str, int]]) -> Dict[str, int]:
return self._index

def field(self, field, field_result):
def field(self, field: NestedField, struct_result: Dict[str, int]) -> Dict[str, int]:
"""Add the field name to the index"""
self._add_field(field.name, field.field_id)
return self._index

def list(self, list_type, result):
def list(self, list_type: ListType, struct_result: Dict[str, int]) -> Dict[str, int]:
"""Add the list element name to the index"""
self._add_field(list_type.element.name, list_type.element.field_id)
return self._index

def map(self, map_type, key_result, value_result):
def map(self, map_type: MapType, key_result: Dict[str, int], value_result: Dict[str, int]) -> Dict[str, int]:
"""Add the key name and value name as individual items in the index"""
self._add_field(map_type.key.name, map_type.key.field_id)
self._add_field(map_type.value.name, map_type.value.field_id)
return self._index

def _add_field(self, name: str, field_id: int):
"""Add a field name to the index, mapping its full name to its field ID
Expand All @@ -451,10 +481,10 @@ def _add_field(self, name: str, field_id: int):
short_name = ".".join([".".join(self._short_field_names), name])
self._short_name_to_id[short_name] = field_id

def primitive(self, primitive):
def primitive(self, primitive) -> Dict[str, int]:
return self._index

def by_name(self):
def by_name(self) -> Dict[str, int]:
"""Returns an index of combined full and short names
Note: Only short names that do not conflict with full names are included.
Expand All @@ -463,13 +493,13 @@ def by_name(self):
combined_index.update(self._index)
return combined_index

def by_id(self):
def by_id(self) -> Dict[int, str]:
"""Returns an index of ID to full names"""
id_to_full_name = dict([(value, key) for key, value in self._index.items()])
return id_to_full_name


def index_by_name(schema_or_type) -> Dict[str, int]:
def index_by_name(schema_or_type: Schema | IcebergType) -> Dict[str, int]:
"""Generate an index of field names to field IDs
Args:
Expand All @@ -483,7 +513,7 @@ def index_by_name(schema_or_type) -> Dict[str, int]:
return indexer.by_name()


def index_name_by_id(schema_or_type) -> Dict[int, str]:
def index_name_by_id(schema_or_type: Schema | IcebergType) -> Dict[int, str]:
"""Generate an index of field IDs full field names
Args:
Expand All @@ -497,31 +527,73 @@ def index_name_by_id(schema_or_type) -> Dict[int, str]:
return indexer.by_id()


class _BuildPositionAccessors(SchemaVisitor[Dict[int, "Accessor"]]):
"""A schema visitor for generating a field ID to accessor index"""
Position = int

def __init__(self) -> None:
self._index: Dict[int, Accessor] = {}

def schema(self, schema, result: Dict[int, Accessor]) -> Dict[int, Accessor]:
return self._index
class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]):
"""A schema visitor for generating a field ID to accessor index
def struct(self, struct, result: List[Dict[int, Accessor]]) -> Dict[int, Accessor]:
# TODO: Populate the `self._index` dictionary where the key is the field ID and the value is an accessor for that field.
# The equivalent java logic can be found here: https://github.com/apache/iceberg/blob/master/api/src/main/java/org/apache/iceberg/Accessors.java#L213-L230
return self._index
Example:
>>> from iceberg.schema import Schema
>>> from iceberg.types import *
>>> schema = Schema(
... NestedField(field_id=2, name="id", field_type=IntegerType(), is_optional=False),
... NestedField(field_id=1, name="data", field_type=StringType(), is_optional=True),
... NestedField(
... field_id=3,
... name="location",
... field_type=StructType(
... NestedField(field_id=5, name="latitude", field_type=FloatType(), is_optional=False),
... NestedField(field_id=6, name="longitude", field_type=FloatType(), is_optional=False),
... ),
... is_optional=True,
... ),
... schema_id=1,
... identifier_field_ids=[1],
... )
>>> result = build_position_accessors(schema)
>>> expected = {
... 2: Accessor(position=0, inner=None),
... 1: Accessor(position=1, inner=None),
... 5: Accessor(position=2, inner=Accessor(position=0, inner=None)),
... 6: Accessor(position=2, inner=Accessor(position=1, inner=None))
... }
>>> result == expected
True
"""

def field(self, field: NestedField, result: Dict[int, Accessor]) -> Dict[int, Accessor]:
return self._index
@staticmethod
def _wrap_leaves(result: Dict[Position, Accessor], position: Position = 0) -> Dict[Position, Accessor]:
return {field_id: Accessor(position, inner=inner) for field_id, inner in result.items()}

def list(self, list_type: ListType, result: Dict[int, Accessor]) -> Dict[int, Accessor]:
return self._index
def schema(self, schema: Schema, result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
return result

def map(self, map_type: MapType, key_result: Dict[int, Accessor], value_result: Dict[int, Accessor]) -> Dict[int, Accessor]:
return self._index
def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor]]) -> Dict[Position, Accessor]:
result = {}

def primitive(self, primitive: PrimitiveType) -> Dict[int, Accessor]:
return self._index
for position, field in enumerate(struct.fields):
if field_results[position]:
for inner_field_id, acc in field_results[position].items():
result[inner_field_id] = Accessor(position, inner=acc)
else:
result[field.field_id] = Accessor(position)

return result

def field(self, field: NestedField, result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
return result

def list(self, list_type: ListType, result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
return {}

def map(
self, map_type: MapType, key_result: Dict[Position, Accessor], value_result: Dict[Position, Accessor]
) -> Dict[Position, Accessor]:
return {}

def primitive(self, primitive: PrimitiveType) -> Dict[Position, Accessor]:
return {}


def build_position_accessors(schema_or_type: Schema | IcebergType) -> Dict[int, Accessor]:
Expand Down
Loading

0 comments on commit 17a1be6

Please sign in to comment.