diff --git a/collection.go b/collection.go index ec641a1..8a0bdcb 100644 --- a/collection.go +++ b/collection.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "maps" "path/filepath" "slices" "sync" @@ -291,6 +292,31 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { return nil } +// GetByID returns a document by its ID. +// The returned document is a copy of the original document, so it can be safely +// modified without affecting the collection. +func (c *Collection) GetByID(ctx context.Context, id string) (Document, error) { + if id == "" { + return Document{}, errors.New("document ID is empty") + } + + c.documentsLock.RLock() + defer c.documentsLock.RUnlock() + + doc, ok := c.documents[id] + if ok { + // Clone the document + res := *doc + // Above copies the simple fields, but we need to copy the slices and maps + res.Metadata = maps.Clone(doc.Metadata) + res.Embedding = slices.Clone(doc.Embedding) + + return res, nil + } + + return Document{}, fmt.Errorf("document with ID '%v' not found", id) +} + // Delete removes document(s) from the collection. // // - where: Conditional filtering on metadata. Optional. diff --git a/collection_test.go b/collection_test.go index 6ec2738..4e47d4b 100644 --- a/collection_test.go +++ b/collection_test.go @@ -391,6 +391,60 @@ func TestCollection_QueryError(t *testing.T) { } } +func TestCollection_Get(t *testing.T) { + ctx := context.Background() + + // Create collection + db := NewDB() + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Fatal("expected no error, got", err) + } + if c == nil { + t.Fatal("expected collection, got nil") + } + + // Add documents + ids := []string{"1", "2"} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} + contents := []string{"hello world", "hallo welt"} + err = c.Add(context.Background(), ids, nil, metadatas, contents) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Get by ID + doc, err := c.GetByID(ctx, ids[0]) + if err != nil { + t.Fatal("expected nil, got", err) + } + // Check fields + if doc.ID != ids[0] { + t.Fatal("expected", ids[0], "got", doc.ID) + } + if len(doc.Metadata) != 1 { + t.Fatal("expected 1, got", len(doc.Metadata)) + } + if !slices.Equal(doc.Embedding, vectors) { + t.Fatal("expected", vectors, "got", doc.Embedding) + } + if doc.Content != contents[0] { + t.Fatal("expected", contents[0], "got", doc.Content) + } + + // Check error + _, err = c.GetByID(ctx, "3") + if err == nil { + t.Fatal("expected error, got nil") + } +} + func TestCollection_Count(t *testing.T) { // Create collection db := NewDB()