Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the issue where the relationship data of children is not deleted during replacement. #2768

Merged
merged 5 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fix the random silent failure bug when ibis creates tables.
- Fix the multi-checkbox on frontend
- Fix the issue where the relationship data of children is not deleted during replacement.

## [0.5.0](https://github.com/superduper-io/superduper/compare/0.5.0...0.4.0]) (2024-Nov-02)

Expand Down
2 changes: 1 addition & 1 deletion plugins/ibis/superduper_ibis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .data_backend import IbisDataBackend as DataBackend
from .query import IbisQuery

__version__ = "0.5.3"
__version__ = "0.5.3"

__all__ = ["IbisQuery", "DataBackend"]
4 changes: 2 additions & 2 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ def build_metadata(self):

def _check_token(self):
import datetime

auth_token = os.environ['SUPERDUPER_AUTH_TOKEN']
with open(auth_token) as f:
expiration_date = datetime.datetime.strptime(
f.read().split('\n')[0].strip(),
"%Y-%m-%d %H:%M:%S.%f"
f.read().split('\n')[0].strip(), "%Y-%m-%d %H:%M:%S.%f"
)
if expiration_date < datetime.datetime.now():
raise Exception("auth token expired")
Expand Down
2 changes: 1 addition & 1 deletion plugins/mongodb/superduper_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .query import MongoQuery
from .vector_search import MongoAtlasVectorSearcher as VectorSearcher

__version__ = "0.5.0"
__version__ = "0.5.1"

__all__ = [
"ArtifactStore",
Expand Down
14 changes: 8 additions & 6 deletions plugins/mongodb/superduper_mongodb/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def delete_parent_child(self, parent: str, child: str) -> None:
"""
self.parent_child_mappings.delete_many(
{
'parent': parent,
'child': child,
'parent_id': parent,
'child_id': child,
}
)

Expand All @@ -95,8 +95,8 @@ def create_parent_child(self, parent: str, child: str) -> None:
"""
self.parent_child_mappings.insert_one(
{
'parent': parent,
'child': child,
'parent_id': parent,
'child_id': child,
}
)

Expand Down Expand Up @@ -275,7 +275,7 @@ def component_version_has_parents(
{'type_id': type_id, 'identifier': identifier, 'version': version},
{'uuid': 1, 'id': 1},
)['uuid']
doc = {'child': uuid}
doc = {'child_id': uuid}
return self.parent_child_mappings.count_documents(doc)

def delete_component_version(
Expand Down Expand Up @@ -358,7 +358,9 @@ def get_component_version_parents(self, uuid: str) -> t.List[str]:

:param uuid: unique identifier of component
"""
return [r['parent'] for r in self.parent_child_mappings.find({'child': uuid})]
return [
r['parent_id'] for r in self.parent_child_mappings.find({'child_id': uuid})
]

def _replace_object(
self,
Expand Down
26 changes: 16 additions & 10 deletions plugins/snowflake/superduper_snowflake/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import pandas
import snowflake.connector

from superduper_snowflake.schema import ibis_schema_to_snowpark_cols, snowpark_cols_to_schema
from superduper_snowflake.schema import (
ibis_schema_to_snowpark_cols,
snowpark_cols_to_schema,
)
from superduper import logging
from superduper_ibis.data_backend import IbisDataBackend
from snowflake.snowpark import Session


class SnowflakeDataBackend(IbisDataBackend):

@wraps(IbisDataBackend.__init__)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -22,8 +24,9 @@ def __init__(self, *args, **kwargs):

@staticmethod
def _get_snowpark_session(uri):
logging.info('Creating Snowpark session for'
' snowflake vector-search implementation')
logging.info(
'Creating Snowpark session for' ' snowflake vector-search implementation'
)
if uri == 'snowflake://':
connection_parameters = dict(
host=os.environ['SNOWFLAKE_HOST'],
Expand All @@ -37,9 +40,7 @@ def _get_snowpark_session(uri):
)
else:
if '?warehouse=' not in uri:
match = re.match(
'^snowflake:\/\/(.*):(.*)\@(.*)\/(.*)\/(.*)$', uri
)
match = re.match('^snowflake:\/\/(.*):(.*)\@(.*)\/(.*)\/(.*)$', uri)
user, password, account, database, schema = match.groups()
warehouse = None
else:
Expand Down Expand Up @@ -76,7 +77,13 @@ def _do_connection_callback(uri):
def _connection_callback(self, uri):
if uri != 'snowflake://':
return IbisDataBackend._connection_callback(uri)
return ibis.snowflake.from_connection(self._do_connection_callback(uri), create_object_udfs=False), 'snowflake', False
return (
ibis.snowflake.from_connection(
self._do_connection_callback(uri), create_object_udfs=False
),
'snowflake',
False,
)

def reconnect(self):
super().reconnect()
Expand All @@ -93,5 +100,4 @@ def insert(self, table_name, raw_documents):
snowpark_cols = ibis_schema_to_snowpark_cols(ibis_schema)
snowpark_schema = snowpark_cols_to_schema(snowpark_cols, columns)
native_df = self.snowpark.create_dataframe(rows, schema=snowpark_schema)
return native_df.write.saveAsTable(f'"{table_name}"', mode='append')

return native_df.write.saveAsTable(f'"{table_name}"', mode='append')
18 changes: 9 additions & 9 deletions plugins/snowflake/superduper_snowflake/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@
DateType,
TimestampType,
DecimalType,
VariantType
VariantType,
)


def ibis_type_to_snowpark_type(ibis_dtype):
"""Convert an Ibis data type to the closest Snowpark type."""

# Integer (covers int8, int16, int32, int64 in Ibis)
if ibis_dtype.is_integer():
return IntegerType()

# Boolean
if ibis_dtype.is_boolean():
return BooleanType()

# Floating point (covers float32, float64 in Ibis)
if ibis_dtype.is_floating():
# FloatType is 32-bit, DoubleType is 64-bit
# You could decide based on ibis_dtype here. Example:
return DoubleType()

# Decimal (e.g. Decimal(precision, scale))
if ibis_dtype.is_decimal():
# Get precision and scale from the Ibis type
Expand All @@ -38,19 +38,19 @@ def ibis_type_to_snowpark_type(ibis_dtype):

if ibis_dtype.is_json():
return VariantType()

# String
if ibis_dtype.is_string():
return StringType()

# Date
if ibis_dtype.is_date():
return DateType()

# Timestamp
if ibis_dtype.is_timestamp():
return TimestampType()

# Fallback: map everything else to StringType (or VariantType, etc. if desired)
return StringType()

Expand Down
5 changes: 1 addition & 4 deletions plugins/snowflake/superduper_snowflake/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ class UpdatingSecretException(Exception):
def check_secret_updates(db):
result = db.databackend.conn.raw_sql("CALL v1.wrapper('SHOW SECRETS')")

lookup = {
r[1]: json.loads(r[5])['status']['hash']
for r in result
}
lookup = {r[1]: json.loads(r[5])['status']['hash'] for r in result}

updating = []
for k in lookup:
Expand Down
4 changes: 3 additions & 1 deletion plugins/snowflake/superduper_snowflake/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def wrapper(self, *args, **kwargs):
if 'token' in str(e):
self.session = SnowflakeVectorSearcher.create_session(CFG.data_backend)
return f(self, *args, **kwargs)

return wrapper


Expand Down Expand Up @@ -97,7 +98,8 @@ def create_session(cls, vector_search_uri):
warehouse = None
else:
match = re.match(
'^snowflake://(.*):(.*)@(.*)/(.*)/(.*)?warehouse=(.*)$', vector_search_uri
'^snowflake://(.*):(.*)@(.*)/(.*)/(.*)?warehouse=(.*)$',
vector_search_uri,
)
user, password, account, database, schema, warehouse = match.groups()
if match:
Expand Down
5 changes: 2 additions & 3 deletions plugins/sqlalchemy/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def test_artifact_relation(metadata):


def test_cleanup_metadata():

db = superduper(DATABASE_URL)

@model
def test(x): return x + 1
def test(x):
return x + 1

db.apply(test, force=True)

Expand All @@ -48,4 +48,3 @@ def test(x): return x + 1
assert not db.show(), 'The metadata was not cleared up'

assert not db.metadata._cache, f'Cache not cleared: {db.metadata._cache}'

2 changes: 1 addition & 1 deletion plugins/sqlalchemy/superduper_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .metadata import SQLAlchemyMetadata as MetaDataStore

__version__ = "0.5.6"
__version__ = "0.5.7"

__all__ = ['MetaDataStore']
34 changes: 18 additions & 16 deletions plugins/sqlalchemy/superduper_sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def _connect_snowflake():

import snowflake.connector
import os

if os.environ.get('SUPERDUPER_AUTH_DEBUG'):
with open(os.environ['SUPERDUPER_AUTH_TOKEN'], 'w') as f:
f.write('2026-01-01 23:59:59.999999\n')
Expand Down Expand Up @@ -107,7 +108,7 @@ def expire_version(self, type_id, identifier, version):
del self._uuid2metadata[r['uuid']]
if not self._type_id_identifier2metadata[(type_id, identifier)]:
del self._type_id_identifier2metadata[(type_id, identifier)]

def add_metadata(self, metadata):
metadata = copy.deepcopy(metadata)
if 'dict' in metadata:
Expand Down Expand Up @@ -306,6 +307,9 @@ def _init_tables(self):

self._table_mapping = {
'_artifact_relations': self.artifact_table,
'_parent_child': self.parent_child_association_table,
'_component': self.component_table,
'_job': self.job_table,
}

try:
Expand Down Expand Up @@ -338,11 +342,11 @@ def _delete_data(self, table_name, filter):
def _check_token(self):
import os
import datetime

auth_token = os.environ['SUPERDUPER_AUTH_TOKEN']
with open(auth_token) as f:
expiration_date = datetime.datetime.strptime(
f.read().split('\n')[0].strip(),
"%Y-%m-%d %H:%M:%S.%f"
f.read().split('\n')[0].strip(), "%Y-%m-%d %H:%M:%S.%f"
)
if expiration_date < datetime.datetime.now():
raise Exception("auth token expired")
Expand Down Expand Up @@ -467,16 +471,17 @@ def create_component(
with self.session_context(commit=not self.batched) as session:
if not self.batched:
primary_key_value = new_info['id']
exists = session.execute(
select(self.component_table).
where(self.component_table.c.id == primary_key_value)
).scalar() is not None
exists = (
session.execute(
select(self.component_table).where(
self.component_table.c.id == primary_key_value
)
).scalar()
is not None
)
if exists:
return
stmt = (
insert(self.component_table)
.values(new_info)
)
stmt = insert(self.component_table).values(new_info)
session.execute(stmt)
else:
self._insert_flush['component'].append(copy.deepcopy(new_info))
Expand Down Expand Up @@ -746,11 +751,8 @@ def _replace_object(
def show_cdc_tables(self):
"""Show tables to be consumed with cdc."""
with self.session_context() as session:
stmt = (
self.component_table.select()
.where(
self.component_table.c.cdc_table.isnot(None),
)
stmt = self.component_table.select().where(
self.component_table.c.cdc_table.isnot(None),
)
res = self.query_results(self.component_table, stmt, session)
return [r['cdc_table'] for r in res]
Expand Down
9 changes: 9 additions & 0 deletions superduper/backends/base/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ def get_artifact_relations(self, uuid=None, artifact_id=None):
ids = [relation['uuid'] for relation in relations]
return ids

def get_children_relations(self, parent: str):
"""
Get all children of a component.

:param parent: parent component
"""
relations = self._get_data('_parent_child', {'parent_id': parent})
return [relation['child_id'] for relation in relations]

# TODO: Refactor to use _create_data, _delete_data, _get_data
@abstractmethod
def _create_data(self, table_name, datas):
Expand Down
12 changes: 6 additions & 6 deletions superduper/backends/local/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def _put(self, component: Component):
current_component = self._cache[current]
current_version = current_component.version
if current_version < component.version:
self._component_to_uuid[component.type_id, component.identifier] = (
component.uuid
)
self._component_to_uuid[
component.type_id, component.identifier
] = component.uuid
self.expire(current_component.uuid)
else:
self._component_to_uuid[component.type_id, component.identifier] = (
component.uuid
)
self._component_to_uuid[
component.type_id, component.identifier
] = component.uuid

def __delitem__(self, item):
if isinstance(item, tuple):
Expand Down
6 changes: 3 additions & 3 deletions superduper/backends/local/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def _put(self, component):
msg = 'Table name "_apply" collides with Superduper namespace'
assert component.cdc_table != '_apply', msg
assert isinstance(component, CDC)
self._component_uuid_mapping[component.type_id, component.identifier] = (
component.uuid
)
self._component_uuid_mapping[
component.type_id, component.identifier
] = component.uuid
if component.cdc_table in self.queue:
return
self.queue[component.cdc_table] = []
Expand Down
Loading
Loading