Skip to content

Commit

Permalink
[ENH] Atomically delete Collection w segments (chroma-core#3039)
Browse files Browse the repository at this point in the history
## Atomically delete collection & segments in 1 transaction.
This PR only changes the backend GRPC/sysdb code.

*Summarize the changes made by this PR.*
 - Improvements
	 - Deletion of segments happens along with deletion of collections, so it removes the state where orphaned segments could be lying around.

## Test plan
- [ ] Tests pass locally with `pytest` for python, `make test` for
golang

## Documentation Changes
N/A
  • Loading branch information
rohitcpbot authored Nov 19, 2024
1 parent 5bd2ed7 commit 9116840
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 73 deletions.
3 changes: 1 addition & 2 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ def delete_collection(
self._sysdb.delete_collection(
existing[0].id, tenant=tenant, database=database
)
for s in self._manager.delete_segments(existing[0].id):
self._sysdb.delete_segment(existing[0].id, s)
self._manager.delete_segments(existing[0].id)
else:
raise ValueError(f"Collection {name} does not exist.")

Expand Down
5 changes: 4 additions & 1 deletion chromadb/db/impl/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ def create_collection(

@overrides
def delete_collection(
self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
self,
id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
try:
request = DeleteCollectionRequest(
Expand Down
21 changes: 19 additions & 2 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from concurrent import futures
from typing import Any, Dict, cast
from typing import Any, Dict, List, cast
from uuid import UUID
from overrides import overrides
from chromadb.api.configuration import CollectionConfigurationInternal
Expand Down Expand Up @@ -56,6 +56,7 @@ class GrpcMockSysDB(SysDBServicer, Component):
_server: grpc.Server
_server_port: int
_segments: Dict[str, Segment] = {}
_collection_to_segments: Dict[str, List[str]] = {}
_tenants_to_databases_to_collections: Dict[
str, Dict[str, Dict[str, Collection]]
] = {}
Expand Down Expand Up @@ -290,13 +291,24 @@ def CreateCollection(
tenant=tenant,
version=0,
)
collections[request.id] = new_collection

# Check that segments are unique and do not already exist
# Keep a track of the segments that are being added
segments_added = []
# Create segments for the collection
for segment_proto in request.segments:
segment = from_proto_segment(segment_proto)
if segment["id"].hex in self._segments:
# Remove the already added segment since we need to roll back
for s in segments_added:
self.DeleteSegment(DeleteSegmentRequest(id=s), context)
context.abort(grpc.StatusCode.ALREADY_EXISTS, f"Segment {segment['id']} already exists")
self.CreateSegmentHelper(segment, context)
segments_added.append(segment["id"].hex)

collections[request.id] = new_collection
collection_unique_key = f"{tenant}:{database}:{request.id}"
self._collection_to_segments[collection_unique_key] = segments_added
return CreateCollectionResponse(
collection=to_proto_collection(new_collection),
created=True,
Expand All @@ -316,6 +328,11 @@ def DeleteCollection(
collections = self._tenants_to_databases_to_collections[tenant][database]
if collection_id in collections:
del collections[collection_id]
collection_unique_key = f"{tenant}:{database}:{collection_id}"
segment_ids = self._collection_to_segments[collection_unique_key]
if segment_ids: # Delete segments if provided.
for segment_id in segment_ids:
del self._segments[segment_id]
return DeleteCollectionResponse()
else:
context.abort(
Expand Down
16 changes: 16 additions & 0 deletions chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,19 @@ def delete_segment(self, collection: UUID, id: UUID) -> None:
if not result:
raise NotFoundError(f"Segment {id} not found")

# Used by delete_collection to delete all segments for a collection along with
# the collection itself in a single transaction.
def delete_segments_for_collection(self, cur: Cursor, collection: UUID) -> None:
segments_t = Table("segments")
q = (
self.querybuilder()
.from_(segments_t)
.where(segments_t.collection == ParameterValue(self.uuid_to_db(collection)))
.delete()
)
sql, params = get_sql(q, self.parameter_format())
cur.execute(sql, params)

@trace_method("SqlSysDB.delete_collection", OpenTelemetryGranularity.ALL)
@override
def delete_collection(
Expand Down Expand Up @@ -550,6 +563,9 @@ def delete_collection(
result = cur.execute(sql, params).fetchone()
if not result:
raise NotFoundError(f"Collection {id} not found")
# Delete segments.
self.delete_segments_for_collection(cur, id)

self._producer.delete_log(result[0])

@trace_method("SqlSysDB.update_segment", OpenTelemetryGranularity.ALL)
Expand Down
5 changes: 4 additions & 1 deletion chromadb/db/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def create_collection(

@abstractmethod
def delete_collection(
self, id: UUID, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
self,
id: UUID,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
"""Delete a collection, all associated segments and any associate resources (log stream)
from the SysDB and the system at large."""
Expand Down
67 changes: 33 additions & 34 deletions chromadb/test/db/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import tempfile
import pytest
import sqlite3
from typing import Generator, List, Callable, Dict, Union

from chromadb.db.impl.grpc.client import GrpcSysDB
Expand Down Expand Up @@ -180,21 +181,16 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None:
logger.debug("Resetting state")
sysdb.reset_state()

segments_created_with_collection = []
for collection in sample_collections:
logger.debug(f"Creating collection: {collection.name}")
segment = sample_segment(collection_id=collection.id)
segments_created_with_collection.append(segment)
sysdb.create_collection(
id=collection.id,
name=collection.name,
configuration=collection.get_configuration(),
segments=[
Segment(
id=uuid.uuid4(),
type="test_type_a",
scope=SegmentScope.VECTOR,
collection=collection.id,
metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3},
)
],
segments=[segment],
metadata=collection["metadata"],
dimension=collection["dimension"],
)
Expand All @@ -213,15 +209,7 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None:
name=sample_collections[0].name,
id=sample_collections[0].id,
configuration=sample_collections[0].get_configuration(),
segments=[
Segment(
id=uuid.uuid4(),
type="test_type_a",
scope=SegmentScope.VECTOR,
collection=sample_collections[0].id,
metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3},
)
],
segments=[segments_created_with_collection[0]],
)

# Find by name
Expand All @@ -236,7 +224,7 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None:

# Delete
c1 = sample_collections[0]
sysdb.delete_collection(c1.id)
sysdb.delete_collection(id=c1.id)

results = sysdb.get_collections()
assert c1 not in results
Expand All @@ -246,11 +234,34 @@ def test_create_get_delete_collections(sysdb: SysDB) -> None:
by_id_result = sysdb.get_collections(id=c1["id"])
assert by_id_result == []

# Check that the segment was deleted
by_collection_result = sysdb.get_segments(collection=c1.id)
assert by_collection_result == []

# Duplicate delete throws an exception
with pytest.raises(NotFoundError):
sysdb.delete_collection(c1.id)


# Create a new collection with two segments that have same id.
# Creation should fail.
# Check that collection was not created.
# Check that segments were not created.
with pytest.raises((InternalError, UniqueConstraintError, sqlite3.IntegrityError)):
sysdb.create_collection(
name=sample_collections[0].name,
id=sample_collections[0].id,
configuration=sample_collections[0].get_configuration(),
segments=[segments_created_with_collection[0], segments_created_with_collection[0]],
)
# Check that collection was not created.
# Check that segments were not created.
by_id_result = sysdb.get_collections(id=sample_collections[0].id)
assert by_id_result == []
by_collection_result = sysdb.get_segments(collection=sample_collections[0].id)
assert by_collection_result == []


def test_update_collections(sysdb: SysDB) -> None:
coll = Collection(
name=sample_collections[0].name,
Expand Down Expand Up @@ -831,24 +842,12 @@ def test_create_get_delete_segments(sysdb: SysDB) -> None:
result = sysdb.get_segments(type="test_type_b", collection=sample_collections[0].id)
assert len(result) == 0

# Delete
# Delete Segments will be a NoOp due to atomic delete of collection and segments.
# See comments in coordinator.go for more details.
# Execute the call and expect no error or exception.
s1 = segments_created_with_collection[0]
sysdb.delete_segment(s1["collection"], s1["id"])

results = []
for collection in sample_collections:
results.extend(sysdb.get_segments(collection=collection.id))
assert s1 not in results
assert len(results) == len(segments_created_with_collection) - 1
assert sorted(results, key=lambda c: c["id"]) == sorted(
segments_created_with_collection[1:], key=lambda c: c["id"]
)

# Duplicate delete throws an exception
with pytest.raises(NotFoundError):
sysdb.delete_segment(s1["collection"], s1["id"])


def test_update_segment(sysdb: SysDB) -> None:
metadata: Dict[str, Union[str, int, float]] = {
"test_str": "str1",
Expand Down
4 changes: 4 additions & 0 deletions go/pkg/sysdb/coordinator/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ func (s *Coordinator) GetSegments(ctx context.Context, segmentID types.UniqueID,
return s.catalog.GetSegments(ctx, segmentID, segmentType, scope, collectionID)
}

// DeleteSegment is a no-op.
// Segments are deleted as part of atomic delete of collection.
// Keeping this API so that older clients continue to work, since older clients will issue DeleteSegment
// after a DeleteCollection.
func (s *Coordinator) DeleteSegment(ctx context.Context, segmentID types.UniqueID, collectionID types.UniqueID) error {
return s.catalog.DeleteSegment(ctx, segmentID, collectionID)
}
Expand Down
86 changes: 74 additions & 12 deletions go/pkg/sysdb/coordinator/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sort"
"strconv"
"testing"
"time"

"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dao"
"github.com/pingcap/log"
Expand Down Expand Up @@ -406,6 +407,74 @@ func (suite *APIsTestSuite) TestCreateGetDeleteCollections() {
// Duplicate delete throws an exception
err = suite.coordinator.DeleteCollection(ctx, deleteCollection)
suite.Error(err)

// Re-create the deleted collection
// Recreating the deleted collection with new ID since the old ID is already in use by the soft deleted collection.
createCollection := &model.CreateCollection{
ID: types.NewUniqueID(),
Name: suite.sampleCollections[0].Name,
Dimension: suite.sampleCollections[0].Dimension,
Metadata: suite.sampleCollections[0].Metadata,
TenantID: suite.tenantName,
DatabaseName: suite.databaseName,
Ts: types.Timestamp(time.Now().Unix()),
}
_, _, err = suite.coordinator.CreateCollection(ctx, createCollection)
suite.NoError(err)

// Verify collection was re-created
results, err = suite.coordinator.GetCollections(ctx, createCollection.ID, nil, suite.tenantName, suite.databaseName, nil, nil)
suite.NoError(err)
suite.Len(results, 1)
suite.Equal(createCollection.ID, results[0].ID)
suite.Equal(createCollection.Name, results[0].Name)
suite.Equal(createCollection.Dimension, results[0].Dimension)
suite.Equal(createCollection.Metadata, results[0].Metadata)

// Create segments associated with collection
segment := &model.CreateSegment{
ID: types.MustParse("00000000-0000-0000-0000-000000000001"),
CollectionID: createCollection.ID,
Type: "test_segment",
Scope: "test_scope",
Ts: types.Timestamp(time.Now().Unix()),
}
err = suite.coordinator.CreateSegment(ctx, segment)
suite.NoError(err)

// Verify segment was created
segments, err := suite.coordinator.GetSegments(ctx, segment.ID, nil, nil, createCollection.ID)
suite.NoError(err)
suite.Len(segments, 1)
suite.Equal(segment.ID, segments[0].ID)
suite.Equal(segment.CollectionID, segments[0].CollectionID)
suite.Equal(segment.Type, segments[0].Type)
suite.Equal(segment.Scope, segments[0].Scope)

// Delete the re-created collection with segments
deleteCollection = &model.DeleteCollection{
ID: createCollection.ID,
DatabaseName: suite.databaseName,
TenantID: suite.tenantName,
}
err = suite.coordinator.DeleteCollection(ctx, deleteCollection)
suite.NoError(err)

// Verify collection and segment were deleted
results, err = suite.coordinator.GetCollections(ctx, createCollection.ID, nil, suite.tenantName, suite.databaseName, nil, nil)
suite.NoError(err)
suite.Empty(results)
// Segments will not be deleted since the collection is only soft deleted.
// Hard deleting the collection will also delete the segments.
segments, err = suite.coordinator.GetSegments(ctx, segment.ID, nil, nil, createCollection.ID)
suite.NoError(err)
suite.NotEmpty(segments)
suite.coordinator.deleteMode = HardDelete
err = suite.coordinator.DeleteCollection(ctx, deleteCollection)
suite.NoError(err)
segments, err = suite.coordinator.GetSegments(ctx, segment.ID, nil, nil, createCollection.ID)
suite.NoError(err)
suite.Empty(segments)
}

func (suite *APIsTestSuite) TestUpdateCollections() {
Expand Down Expand Up @@ -890,24 +959,17 @@ func (suite *APIsTestSuite) TestCreateGetDeleteSegments() {
suite.NoError(err)
suite.Empty(result)

// Delete
// Delete Segments will not delete the collection, hence no tests for that.
// Reason - After introduction of Atomic delete of collection & segments,
// the DeleteSegment API will not delete the collection. See comments in
// coordinator.go/DeleteSegment for more details.
s1 := sampleSegments[0]
err = c.DeleteSegment(ctx, s1.ID, s1.CollectionID)
suite.NoError(err)

results, err = c.GetSegments(ctx, types.NilUniqueID(), nil, nil, s1.CollectionID)
suite.NoError(err)
suite.NotContains(results, s1)
suite.Len(results, 0)

// Duplicate delete throws an exception
err = c.DeleteSegment(ctx, s1.ID, s1.CollectionID)
suite.Error(err)

// clean up segments
for _, segment := range sampleSegments {
_ = c.DeleteSegment(ctx, segment.ID, segment.CollectionID)
}
suite.Contains(results, s1)
}

func (suite *APIsTestSuite) TestUpdateSegment() {
Expand Down
Loading

0 comments on commit 9116840

Please sign in to comment.