Skip to content

Commit

Permalink
Build schema from dataclass signature
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Feb 14, 2025
1 parent 7dfba55 commit db2782b
Show file tree
Hide file tree
Showing 53 changed files with 905 additions and 733 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- No need to add `.signature` to `Model` implementations
- No need to write `Component.__post_init__` to modify attributes (use `Component.postinit`).
- Move from in-line encoding to schema-based encoding with `Leaf._fields`
- No need to define _fields

#### New Features & Functionality

Expand Down
2 changes: 2 additions & 0 deletions plugins/ibis/plugin_test/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ data_backend: sqlite://
auto_schema: false
force_apply: true
json_native: false
datatype_presets:
vector: superduper.components.datatype.Array
23 changes: 9 additions & 14 deletions plugins/ibis/plugin_test/test_end_2_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from superduper import CFG, superduper
from superduper.base.document import Document as D
from superduper.components.listener import Listener
from superduper.components.schema import FieldType, Schema


@pytest.mark.skip
Expand All @@ -20,17 +19,13 @@ def _end_2_end(db, memory_table=False):
import torchvision
from superduper.ext.torch.encoder import tensor
from superduper.ext.torch.model import TorchModel
from superduper_pillow import pil_image

schema = Schema(
identifier="my_table",
fields={
"id": FieldType(identifier="str"),
"health": FieldType(identifier="int32"),
"age": FieldType(identifier="int32"),
"image": pil_image,
},
)

fields = {
"id": "str",
"health": "int32",
"age": "int32",
"image": 'superduper_pillow.pil_image',
}
im = PIL.Image.open("test/material/data/test-image.jpeg")

data_to_insert = [
Expand All @@ -42,7 +37,7 @@ def _end_2_end(db, memory_table=False):

from superduper.components.table import Table

t = Table(identifier="my_table", schema=schema, db=db)
t = Table(identifier="my_table", fields=fields)

db.apply(t)
t = db["my_table"]
Expand Down Expand Up @@ -92,7 +87,7 @@ def postprocess(x):
preprocess=preprocess,
postprocess=postprocess,
object=torchvision.models.resnet18(pretrained=False),
datatype=FieldType("int32"),
datatype='int32',
)

# Apply the torchvision model
Expand Down
20 changes: 6 additions & 14 deletions plugins/ibis/plugin_test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,17 @@
import numpy as np
import pytest
from superduper.base.document import Document
from superduper.components.schema import Schema
from superduper.components.table import Table


def test_serialize_table():
schema = Schema(
identifier="my_schema",
fields={
"id": "int64",
"health": "int32",
"age": "int32",
},
)

s = schema.encode()
ds = Document.decode(s).unpack()
assert isinstance(ds, Schema)
fields = {
"id": "int",
"health": "int",
"age": "int",
}

t = Table(identifier="my_table", schema=schema)
t = Table(identifier="my_table", fields=fields)

s = t.encode()

Expand Down
8 changes: 0 additions & 8 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,6 @@ def __init__(self, uri: str, plugin: t.Any, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

self.datatype_presets = {'vector': 'superduper.components.datatype.Array'}

if uri.startswith('snowflake://') or uri.startswith('clickhouse://'):
self.bytes_encoding = 'base64'
self.datatype_presets.update(
{'vector': 'superduper.components.datatype.NativeVector'}
)

def random_id(self):
"""Generate a random ID."""
return str(uuid.uuid4())
Expand Down
24 changes: 6 additions & 18 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from ibis.expr.datatypes import dtype
from superduper.components.datatype import (
ID,
BaseDataType,
FieldType,
FileItem,
Vector,
)
from superduper.components.schema import ID, FieldType, Schema
from superduper.components.schema import Schema

SPECIAL_ENCODABLES_FIELDS = {
FileItem: "str",
Expand All @@ -31,22 +32,9 @@ def convert_schema_to_fields(schema: Schema):
for k, v in schema.fields.items():
if isinstance(v, FieldType):
fields[k] = _convert_field_type_to_ibis_type(v)
elif not isinstance(v, BaseDataType):
fields[k] = v.identifier
else:
if v.encodable == 'encodable':
fields[k] = dtype(
'str'
if schema.db.databackend.bytes_encoding == 'base64'
else 'bytes'
)
elif isinstance(v, Vector):
fields[k] = dtype('json')

elif v.encodable == 'native':
fields[k] = dtype(v.dtype)

else:
fields[k] = dtype('str')
assert isinstance(schema.fields[k], BaseDataType)

fields[k] = dtype(schema.fields[k].dtype)

return fields
6 changes: 3 additions & 3 deletions plugins/transformers/superduper_transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ def init_pipeline(
tokenizer_kwargs["pretrained_model_name_or_path"] = adapter_id

else:
tokenizer_kwargs[
"pretrained_model_name_or_path"
] = self.model_name_or_path
tokenizer_kwargs["pretrained_model_name_or_path"] = (
self.model_name_or_path
)

tokenizer = AutoTokenizer.from_pretrained(
**tokenizer_kwargs,
Expand Down
6 changes: 3 additions & 3 deletions plugins/transformers/superduper_transformers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,9 @@ def ray_train_func(train_loop_config):
train_loop_config.get("gradient_checkpointing_kwargs", {}) or {}
)
gradient_checkpointing_kwargs["use_reentrant"] = False
train_loop_config[
"gradient_checkpointing_kwargs"
] = gradient_checkpointing_kwargs
train_loop_config["gradient_checkpointing_kwargs"] = (
gradient_checkpointing_kwargs
)
train_loop_args = LLMTrainer(**train_loop_config)
# Build the training_args on remote machine
train_loop_args.build()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ test = [
"scikit-learn>=1.1.3",
"pandas",
"pre-commit",
"black==23.3",
"black==25.1.0",
"ruff==0.4.4",
"mypy",
"types-PyYAML",
Expand Down
12 changes: 6 additions & 6 deletions superduper/backends/local/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def _put(self, component: Component):
current_component = self._cache[current]
current_version = current_component.version
if current_version < component.version:
self._component_to_uuid[
component.type_id, component.identifier
] = component.uuid
self._component_to_uuid[component.type_id, component.identifier] = (
component.uuid
)
self.expire(current_component.uuid)
else:
self._component_to_uuid[
component.type_id, component.identifier
] = component.uuid
self._component_to_uuid[component.type_id, component.identifier] = (
component.uuid
)

def __delitem__(self, item):
if isinstance(item, tuple):
Expand Down
6 changes: 3 additions & 3 deletions superduper/backends/local/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def _put(self, component):
msg = 'Table name "_apply" collides with Superduper namespace'
assert component.cdc_table != '_apply', msg
assert isinstance(component, CDC)
self._component_uuid_mapping[
component.type_id, component.identifier
] = component.uuid
self._component_uuid_mapping[component.type_id, component.identifier] = (
component.uuid
)
if component.cdc_table in self.queue:
return
self.queue[component.cdc_table] = []
Expand Down
21 changes: 15 additions & 6 deletions superduper/base/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _apply(

object.db = db

serialized = object.dict(metadata=False)
serialized = object.dict(metadata=False, schema=True)

del serialized['uuid']

Expand Down Expand Up @@ -214,7 +214,7 @@ def replace_existing(x):
x = x.replace(uuid, non_breaking_changes[uuid])

elif isinstance(x, Query):
r = x.dict()
r = x.dict(schema=True)
for uuid in non_breaking_changes:
r['query'] = r['query'].replace(uuid, non_breaking_changes[uuid])
for i, doc in enumerate(r['documents']):
Expand Down Expand Up @@ -242,7 +242,7 @@ def replace_existing(x):

# only check for diff not in metadata/ uuid
# also only
current_serialized = current.dict(metadata=False, refs=True)
current_serialized = current.dict(metadata=False, refs=True, schema=True)
del current_serialized['uuid']

serialized = serialized.map(
Expand All @@ -269,6 +269,13 @@ def replace_existing(x):

return create_events, job_events

elif '_path' in this_diff:
# TODO use custom exception
raise ValueError(
f'Cannot update a version of {current_serialized["_path"]} '
f'with a new class {serialized["_path"]}.'
)

elif set(this_diff.keys(deep=True)).intersection(object.breaks):
# if this is a breaking change then create a new version
apply_status = 'breaking'
Expand All @@ -287,7 +294,7 @@ def replace_existing(x):
# during the `.map` to the children
# serializer.map...
# this means replacing components with references
serialized = object.dict().update(serialized)
serialized = object.dict(schema=True).update(serialized)

# this is necessary to prevent inconsistencies
# this takes the difference between
Expand Down Expand Up @@ -325,7 +332,7 @@ def replace_existing(x):
# update the existing component with the change
# data from the applied component
serialized = (
current.dict()
current.dict(schema=True)
.update(serialized)
.update(this_diff)
.encode(keep_schema=False)
Expand All @@ -339,7 +346,8 @@ def replace_existing(x):
replace_existing, lambda x: isinstance(x, str) or isinstance(x, Query)
)
serialized['version'] = 0
serialized = object.dict().update(serialized)

serialized = object.dict(schema=True).update(serialized)

# if the metadata includes components, which
# need to be applied, do that now
Expand All @@ -365,6 +373,7 @@ def replace_existing(x):
serialized = db._save_artifact(object.uuid, serialized)

if apply_status in {'new', 'breaking'}:

metadata_event = Create(
context=context,
component=serialized,
Expand Down
21 changes: 10 additions & 11 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from superduper.base.document import Document
from superduper.components.component import Component
from superduper.components.datatype import LeafType
from superduper.components.schema import Schema
from superduper.components.table import Table


Expand Down Expand Up @@ -217,11 +216,8 @@ def _auto_create_table(self, table_name, documents):

# Should we need to check all the documents?
document = documents[0]
schema = self.infer_schema(document)
table = Table(identifier=table_name, schema=schema)
logging.info(
f"Creating table {table_name} with schema {list(schema.fields.keys())}"
)
table = Table(identifier=table_name, fields=self.infer_schema(document))
logging.info(f"Creating table {table_name} with schema {table.schema}")
self.apply(table, force=True)
return table

Expand Down Expand Up @@ -383,9 +379,10 @@ def load(
for k in builds:
builds[k]['identifier'] = k.split(':')[-1]

c = LeafType('leaf_type', db=self).decode_data(
c = LeafType().decode_data(
{k: v for k, v in info.items() if k != '_builds'},
builds=builds,
db=self,
)
if c.cache:
logging.info(f'Adding {c.huuid} to cache')
Expand All @@ -409,8 +406,10 @@ def load(
identifier=identifier,
allow_hidden=allow_hidden,
)
c = LeafType('leaf_type', db=self).decode_data(
info, builds=info.get('_builds', {})
c = LeafType().decode_data(
info,
builds=info.get('_builds', {}),
db=self,
)
if c.cache:
logging.info(f'Adding {c.huuid} to cache')
Expand Down Expand Up @@ -507,7 +506,7 @@ def replace(self, object: t.Any):
except FileNotFoundError:
pass

serialized = object.dict()
serialized = object.dict(schema=True)

if old_uuid:

Expand Down Expand Up @@ -629,7 +628,7 @@ def disconnect(self):

def infer_schema(
self, data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None
) -> Schema:
) -> t.Dict:
"""Infer a schema from a given data object.
:param data: The data object
Expand Down
Loading

0 comments on commit db2782b

Please sign in to comment.