Skip to content

Commit

Permalink
Add changes necessary for services
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Feb 20, 2025
1 parent c0b966e commit 960e54e
Show file tree
Hide file tree
Showing 39 changed files with 1,011 additions and 492 deletions.
9 changes: 2 additions & 7 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import typing as t

from functools import lru_cache as cache
import numpy
import requests
import tqdm
Expand All @@ -18,8 +19,7 @@
from superduper.backends.query_dataset import QueryDataset
from superduper.base import exceptions
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel, Inputs
from superduper.misc.compat import cache
from superduper.components.model import APIBaseModel
from superduper.misc.retry import Retry, safe_retry

retry = Retry(
Expand Down Expand Up @@ -118,11 +118,6 @@ class OpenAIEmbedding(_OpenAI):
signature: str = 'singleton'
batch_size: int = 100

@property
def inputs(self):
"""The inputs of the model."""
return Inputs(['input'])

@retry
def predict(self, X: str):
"""Generates embeddings from text.
Expand Down
13 changes: 13 additions & 0 deletions superduper/backends/base/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def db(self, value):
self.compute.db = value
self.cdc.db = value

def load_custom_plugins(self):
"""Load user plugins."""
from superduper import logging

if 'Plugin' in self.db.show('Table'):
logging.info(f"Found custom plugins - loading...")
for plugin in self.db.show('Plugin'):
logging.info(f"Loading plugin: {plugin}")
plugin = self.db.load('Plugin', plugin)

def initialize(self, with_compute: bool = False):
"""Initialize the cluster.
Expand All @@ -88,6 +98,9 @@ def initialize(self, with_compute: bool = False):

start = time.time()
assert self.db

self.load_custom_plugins()

if with_compute:
self.compute.initialize()

Expand Down
63 changes: 37 additions & 26 deletions superduper/backends/base/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import typing as t
import uuid

from superduper import logging
from superduper.base.exceptions import DatabackendError
from superduper.base.base import Base
from superduper.components.cdc import CDC
from superduper.components.schema import Schema
from superduper.components.table import Table
from superduper.misc.importing import import_object


class NonExistentMetadataError(Exception):
Expand Down Expand Up @@ -90,31 +92,27 @@ def __init__(self, db):
self.preset_components = {
('Table', 'Table'): Table(
identifier='Table',
cls=Table,
primary_id='uuid',
uuid='abc',
component=True,
path='superduper.components.table.Table',
).encode(),
('Table', 'ParentChildAssociations'): Table(
identifier='ParentChildAssociations',
cls=ParentChildAssociations,
primary_id='uuid',
uuid='def',
component=True,
path='superduper.backends.base.metadata.ParentChildAssociations',
).encode(),
('Table', 'ArtifactRelations'): Table(
identifier='ArtifactRelations',
cls=ArtifactRelations,
primary_id='uuid',
uuid='ghi',
component=True,
path='superduper.backends.base.metadata.ArtifactRelations',
).encode(),
('Table', 'Job'): Table(
identifier='Job',
cls=Job,
primary_id='uuid',
uuid='jkl',
component=True,
Expand Down Expand Up @@ -142,8 +140,8 @@ def get_schema(self, table: str):
r = self.db['Table'].get(identifier=table)
try:
r = r.unpack()
if r['cls'] is not None:
return r['cls'].class_schema
if r['path'] is not None:
return import_object(r['path']).class_schema
return Schema.build(r['fields'])
except AttributeError as e:
if 'unpack' in str(e) and 'NoneType' in str(e):
Expand All @@ -165,7 +163,7 @@ def create(self, cls: t.Type[Base]):
except DatabackendError as e:
if 'not found' in str(e):
self.db.databackend.create_table_and_schema('Table', Table.class_schema)
t = Table('Table', cls=Table, primary_id='uuid', component=True)
t = Table('Table', path='superduper.components.table.Table', primary_id='uuid', component=True)
r = self.db['Table'].insert(
[t.dict(schema=True, path=False)],
)
Expand All @@ -178,7 +176,7 @@ def create(self, cls: t.Type[Base]):
)

self.db.databackend.create_table_and_schema(cls.__name__, cls.class_schema)
t = Table(identifier=cls.__name__, cls=cls, primary_id='uuid', component=True)
t = Table(identifier=cls.__name__, path=f'{cls.__module__}.{cls.__name__}', primary_id='uuid', component=True)
self.db['Table'].insert([t.dict(path=False)])
return t

Expand Down Expand Up @@ -334,6 +332,18 @@ def create_job(self, info: t.Dict):
"""
self.create_entry(info, 'Job', raw=False)

def show_jobs(self, component: str, identifier: str):
"""
Show all jobs in the metadata store.
:param component: type of component
:param identifier: identifier of component
"""
return self.db['Job'].filter(
self.db['Job']['component'] == component,
self.db['Job']['identifier'] == identifier,
).distinct('job_id')

def show_components(self, component: str | None = None):
"""
Show all components in the metadata store.
Expand All @@ -348,12 +358,16 @@ def show_components(self, component: str | None = None):
):
if component in metaclasses.keys():
continue
out.extend(
[
{'component': component, 'identifier': x}
for x in self.db[component].distinct('identifier')
]
)

try:
out.extend(
[
{'component': component, 'identifier': x}
for x in self.db[component].distinct('identifier')
]
)
except ModuleNotFoundError as e:
logging.error(f'Component type not found: {component}; ', e)
out.extend(
[
{'component': 'Table', 'identifier': x}
Expand All @@ -363,19 +377,15 @@ def show_components(self, component: str | None = None):
return out
return self.db[component].distinct('identifier')

def get_classes(self):
"""Get all classes in the metadata store."""
data = self['Metadata'].execute()
return [r['cls'] for r in data]

def show_cdc_tables(self):
"""List the tables used for CDC."""
cdc_classes = []
for r in self.db['Table'].execute():
if r['cls'] is None:
if r['path'] is None:
continue
cls = import_object(r['path'])
r = r.unpack()
if issubclass(r['cls'], CDC):
if issubclass(cls, CDC):
cdc_classes.append(r)

cdc_tables = []
Expand All @@ -391,9 +401,10 @@ def show_cdcs(self, table):
"""
cdc_classes = []
for r in self.db['Table'].execute():
if r['cls'] is None:
if r['path'] is None:
continue
if issubclass(r['cls'], CDC):
cls = import_object(r['path'])
if issubclass(cls, CDC):
cdc_classes.append(r)

cdcs = []
Expand Down Expand Up @@ -481,9 +492,8 @@ def get_component_by_uuid(self, component: str, uuid: str):
if uuid in self.preset_uuids:
return self.preset_uuids[uuid]
r = self.db[component].get(uuid=uuid, raw=True)
cls = self.db['Table'].get(identifier=component)['cls']
_path = cls.__module__ + '.' + cls.__name__
r['_path'] = _path
path = self.db['Table'].get(identifier=component)['path']
r['_path'] = path
return r

def get_component(
Expand All @@ -504,6 +514,7 @@ def get_component(
if (component, identifier) in self.preset_components:
return self.preset_components[(component, identifier)]

# TODO find a more efficient way to do this.
if version is None:
version = self.get_latest_version(
component=component,
Expand Down
6 changes: 3 additions & 3 deletions superduper/backends/local/cdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def __delitem__(self, item):
def initialize(self):
"""Initialize the CDC."""
for component_data in self.db.show():
type_id = component_data['type_id']
component = component_data['component']
identifier = component_data['identifier']
r = self.db.show(component=type_id, identifier=identifier, version=-1)
r = self.db.show(component=component, identifier=identifier, version=-1)
if r.get('trigger'):
self.put(self.db.load(type_id=type_id, identifier=identifier))
self.put(self.db.load(component=component, identifier=identifier))
# TODO consider re-initialzing CDC jobs since potentially failure

def drop(self, component: t.Optional['Component'] = None):
Expand Down
6 changes: 3 additions & 3 deletions superduper/backends/local/crontab.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def drop(self, component: t.Optional['Component'] = None):
def initialize(self):
"""Initialize the crontab."""
for component_data in self.db.show():
type_id = component_data['type_id']
component = component_data['component']
identifier = component_data['identifier']
r = self.db.show(component=type_id, identifier=identifier, version=-1)
r = self.db.show(component=component, identifier=identifier, version=-1)
if r.get('schedule'):
obj = self.db.load(type_id=type_id, identifier=identifier)
obj = self.db.load(component=component, identifier=identifier)
from superduper.components.cron_job import CronJob

if isinstance(obj, CronJob):
Expand Down
15 changes: 8 additions & 7 deletions superduper/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from prettytable import PrettyTable

import superduper as s
from superduper import logging
from superduper import CFG, logging
from superduper.backends.base.data_backend import DataBackendProxy
from superduper.base.config import Config
from superduper.base.datalayer import Datalayer
from superduper.misc.anonymize import anonymize_url
from superduper.misc.importing import load_plugin

from superduper.backends.local.artifacts import (
FileSystemArtifactStore,
)


class _Loader:
not_supported: t.Tuple = ()
Expand Down Expand Up @@ -74,8 +78,8 @@ class _ArtifactStoreLoader(_Loader):
}


def _build_artifact_store(uri):
return _ArtifactStoreLoader.create(uri)
def _build_artifact_store():
return FileSystemArtifactStore(CFG.artifact_store)


def _build_databackend(uri):
Expand Down Expand Up @@ -103,10 +107,7 @@ def build_datalayer(cfg=None, **kwargs) -> Datalayer:
cfg = t.cast(Config, cfg)
databackend_obj = _build_databackend(cfg.data_backend)

if cfg.artifact_store:
artifact_store = _build_artifact_store(cfg.artifact_store)
else:
artifact_store = databackend_obj.build_artifact_store()
artifact_store = _build_artifact_store()

backend = getattr(load_plugin(cfg.cluster_engine), 'Cluster')
cluster = backend.build(cfg, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class Config(BaseConfig):
data_backend: str = "mongodb://localhost:27017/test_db"
secrets_volume: str = os.path.join(".superduper", "/session/secrets")

artifact_store: t.Optional[str] = None
# TODO drop the "filesystem://" prefix
artifact_store: str = 'filesystem://./artifact_store'
metadata_store: t.Optional[str] = None
vector_search_engine: str = 'local'
cluster_engine: str = 'local'
Expand Down
38 changes: 28 additions & 10 deletions superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def _update(r, s):


class _InMemoryArtifactStore(ArtifactStore):
def __init__(self, blobs, files):
def __init__(self, blobs, files, artifact_store=None):
self.blobs = blobs
self.files = files
self.artifact_store = artifact_store

def url(self):
"""Artifact store connection url."""
Expand Down Expand Up @@ -151,15 +152,25 @@ def get_bytes(self, file_id: str) -> bytes:
:param file_id: Identifier of artifact in the store
"""
return self.blobs[file_id]
if file_id in self.blobs:
return self.blobs[file_id]
elif self.artifact_store:
return self.artifact_store.get_bytes(file_id)
else:
raise FileNotFoundError(f'Blob {file_id} not found in in-memory artifact store')

def get_file(self, file_id: str) -> str:
"""
Load file from artifact store and return path.
:param file_id: Identifier of artifact in the store
"""
return self.files[file_id]
if file_id in self.files:
return self.files[file_id]
elif self.artifact_store:
return self.artifact_store.get_file(file_id)
else:
raise FileNotFoundError(f'File {file_id} not found in in-memory artifact store')

def disconnect(self):
"""Disconnect the client."""
Expand All @@ -176,9 +187,13 @@ class _TmpDB:
:param databackend: The databackend to use.
"""

def __init__(self, artifact_store, databackend):
def __init__(self, artifact_store, databackend, db: t.Optional['Datalayer'] = None):
self.artifact_store = artifact_store
self.databackend = databackend
self.db = db

def __getitem__(self, item):
return self.db[item]

def __getitem__(self, item):
from superduper.backends.base.query import Query
Expand Down Expand Up @@ -305,12 +320,14 @@ def __getitem__(self, key: str) -> t.Any:
return super().__getitem__(key)

@classmethod
def build_in_memory_db(cls, blobs, files):
def build_in_memory_db(cls, blobs, files, db: t.Optional['Datalayer'] | None = None):
artifact_store = db.artifact_store if db is not None else None
return _TmpDB(
artifact_store=_InMemoryArtifactStore(blobs=blobs, files=files),
artifact_store=_InMemoryArtifactStore(blobs=blobs, files=files, artifact_store=artifact_store),
databackend=namedtuple('tmp_databackend', field_names=('bytes_encoding',))(
bytes_encoding='bytes'
),
db=db,
)

def dict(self, *args, **kwargs):
Expand All @@ -334,10 +351,11 @@ def decode(
:param schema: The schema to use.
:param db: The datalayer to use.
"""
if db is None:
blobs = r.pop('_blobs', {})
files = r.pop('_files', {})
db = cls.build_in_memory_db(blobs=blobs, files=files)
blobs = r.pop('_blobs', {})
files = r.pop('_files', {})

if blobs or files:
db = cls.build_in_memory_db(blobs=blobs, files=files, db=db)

if '_variables' in r:
variables = {**r['_variables'], 'output_prefix': CFG.output_prefix}
Expand Down
Loading

0 comments on commit 960e54e

Please sign in to comment.