diff --git a/collection.go b/collection.go index b0765cd..e01540a 100644 --- a/collection.go +++ b/collection.go @@ -236,19 +236,73 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { // Persist the document if c.persistDirectory != "" { - safeID := hash2hex(doc.ID) - docPath := filepath.Join(c.persistDirectory, safeID) - docPath += ".gob" - if c.compress { - docPath += ".gz" - } + docPath := c.getDocPath(doc.ID) err := persist(docPath, doc, c.compress, "") if err != nil { - return fmt.Errorf("couldn't persist document: %w", err) + return fmt.Errorf("couldn't persist document to %q: %w", docPath, err) + } + } + + return nil +} + +// Delete removes document(s) from the collection. +// +// - where: Conditional filtering on metadata. Optional. +// - whereDocument: Conditional filtering on documents. Optional. +// - ids: The ids of the documents to delete. If empty, all documents are deleted. +func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error { + + // must have at least one of where, whereDocument or ids + if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 { + return fmt.Errorf("must have at least one of where, whereDocument or ids") + } + + if len(c.documents) == 0 { + return nil + } + + for k := range whereDocument { + if !slices.Contains(supportedFilters, k) { + return errors.New("unsupported whereDocument operator") + } + } + + var docIDs []string + + c.documentsLock.Lock() + defer c.documentsLock.Unlock() + + if where != nil || whereDocument != nil { + // metadata + content filters + filteredDocs := filterDocs(c.documents, where, whereDocument) + for _, doc := range filteredDocs { + docIDs = append(docIDs, doc.ID) + } + } else { + docIDs = ids + } + + // No-op if no docs are left + if len(docIDs) == 0 { + return nil + } + + for _, docID := range docIDs { + delete(c.documents, docID) + + // Remove the document from disk + if c.persistDirectory != "" { + docPath := c.getDocPath(docID) + err := remove(docPath) + if err != nil { + return fmt.Errorf("couldn't remove document at %q: %w", docPath, err) + } } } return nil + } // Count returns the number of documents in the collection. @@ -350,3 +404,14 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 // Return the top nResults return res, nil } + +// getDocPath generates the path to the document file. +func (c *Collection) getDocPath(docID string) string { + safeID := hash2hex(docID) + docPath := filepath.Join(c.persistDirectory, safeID) + docPath += ".gob" + if c.compress { + docPath += ".gz" + } + return docPath +} diff --git a/collection_test.go b/collection_test.go index fc37689..e37e511 100644 --- a/collection_test.go +++ b/collection_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "math/rand" + "os" "slices" "strconv" "testing" @@ -420,6 +421,95 @@ func TestCollection_Count(t *testing.T) { } } +func TestCollection_Delete(t *testing.T) { + // Create persistent collection + tmpdir, err := os.MkdirTemp(os.TempDir(), "chromem-test-*") + if err != nil { + t.Fatal("expected no error, got", err) + } + db, err := NewPersistentDB(tmpdir, false) + if err != nil { + t.Fatal("expected no error, got", err) + } + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Fatal("expected no error, got", err) + } + if c == nil { + t.Fatal("expected collection, got nil") + } + + // Add documents + ids := []string{"1", "2", "3", "4"} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}, {"foo": "bar"}, {"e": "f"}} + contents := []string{"hello world", "hallo welt", "bonjour le monde", "hola mundo"} + err = c.Add(context.Background(), ids, nil, metadatas, contents) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Check count + if c.Count() != 4 { + t.Fatal("expected 4 documents, got", c.Count()) + } + + // Check number of files in the persist directory + d, err := os.ReadDir(c.persistDirectory) + + if err != nil { + t.Fatal("expected nil, got", err) + } + if len(d) != 5 { // 4 documents + 1 metadata file + t.Fatal("expected 4 document files + 1 metadata file in persist_dir, got", len(d)) + } + + checkCount := func(expected int) { + // Check count + if c.Count() != expected { + t.Fatalf("expected %d documents, got %d", expected, c.Count()) + } + + // Check number of files in the persist directory + d, err = os.ReadDir(c.persistDirectory) + if err != nil { + t.Fatal("expected nil, got", err) + } + if len(d) != expected+1 { // 3 document + 1 metadata file + t.Fatalf("expected %d document files + 1 metadata file in persist_dir, got %d", expected, len(d)) + } + } + + // Test 1 - Remove document by ID: should delete one document + err = c.Delete(context.Background(), nil, nil, "4") + if err != nil { + t.Fatal("expected nil, got", err) + } + checkCount(3) + + // Test 2 - Remove document by metadata + err = c.Delete(context.Background(), map[string]string{"foo": "bar"}, nil) + if err != nil { + t.Fatal("expected nil, got", err) + } + + checkCount(1) + + // Test 3 - Remove document by content + err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"}) + if err != nil { + t.Fatal("expected nil, got", err) + } + + checkCount(0) + +} + // Global var for assignment in the benchmark to avoid compiler optimizations. var globalRes []Result diff --git a/persistence.go b/persistence.go index 4a385d4..5748afd 100644 --- a/persistence.go +++ b/persistence.go @@ -225,3 +225,19 @@ func read(filePath string, obj any, encryptionKey string) error { return nil } + +// remove removes a file at the given path. If the file doesn't exist, it's a no-op. +func remove(filePath string) error { + if filePath == "" { + return fmt.Errorf("file path is empty") + } + + err := os.Remove(filePath) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("couldn't remove file %q: %w", filePath, err) + } + } + + return nil +}