Skip to content

Commit

Permalink
Merge pull request #63 from iwilltry42/feat/delete-document
Browse files Browse the repository at this point in the history
add: Delete to delete a documents from a collection
  • Loading branch information
philippgille authored Apr 22, 2024
2 parents 9c85711 + 1da54cf commit df6c863
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 7 deletions.
79 changes: 72 additions & 7 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,19 +236,73 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {

// Persist the document
if c.persistDirectory != "" {
safeID := hash2hex(doc.ID)
docPath := filepath.Join(c.persistDirectory, safeID)
docPath += ".gob"
if c.compress {
docPath += ".gz"
}
docPath := c.getDocPath(doc.ID)
err := persist(docPath, doc, c.compress, "")
if err != nil {
return fmt.Errorf("couldn't persist document: %w", err)
return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
}
}

return nil
}

// Delete removes document(s) from the collection.
//
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
// - ids: The ids of the documents to delete. If empty, all documents are deleted.
func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {

// must have at least one of where, whereDocument or ids
if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
return fmt.Errorf("must have at least one of where, whereDocument or ids")
}

if len(c.documents) == 0 {
return nil
}

for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return errors.New("unsupported whereDocument operator")
}
}

var docIDs []string

c.documentsLock.Lock()
defer c.documentsLock.Unlock()

if where != nil || whereDocument != nil {
// metadata + content filters
filteredDocs := filterDocs(c.documents, where, whereDocument)
for _, doc := range filteredDocs {
docIDs = append(docIDs, doc.ID)
}
} else {
docIDs = ids
}

// No-op if no docs are left
if len(docIDs) == 0 {
return nil
}

for _, docID := range docIDs {
delete(c.documents, docID)

// Remove the document from disk
if c.persistDirectory != "" {
docPath := c.getDocPath(docID)
err := remove(docPath)
if err != nil {
return fmt.Errorf("couldn't remove document at %q: %w", docPath, err)
}
}
}

return nil

}

// Count returns the number of documents in the collection.
Expand Down Expand Up @@ -350,3 +404,14 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
// Return the top nResults
return res, nil
}

// getDocPath generates the path to the document file.
func (c *Collection) getDocPath(docID string) string {
safeID := hash2hex(docID)
docPath := filepath.Join(c.persistDirectory, safeID)
docPath += ".gob"
if c.compress {
docPath += ".gz"
}
return docPath
}
90 changes: 90 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"math/rand"
"os"
"slices"
"strconv"
"testing"
Expand Down Expand Up @@ -420,6 +421,95 @@ func TestCollection_Count(t *testing.T) {
}
}

func TestCollection_Delete(t *testing.T) {
// Create persistent collection
tmpdir, err := os.MkdirTemp(os.TempDir(), "chromem-test-*")
if err != nil {
t.Fatal("expected no error, got", err)
}
db, err := NewPersistentDB(tmpdir, false)
if err != nil {
t.Fatal("expected no error, got", err)
}
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
if c == nil {
t.Fatal("expected collection, got nil")
}

// Add documents
ids := []string{"1", "2", "3", "4"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}, {"foo": "bar"}, {"e": "f"}}
contents := []string{"hello world", "hallo welt", "bonjour le monde", "hola mundo"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Fatal("expected nil, got", err)
}

// Check count
if c.Count() != 4 {
t.Fatal("expected 4 documents, got", c.Count())
}

// Check number of files in the persist directory
d, err := os.ReadDir(c.persistDirectory)

if err != nil {
t.Fatal("expected nil, got", err)
}
if len(d) != 5 { // 4 documents + 1 metadata file
t.Fatal("expected 4 document files + 1 metadata file in persist_dir, got", len(d))
}

checkCount := func(expected int) {
// Check count
if c.Count() != expected {
t.Fatalf("expected %d documents, got %d", expected, c.Count())
}

// Check number of files in the persist directory
d, err = os.ReadDir(c.persistDirectory)
if err != nil {
t.Fatal("expected nil, got", err)
}
if len(d) != expected+1 { // 3 document + 1 metadata file
t.Fatalf("expected %d document files + 1 metadata file in persist_dir, got %d", expected, len(d))
}
}

// Test 1 - Remove document by ID: should delete one document
err = c.Delete(context.Background(), nil, nil, "4")
if err != nil {
t.Fatal("expected nil, got", err)
}
checkCount(3)

// Test 2 - Remove document by metadata
err = c.Delete(context.Background(), map[string]string{"foo": "bar"}, nil)
if err != nil {
t.Fatal("expected nil, got", err)
}

checkCount(1)

// Test 3 - Remove document by content
err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"})
if err != nil {
t.Fatal("expected nil, got", err)
}

checkCount(0)

}

// Global var for assignment in the benchmark to avoid compiler optimizations.
var globalRes []Result

Expand Down
16 changes: 16 additions & 0 deletions persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,19 @@ func read(filePath string, obj any, encryptionKey string) error {

return nil
}

// remove removes a file at the given path. If the file doesn't exist, it's a no-op.
func remove(filePath string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}

err := os.Remove(filePath)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("couldn't remove file %q: %w", filePath, err)
}
}

return nil
}

0 comments on commit df6c863

Please sign in to comment.