Skip to content

Commit

Permalink
Refactor types to use simpler design
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 27, 2024
1 parent e7008a3 commit b39d5da
Show file tree
Hide file tree
Showing 85 changed files with 952 additions and 1,921 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

#### Changed defaults / behaviours

- Deprecate vanilla `DataType`
- Remove `_Encodable` from project

#### New Features & Functionality

- Streamlit component and server
Expand Down
4 changes: 2 additions & 2 deletions plugins/anthropic/superduper_anthropic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class Anthropic(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts, example):
def __post_init__(self, db, example):
self.model = self.model or self.identifier
super().__post_init__(db, artifacts, example=example)
super().__post_init__(db, example=example)

def init(self, db=None):
"""Initialize the model.
Expand Down
20 changes: 4 additions & 16 deletions plugins/cohere/superduper_cohere/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from cohere.error import CohereAPIError, CohereConnectionError
from superduper.backends.query_dataset import QueryDataset
from superduper.components.model import APIBaseModel
from superduper.components.vector_index import vector
from superduper.ext.utils import format_prompt, get_key
from superduper.misc.retry import Retry

Expand All @@ -23,8 +22,8 @@ class Cohere(APIBaseModel):

client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)
self.identifier = self.identifier or self.model


Expand All @@ -47,22 +46,11 @@ class CohereEmbed(Cohere):
batch_size: int = 100
signature: str = 'singleton'

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)
if self.shape is None:
self.shape = self.shapes[self.identifier]

def _pre_create(self, db):
"""Pre create method for the model.
If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.
:param db: The datalayer to use for the model.
"""
if self.datatype is None:
self.datatype = vector(shape=self.shape)

@retry
def predict(self, X: str):
"""Predict the embedding of a single text.
Expand Down
6 changes: 3 additions & 3 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from superduper.backends.local.artifacts import FileSystemArtifactStore
from superduper.base import exceptions
from superduper.base.enums import DBType
from superduper.components.datatype import DataType
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema
from superduper.components.table import Table

Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

if uri.startswith('snowflake://') or uri.startswith('sqlite://'):
if uri.startswith('snowflake://'):
self.bytes_encoding = 'base64'

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}
Expand Down Expand Up @@ -190,7 +190,7 @@ def drop_table_or_collection(self, name: str):
def create_output_dest(
self,
predict_id: str,
datatype: t.Union[FieldType, DataType],
datatype: t.Union[FieldType, BaseDataType],
flatten: bool = False,
):
"""Create a table for the output of the model.
Expand Down
4 changes: 2 additions & 2 deletions plugins/ibis/superduper_ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from superduper.base.cursor import SuperDuperCursor
from superduper.base.exceptions import DatabackendException
from superduper.components.datatype import Encodable
from superduper.components.datatype import _Encodable
from superduper.components.schema import Schema
from superduper.misc.special_dicts import SuperDuperFlatEncode

Expand Down Expand Up @@ -81,7 +81,7 @@ def _model_update_impl(
d = {
"_source": str(source_id),
f"{CFG.output_prefix}{predict_id}": output.x
if isinstance(output, Encodable)
if isinstance(output, _Encodable)
else output,
"id": str(uuid.uuid4()),
}
Expand Down
8 changes: 0 additions & 8 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from ibis.expr.datatypes import dtype
from superduper.components.datatype import (
Artifact,
BaseDataType,
File,
LazyArtifact,
LazyFile,
Native,
)
from superduper.components.schema import ID, FieldType, Schema

SPECIAL_ENCODABLES_FIELDS = {
File: "str",
LazyFile: "str",
Artifact: "str",
LazyArtifact: "str",
Native: "json",
}


Expand Down
20 changes: 4 additions & 16 deletions plugins/jina/superduper_jina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tqdm
from superduper.backends.query_dataset import QueryDataset
from superduper.components.model import APIBaseModel
from superduper.components.vector_index import vector

from superduper_jina.client import JinaAPIClient

Expand All @@ -16,8 +15,8 @@ class Jina(APIBaseModel):

api_key: t.Optional[str] = None

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)
self.identifier = self.identifier or self.model
self.client = JinaAPIClient(model_name=self.identifier, api_key=self.api_key)

Expand All @@ -41,22 +40,11 @@ class JinaEmbedding(Jina):
shape: t.Optional[t.Sequence[int]] = None
signature: str = 'singleton'

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
def __post_init__(self, db, example):
super().__post_init__(db, example)
if self.shape is None:
self.shape = (len(self.client.encode_batch(['shape'])[0]),)

def _pre_create(self, db):
"""Pre create method for the model.
If the datalayer is Ibis, the datatype will be set to the appropriate
SQL datatype.
:param db: The datalayer to use for the model.
"""
if self.datatype is None:
self.datatype = vector(shape=self.shape)

def predict(self, X: str):
"""Predict the embedding of a single text.
Expand Down
7 changes: 5 additions & 2 deletions plugins/mongodb/plugin_test/test_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import superduper as s
from superduper import CFG, superduper
from superduper.base.document import Document
from superduper.components.datatype import Vector
from superduper.components.listener import Listener
from superduper.components.model import ObjectModel
from superduper.components.vector_index import VectorIndex, vector
from superduper.components.vector_index import VectorIndex

from superduper_mongodb.query import MongoQuery

Expand Down Expand Up @@ -50,7 +51,9 @@ def atlas_search_config():
@pytest.mark.skipif(DO_SKIP, reason='Only atlas deployments relevant.')
def test_setup_atlas_vector_search(atlas_search_config):
model = ObjectModel(
identifier='test-model', object=random_vector_model, encoder=vector(shape=(16,))
identifier='test-model',
object=random_vector_model,
encoder=Vector(dtype='float64', shape=(16,)),
)
db = superduper()
collection = MongoQuery(table='docs')
Expand Down
9 changes: 1 addition & 8 deletions plugins/mongodb/plugin_test/test_mongodb_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import pytest
from superduper import CFG
from superduper.components.component import Component
from superduper.components.datatype import (
DataType,
file_serializer,
)

DO_SKIP = not CFG.data_backend.startswith("mongodb")

Expand All @@ -18,10 +14,7 @@
class TestComponent(Component):
path: str
type_id: t.ClassVar[str] = "TestComponent"

_artifacts: t.ClassVar[t.Sequence[t.Tuple[str, "DataType"]]] = (
("path", file_serializer),
)
fields = {'path': 'file'}


@pytest.fixture
Expand Down
4 changes: 2 additions & 2 deletions plugins/mongodb/plugin_test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from superduper.base.document import Document
from superduper.components.schema import Schema
from superduper.components.table import Table
from superduper.ext.numpy.encoder import array
from superduper.ext.numpy.encoder import Array

from superduper_mongodb.query import MongoQuery

Expand All @@ -14,7 +14,7 @@
def schema(request):
bytes_encoding = request.param if hasattr(request, 'param') else None

array_tensor = array(dtype="float64", shape=(32,), bytes_encoding=bytes_encoding)
array_tensor = Array(dtype="float64", shape=(32,))
schema = Schema(
identifier=f'documents-{bytes_encoding}',
fields={
Expand Down
4 changes: 2 additions & 2 deletions plugins/mongodb/superduper_mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.base.enums import DBType
from superduper.components.datatype import DataType
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema
from superduper.misc.colors import Colors

Expand Down Expand Up @@ -140,7 +140,7 @@ def disconnect(self):
def create_output_dest(
self,
predict_id: str,
datatype: t.Union[str, DataType],
datatype: t.Union[str, BaseDataType],
flatten: bool = False,
):
"""Create an output collection for a component.
Expand Down
1 change: 1 addition & 0 deletions plugins/mongodb/superduper_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def process_find_part(part):
method, args, kwargs = part
# args: (filter, projection, *args)
filter = copy.deepcopy(args[0]) if len(args) > 0 else {}
filter = dict(filter)
filter.update(self._get_filter_conditions())
args = tuple((filter, *args[1:]))

Expand Down
8 changes: 4 additions & 4 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class _OpenAI(APIBaseModel):
openai_api_base: t.Optional[str] = None
client_kwargs: t.Optional[dict] = dc.field(default_factory=dict)

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
def __post_init__(self, db, example):
super().__post_init__(db, example)

assert isinstance(self.client_kwargs, dict)

Expand Down Expand Up @@ -151,8 +151,8 @@ class OpenAIChatCompletion(_OpenAI):
batch_size: int = 1
prompt: str = ''

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example)
def __post_init__(self, db, example):
super().__post_init__(db, example)
self.takes_context = True

def _format_prompt(self, context, X):
Expand Down
2 changes: 0 additions & 2 deletions plugins/sentence_transformers/plugin_test/test_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from test.utils.component import utils as component_utils

import sentence_transformers
from superduper import vector

from superduper_sentence_transformers import SentenceTransformer

Expand All @@ -10,7 +9,6 @@ def test_encode_and_decode():
model = SentenceTransformer(
identifier="embedding",
object=sentence_transformers.SentenceTransformer("all-MiniLM-L6-v2"),
datatype=vector(shape=(1024,)),
postprocess=lambda x: x.tolist(),
predict_kwargs={"show_progress_bar": True},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from superduper.backends.query_dataset import QueryDataset
from superduper.base.enums import DBType
from superduper.components.component import ensure_initialized
from superduper.components.datatype import DataType, dill_lazy
from superduper.components.model import Model, Signature, _DeviceManaged

DEFAULT_PREDICT_KWARGS = {
Expand Down Expand Up @@ -39,9 +38,11 @@ class SentenceTransformer(Model, _DeviceManaged):
"""

_artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = (
('object', dill_lazy),
)
_fields = {
'object': 'default',
'postprocess': 'default',
'preprocess': 'default',
}

object: t.Optional[_SentenceTransformer] = None
model: t.Optional[str] = None
Expand All @@ -50,8 +51,8 @@ class SentenceTransformer(Model, _DeviceManaged):
postprocess: t.Union[None, t.Callable] = None
signature: Signature = 'singleton'

def __post_init__(self, db, artifacts, example):
super().__post_init__(db, artifacts, example=example)
def __post_init__(self, db, example):
super().__post_init__(db, example=example)

if self.model is None:
self.model = self.identifier
Expand Down
2 changes: 1 addition & 1 deletion plugins/sklearn/plugin_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_sklearn(db):
identifier='test',
object=SVC(),
)
assert 'object' in m.artifact_schema.fields
assert 'object' in m._fields
db.apply(m, force=True)
assert db.show('model') == ['test']

Expand Down
8 changes: 3 additions & 5 deletions plugins/sklearn/superduper_sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from superduper import logging
from superduper.backends.query_dataset import QueryDataset
from superduper.base.datalayer import Datalayer
from superduper.components.datatype import DataType, pickle_serializer
from superduper.components.datatype import pickle_serializer
from superduper.components.model import (
Model,
ModelInputType,
Expand Down Expand Up @@ -93,7 +93,7 @@ def fit(
metrics.update(dataset_metrics)

model.metric_values = metrics
db.replace(model, upsert=True)
db.replace(model)


class Estimator(Model):
Expand All @@ -117,9 +117,7 @@ class Estimator(Model):
"""

_artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = (
('object', pickle_serializer),
)
_fields = {'object': pickle_serializer}

object: BaseEstimator
trainer: t.Optional[SklearnTrainer] = None
Expand Down
4 changes: 2 additions & 2 deletions plugins/torch/plugin_test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from superduper import superduper
from superduper.base.datalayer import Datalayer
from superduper.components.datatype import DataType
from superduper.components.datatype import pickle_encoder

from superduper_torch.model import TorchModel
from superduper_torch.training import TorchTrainer
Expand Down Expand Up @@ -67,7 +67,7 @@ def model():
identifier='test',
preferred_devices=('cpu',),
postprocess=lambda x: int(torch.sigmoid(x).item() > 0.5),
datatype=DataType(identifier='base'),
datatype=pickle_encoder,
)


Expand Down
Loading

0 comments on commit b39d5da

Please sign in to comment.