From 9003db2cf89b8e5024f278d624213dbb1379085a Mon Sep 17 00:00:00 2001 From: Duncan Blythe Date: Thu, 28 Nov 2024 09:55:17 +0100 Subject: [PATCH] Use annotations to declare schema --- CHANGELOG.md | 1 + plugins/ibis/superduper_ibis/utils.py | 4 +- pyproject.toml | 3 +- superduper/base/document.py | 8 +-- superduper/components/component.py | 11 ++++ superduper/components/datatype.py | 80 ++++++++++++++++---------- superduper/components/plugin.py | 6 +- superduper/misc/typing.py | 9 +++ test/unittest/component/test_schema.py | 4 +- test/unittest/misc/test_auto_schema.py | 4 +- 10 files changed, 86 insertions(+), 44 deletions(-) create mode 100644 superduper/misc/typing.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cfd8262e..fbca5fb06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Low-code form builder for frontend - Add snowflake vector search engine - Add a meta-datatype `Vector` to handle different databackend requirements +- Add type annotatiions as a way to declare schema #### Bug Fixes diff --git a/plugins/ibis/superduper_ibis/utils.py b/plugins/ibis/superduper_ibis/utils.py index b8215bf55..94b8c68bd 100644 --- a/plugins/ibis/superduper_ibis/utils.py +++ b/plugins/ibis/superduper_ibis/utils.py @@ -1,12 +1,12 @@ from ibis.expr.datatypes import dtype from superduper.components.datatype import ( BaseDataType, - File, + FileItem, ) from superduper.components.schema import ID, FieldType, Schema SPECIAL_ENCODABLES_FIELDS = { - File: "str", + FileItem: "str", } diff --git a/pyproject.toml b/pyproject.toml index 3d3f08668..07b8ffd99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ target-version = ["py38"] ignore_missing_imports = true no_implicit_optional = true warn_unused_ignores = true -disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg", "import-untyped"] +disable_error_code = ["has-type", "attr-defined", "assignment", "misc", "override", "call-arg", "import-untyped", "no-redef", "valid-type", "valid-newtype"] [tool.pytest.ini_options] addopts = "-W ignore" @@ -136,6 +136,7 @@ ignore = [ "D401", "D102", "E402", + "F403" ] exclude = ["templates", "superduper/templates"] diff --git a/superduper/base/document.py b/superduper/base/document.py index 1dd98dce9..85fd66564 100644 --- a/superduper/base/document.py +++ b/superduper/base/document.py @@ -12,7 +12,7 @@ from superduper.base.leaf import Leaf, import_item from superduper.base.variables import _replace_variables from superduper.components.component import Component -from superduper.components.datatype import BaseDataType, Blob, File +from superduper.components.datatype import BaseDataType, Blob, FileItem from superduper.components.schema import Schema, get_schema from superduper.misc.reference import parse_reference from superduper.misc.special_dicts import MongoStyleDict, SuperDuperFlatEncode @@ -294,7 +294,7 @@ def decode( ) def my_getter(x): - return File(path=r[KEY_FILES].get(x.split(':')[-1]), db=db) + return FileItem(path=r[KEY_FILES].get(x.split(':')[-1]), db=db) if r.get(KEY_FILES): getters.add_getter('file', my_getter) @@ -521,7 +521,7 @@ def _deep_flat_encode( blobs[r.identifier] = r.bytes return '&:blob:' + r.identifier - if isinstance(r, File): + if isinstance(r, FileItem): files[r.identifier] = r.path return '&:file:' + r.identifier @@ -717,7 +717,7 @@ def _get_component(db, path): def _get_file_callback(db): def callback(ref): - return File(identifier=ref, db=db) + return FileItem(identifier=ref, db=db) return callback diff --git a/superduper/components/component.py b/superduper/components/component.py index 6fb3db9e6..72fa22405 100644 --- a/superduper/components/component.py +++ b/superduper/components/component.py @@ -152,6 +152,17 @@ def __new__(cls, name, bases, dct): new_cls._fields[field.name] = 'default' elif annotation is t.Callable or _is_optional_callable(annotation): new_cls._fields[field.name] = 'default' + # a hack... + elif 'superduper.misc.typing' in str(annotation): + annotation = str(annotation) + import re + + match1 = re.match('^typing\.Optional\[(.*)\]$', annotation) + match2 = re.match('^t\.Optional\[(.*)\]$', annotation) + match = match1 or match2 + if match: + annotation = match.groups()[0] + new_cls._fields[field.name] = annotation.split('.')[-1] except KeyError: continue return new_cls diff --git a/superduper/components/datatype.py b/superduper/components/datatype.py index c1366dac0..4270dfcf1 100644 --- a/superduper/components/datatype.py +++ b/superduper/components/datatype.py @@ -13,11 +13,13 @@ from superduper import CFG from superduper.base.leaf import Leaf -from superduper.components.component import Component +from superduper.components.component import Component, ComponentMeta Decode = t.Callable[[bytes], t.Any] Encode = t.Callable[[t.Any], bytes] +INBUILT_DATATYPES = {} + class DataTypeFactory: """Abstract class for creating a DataType # noqa.""" @@ -43,7 +45,21 @@ def create(data: t.Any) -> "BaseDataType": raise NotImplementedError -class BaseDataType(Component): +class DataTypeMeta(ComponentMeta): + """Metaclass for the `Model` class and descendants # noqa.""" + + def __new__(mcls, name, bases, dct): + """Create a new class with merged docstrings # noqa.""" + cls = super().__new__(mcls, name, bases, dct) + try: + instance = cls(cls.__name__) + INBUILT_DATATYPES[cls.__name__] = instance + except TypeError: + pass + return cls + + +class BaseDataType(Component, metaclass=DataTypeMeta): """Base class for datatype.""" type_id: t.ClassVar[str] = 'datatype' @@ -173,7 +189,7 @@ def decode_data(self, item): return pickle.loads(item) -class PickleSerializer(_Artifact, _PickleMixin, BaseDataType): +class Pickle(_Artifact, _PickleMixin, BaseDataType): """Serializer with pickle.""" @@ -189,7 +205,7 @@ def decode_data(self, item): return dill.loads(item) -class DillSerializer(_Artifact, _DillMixin, BaseDataType): +class Dill(_Artifact, _DillMixin, BaseDataType): """Serializer with dill. This is also the default serializer. @@ -197,18 +213,20 @@ class DillSerializer(_Artifact, _DillMixin, BaseDataType): """ -class _DillEncoder(_Encodable, _DillMixin, BaseDataType): +class DillEncoder(_Encodable, _DillMixin, BaseDataType): + """Encoder with dill.""" + ... -class FileType(BaseDataType): +class File(BaseDataType): """Type for encoding files on disk.""" encodable: t.ClassVar[str] = 'file' def encode_data(self, item): assert os.path.exists(item) - return File(path=item) + return FileItem(path=item) def decode_data(self, item): return item @@ -247,7 +265,7 @@ def unpack(self): pass -class File(Saveable): +class FileItem(Saveable): """Placeholder for a file. :param path: Path to file. @@ -306,25 +324,27 @@ def reference(self): json_encoder = JSON('json') pickle_encoder = PickleEncoder('pickle_encoder') -pickle_serializer = PickleSerializer('pickle_serializer') -dill_encoder = _DillEncoder('dill_encoder') -dill_serializer = DillSerializer('dill_serializer') -file = FileType('file') - -DEFAULT_ENCODER = PickleEncoder('default_encoder') -DEFAULT_SERIALIZER = DillSerializer('default') - - -INBUILT_DATATYPES = { - dt.identifier: dt - for dt in [ - json_encoder, - pickle_encoder, - pickle_serializer, - dill_encoder, - dill_serializer, - file, - DEFAULT_SERIALIZER, - DEFAULT_ENCODER, - ] -} +pickle_serializer = Pickle('pickle_serializer') +dill_encoder = DillEncoder('dill_encoder') +dill_serializer = Dill('dill_serializer') +file = File('file') + + +INBUILT_DATATYPES.update( + { + dt.identifier: dt + for dt in [ + json_encoder, + pickle_encoder, + pickle_serializer, + dill_encoder, + dill_serializer, + file, + ] + } +) + +DEFAULT_ENCODER = INBUILT_DATATYPES['PickleEncoder'] +DEFAULT_SERIALIZER = INBUILT_DATATYPES['Dill'] +INBUILT_DATATYPES['default'] = DEFAULT_SERIALIZER +INBUILT_DATATYPES['Blob'] = INBUILT_DATATYPES['Pickle'] diff --git a/superduper/components/plugin.py b/superduper/components/plugin.py index 22256a910..417ef04e5 100644 --- a/superduper/components/plugin.py +++ b/superduper/components/plugin.py @@ -6,7 +6,7 @@ import typing as t from superduper import Component, logging -from superduper.components.datatype import File, file +from superduper.components.datatype import FileItem, file class Plugin(Component): @@ -24,7 +24,7 @@ class Plugin(Component): cache_path: str = "~/.superduper/plugins" def __post_init__(self, db): - if isinstance(self.path, File): + if isinstance(self.path, FileItem): self._prepare_plugin() else: path_name = os.path.basename(self.path.rstrip("/")) @@ -92,7 +92,7 @@ def _pip_install(self, requirement_path): def _prepare_plugin(self): plugin_name_tag = f"{self.identifier}" - assert isinstance(self.path, File) + assert isinstance(self.path, FileItem) cache_path = os.path.expanduser(self.cache_path) uuid_path = os.path.join(cache_path, self.uuid) # Check if plugin is already in cache diff --git a/superduper/misc/typing.py b/superduper/misc/typing.py new file mode 100644 index 000000000..2283e980a --- /dev/null +++ b/superduper/misc/typing.py @@ -0,0 +1,9 @@ +import typing as t + +from superduper.components.datatype import * + +File = t.NewType('File', t.AnyStr) +Blob = t.NewType('Blob', t.Callable) +Dill = t.NewType('Dill', t.Callable) +Pickle = t.NewType('Pickle', t.Callable) +JSON = t.NewType('JSON', t.Dict) diff --git a/test/unittest/component/test_schema.py b/test/unittest/component/test_schema.py index 54d735f63..522521035 100644 --- a/test/unittest/component/test_schema.py +++ b/test/unittest/component/test_schema.py @@ -5,7 +5,7 @@ from superduper import Component, Schema, Table from superduper.components.datatype import ( Blob, - File, + FileItem, dill_serializer, file, pickle_encoder, @@ -104,7 +104,7 @@ def test_schema_with_file(db, tmp_file): r = db['documents'].select().tolist()[0] # loaded document contains a pointer to the file - assert isinstance(r['my_file'], File) + assert isinstance(r['my_file'], FileItem) # however the path has not been populated assert not r['my_file'].path diff --git a/test/unittest/misc/test_auto_schema.py b/test/unittest/misc/test_auto_schema.py index 196f9cc1a..1caf57e6f 100644 --- a/test/unittest/misc/test_auto_schema.py +++ b/test/unittest/misc/test_auto_schema.py @@ -37,11 +37,11 @@ def test_infer_datatype(): assert infer_datatype({"a": 1}).identifier == "json" - assert infer_datatype({"a": np.array([1, 2, 3])}).identifier == "default_encoder" + assert infer_datatype({"a": np.array([1, 2, 3])}).identifier == "PickleEncoder" assert ( infer_datatype(pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})).identifier - == "default_encoder" + == "PickleEncoder" ) with pytest.raises(UnsupportedDatatype):