Skip to content

Commit

Permalink
Use annotations to declare schema
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 29, 2024
1 parent b39d5da commit 9003db2
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Low-code form builder for frontend
- Add snowflake vector search engine
- Add a meta-datatype `Vector` to handle different databackend requirements
- Add type annotatiions as a way to declare schema

#### Bug Fixes

Expand Down
4 changes: 2 additions & 2 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ibis.expr.datatypes import dtype
from superduper.components.datatype import (
BaseDataType,
File,
FileItem,
)
from superduper.components.schema import ID, FieldType, Schema

SPECIAL_ENCODABLES_FIELDS = {
File: "str",
FileItem: "str",
}


Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ target-version = ["py38"]
ignore_missing_imports = true
no_implicit_optional = true
warn_unused_ignores = true
disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg", "import-untyped"]
disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg", "import-untyped", "no-redef", "valid-type", "valid-newtype"]

[tool.pytest.ini_options]
addopts = "-W ignore"
Expand Down Expand Up @@ -136,6 +136,7 @@ ignore = [
"D401",
"D102",
"E402",
"F403"
]
exclude = ["templates", "superduper/templates"]

Expand Down
8 changes: 4 additions & 4 deletions superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from superduper.base.leaf import Leaf, import_item
from superduper.base.variables import _replace_variables
from superduper.components.component import Component
from superduper.components.datatype import BaseDataType, Blob, File
from superduper.components.datatype import BaseDataType, Blob, FileItem
from superduper.components.schema import Schema, get_schema
from superduper.misc.reference import parse_reference
from superduper.misc.special_dicts import MongoStyleDict, SuperDuperFlatEncode
Expand Down Expand Up @@ -294,7 +294,7 @@ def decode(
)

def my_getter(x):
return File(path=r[KEY_FILES].get(x.split(':')[-1]), db=db)
return FileItem(path=r[KEY_FILES].get(x.split(':')[-1]), db=db)

if r.get(KEY_FILES):
getters.add_getter('file', my_getter)
Expand Down Expand Up @@ -521,7 +521,7 @@ def _deep_flat_encode(
blobs[r.identifier] = r.bytes
return '&:blob:' + r.identifier

if isinstance(r, File):
if isinstance(r, FileItem):
files[r.identifier] = r.path
return '&:file:' + r.identifier

Expand Down Expand Up @@ -717,7 +717,7 @@ def _get_component(db, path):

def _get_file_callback(db):
def callback(ref):
return File(identifier=ref, db=db)
return FileItem(identifier=ref, db=db)

return callback

Expand Down
11 changes: 11 additions & 0 deletions superduper/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,17 @@ def __new__(cls, name, bases, dct):
new_cls._fields[field.name] = 'default'
elif annotation is t.Callable or _is_optional_callable(annotation):
new_cls._fields[field.name] = 'default'
# a hack...
elif 'superduper.misc.typing' in str(annotation):
annotation = str(annotation)
import re

match1 = re.match('^typing\.Optional\[(.*)\]$', annotation)
match2 = re.match('^t\.Optional\[(.*)\]$', annotation)
match = match1 or match2
if match:
annotation = match.groups()[0]
new_cls._fields[field.name] = annotation.split('.')[-1]
except KeyError:
continue
return new_cls
Expand Down
80 changes: 50 additions & 30 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@

from superduper import CFG
from superduper.base.leaf import Leaf
from superduper.components.component import Component
from superduper.components.component import Component, ComponentMeta

Decode = t.Callable[[bytes], t.Any]
Encode = t.Callable[[t.Any], bytes]

INBUILT_DATATYPES = {}


class DataTypeFactory:
"""Abstract class for creating a DataType # noqa."""
Expand All @@ -43,7 +45,21 @@ def create(data: t.Any) -> "BaseDataType":
raise NotImplementedError


class BaseDataType(Component):
class DataTypeMeta(ComponentMeta):
"""Metaclass for the `Model` class and descendants # noqa."""

def __new__(mcls, name, bases, dct):
"""Create a new class with merged docstrings # noqa."""
cls = super().__new__(mcls, name, bases, dct)
try:
instance = cls(cls.__name__)
INBUILT_DATATYPES[cls.__name__] = instance
except TypeError:
pass
return cls


class BaseDataType(Component, metaclass=DataTypeMeta):
"""Base class for datatype."""

type_id: t.ClassVar[str] = 'datatype'
Expand Down Expand Up @@ -173,7 +189,7 @@ def decode_data(self, item):
return pickle.loads(item)


class PickleSerializer(_Artifact, _PickleMixin, BaseDataType):
class Pickle(_Artifact, _PickleMixin, BaseDataType):
"""Serializer with pickle."""


Expand All @@ -189,26 +205,28 @@ def decode_data(self, item):
return dill.loads(item)


class DillSerializer(_Artifact, _DillMixin, BaseDataType):
class Dill(_Artifact, _DillMixin, BaseDataType):
"""Serializer with dill.
This is also the default serializer.
>>> from superduper.components.datatype import DEFAULT_SERIALIZER
"""


class _DillEncoder(_Encodable, _DillMixin, BaseDataType):
class DillEncoder(_Encodable, _DillMixin, BaseDataType):
"""Encoder with dill."""

...


class FileType(BaseDataType):
class File(BaseDataType):
"""Type for encoding files on disk."""

encodable: t.ClassVar[str] = 'file'

def encode_data(self, item):
assert os.path.exists(item)
return File(path=item)
return FileItem(path=item)

def decode_data(self, item):
return item
Expand Down Expand Up @@ -247,7 +265,7 @@ def unpack(self):
pass


class File(Saveable):
class FileItem(Saveable):
"""Placeholder for a file.
:param path: Path to file.
Expand Down Expand Up @@ -306,25 +324,27 @@ def reference(self):

json_encoder = JSON('json')
pickle_encoder = PickleEncoder('pickle_encoder')
pickle_serializer = PickleSerializer('pickle_serializer')
dill_encoder = _DillEncoder('dill_encoder')
dill_serializer = DillSerializer('dill_serializer')
file = FileType('file')

DEFAULT_ENCODER = PickleEncoder('default_encoder')
DEFAULT_SERIALIZER = DillSerializer('default')


INBUILT_DATATYPES = {
dt.identifier: dt
for dt in [
json_encoder,
pickle_encoder,
pickle_serializer,
dill_encoder,
dill_serializer,
file,
DEFAULT_SERIALIZER,
DEFAULT_ENCODER,
]
}
pickle_serializer = Pickle('pickle_serializer')
dill_encoder = DillEncoder('dill_encoder')
dill_serializer = Dill('dill_serializer')
file = File('file')


INBUILT_DATATYPES.update(
{
dt.identifier: dt
for dt in [
json_encoder,
pickle_encoder,
pickle_serializer,
dill_encoder,
dill_serializer,
file,
]
}
)

DEFAULT_ENCODER = INBUILT_DATATYPES['PickleEncoder']
DEFAULT_SERIALIZER = INBUILT_DATATYPES['Dill']
INBUILT_DATATYPES['default'] = DEFAULT_SERIALIZER
INBUILT_DATATYPES['Blob'] = INBUILT_DATATYPES['Pickle']
6 changes: 3 additions & 3 deletions superduper/components/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing as t

from superduper import Component, logging
from superduper.components.datatype import File, file
from superduper.components.datatype import FileItem, file


class Plugin(Component):
Expand All @@ -24,7 +24,7 @@ class Plugin(Component):
cache_path: str = "~/.superduper/plugins"

def __post_init__(self, db):
if isinstance(self.path, File):
if isinstance(self.path, FileItem):
self._prepare_plugin()
else:
path_name = os.path.basename(self.path.rstrip("/"))
Expand Down Expand Up @@ -92,7 +92,7 @@ def _pip_install(self, requirement_path):

def _prepare_plugin(self):
plugin_name_tag = f"{self.identifier}"
assert isinstance(self.path, File)
assert isinstance(self.path, FileItem)
cache_path = os.path.expanduser(self.cache_path)
uuid_path = os.path.join(cache_path, self.uuid)
# Check if plugin is already in cache
Expand Down
9 changes: 9 additions & 0 deletions superduper/misc/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import typing as t

from superduper.components.datatype import *

File = t.NewType('File', t.AnyStr)
Blob = t.NewType('Blob', t.Callable)
Dill = t.NewType('Dill', t.Callable)
Pickle = t.NewType('Pickle', t.Callable)
JSON = t.NewType('JSON', t.Dict)
4 changes: 2 additions & 2 deletions test/unittest/component/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from superduper import Component, Schema, Table
from superduper.components.datatype import (
Blob,
File,
FileItem,
dill_serializer,
file,
pickle_encoder,
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_schema_with_file(db, tmp_file):
r = db['documents'].select().tolist()[0]

# loaded document contains a pointer to the file
assert isinstance(r['my_file'], File)
assert isinstance(r['my_file'], FileItem)

# however the path has not been populated
assert not r['my_file'].path
Expand Down
4 changes: 2 additions & 2 deletions test/unittest/misc/test_auto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def test_infer_datatype():

assert infer_datatype({"a": 1}).identifier == "json"

assert infer_datatype({"a": np.array([1, 2, 3])}).identifier == "default_encoder"
assert infer_datatype({"a": np.array([1, 2, 3])}).identifier == "PickleEncoder"

assert (
infer_datatype(pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})).identifier
== "default_encoder"
== "PickleEncoder"
)

with pytest.raises(UnsupportedDatatype):
Expand Down

0 comments on commit 9003db2

Please sign in to comment.