Skip to content

Commit

Permalink
Merge pull request #97 from philippgille/add-get-by-id
Browse files Browse the repository at this point in the history
Add Collection.GetByID()
  • Loading branch information
philippgille authored Sep 1, 2024
2 parents a194428 + 366cbca commit f1e4956
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
26 changes: 26 additions & 0 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"maps"
"path/filepath"
"slices"
"sync"
Expand Down Expand Up @@ -291,6 +292,31 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
return nil
}

// GetByID returns a document by its ID.
// The returned document is a copy of the original document, so it can be safely
// modified without affecting the collection.
func (c *Collection) GetByID(ctx context.Context, id string) (Document, error) {
if id == "" {
return Document{}, errors.New("document ID is empty")
}

c.documentsLock.RLock()
defer c.documentsLock.RUnlock()

doc, ok := c.documents[id]
if ok {
// Clone the document
res := *doc
// Above copies the simple fields, but we need to copy the slices and maps
res.Metadata = maps.Clone(doc.Metadata)
res.Embedding = slices.Clone(doc.Embedding)

return res, nil
}

return Document{}, fmt.Errorf("document with ID '%v' not found", id)
}

// Delete removes document(s) from the collection.
//
// - where: Conditional filtering on metadata. Optional.
Expand Down
54 changes: 54 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,60 @@ func TestCollection_QueryError(t *testing.T) {
}
}

func TestCollection_Get(t *testing.T) {
ctx := context.Background()

// Create collection
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
if c == nil {
t.Fatal("expected collection, got nil")
}

// Add documents
ids := []string{"1", "2"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
contents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Fatal("expected nil, got", err)
}

// Get by ID
doc, err := c.GetByID(ctx, ids[0])
if err != nil {
t.Fatal("expected nil, got", err)
}
// Check fields
if doc.ID != ids[0] {
t.Fatal("expected", ids[0], "got", doc.ID)
}
if len(doc.Metadata) != 1 {
t.Fatal("expected 1, got", len(doc.Metadata))
}
if !slices.Equal(doc.Embedding, vectors) {
t.Fatal("expected", vectors, "got", doc.Embedding)
}
if doc.Content != contents[0] {
t.Fatal("expected", contents[0], "got", doc.Content)
}

// Check error
_, err = c.GetByID(ctx, "3")
if err == nil {
t.Fatal("expected error, got nil")
}
}

func TestCollection_Count(t *testing.T) {
// Create collection
db := NewDB()
Expand Down

0 comments on commit f1e4956

Please sign in to comment.