Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export Document struct and add Go-idiomatic methods for adding them to a collection #34

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 151 additions & 97 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Collection struct {

persistDirectory string
metadata map[string]string
documents map[string]*document
documents map[string]*Document
documentsLock sync.RWMutex
embed EmbeddingFunc
}
Expand All @@ -38,7 +38,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
Name: name,

metadata: m,
documents: make(map[string]*document),
documents: make(map[string]*Document),
embed: embed,
}

Expand Down Expand Up @@ -73,24 +73,166 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
//
// - ids: The ids of the embeddings you wish to add
// - embeddings: The embeddings to add. If nil, embeddings will be computed based
// on the documents using the embeddingFunc set for the Collection. Optional.
// on the contents using the embeddingFunc set for the Collection. Optional.
// - metadatas: The metadata to associate with the embeddings. When querying,
// you can filter on this metadata. Optional.
// - documents: The documents to associate with the embeddings.
// - contents: The contents to associate with the embeddings.
//
// A row-based API will be added when Chroma adds it (they already plan to).
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string) error {
return c.add(ctx, ids, documents, embeddings, metadatas, 1)
// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
}

// AddConcurrently is like Add, but adds embeddings concurrently.
// This is mostly useful when you don't pass any embeddings so they have to be created.
// Upon error, concurrently running operations are canceled and the error is returned.
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string, concurrency int) error {
//
// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
if len(ids) == 0 {
return errors.New("ids are empty")
}
if len(embeddings) == 0 && len(contents) == 0 {
return errors.New("either embeddings or contents must be filled")
}
if len(embeddings) != 0 {
if len(embeddings) != len(ids) {
return errors.New("ids and embeddings must have the same length")
}
} else {
// Assign empty slice so we can simply access via index later
embeddings = make([][]float32, len(ids))
}
if len(metadatas) != 0 && len(ids) != len(metadatas) {
return errors.New("ids, metadatas and contents must have the same length")
}
if len(contents) != 0 {
if len(contents) != len(ids) {
return errors.New("ids and contents must have the same length")
}
} else {
// Assign empty slice so we can simply access via index later
contents = make([]string, len(ids))
}
if concurrency < 1 {
return errors.New("concurrency must be at least 1")
}
return c.add(ctx, ids, documents, embeddings, metadatas, concurrency)

// Convert Chroma-style parameters into a slice of documents.
docs := make([]Document, 0, len(ids))
for i, id := range ids {
docs = append(docs, Document{
ID: id,
Metadata: metadatas[i],
Embedding: embeddings[i],
Content: contents[i],
})
}

return c.AddDocuments(ctx, docs, concurrency)
}

// AddDocuments adds documents to the collection with the specified concurrency.
// If the documents don't have embeddings, they will be created using the collection's
// embedding function.
// Upon error, concurrently running operations are canceled and the error is returned.
func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error {
if len(documents) == 0 {
// TODO: Should this be a no-op instead?
return errors.New("documents slice is nil or empty")
}
if concurrency < 1 {
return errors.New("concurrency must be at least 1")
}
// For other validations we rely on AddDocument.

var globalErr error
globalErrLock := sync.Mutex{}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
setGlobalErr := func(err error) {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
}

var wg sync.WaitGroup
semaphore := make(chan struct{}, concurrency)
for _, doc := range documents {
wg.Add(1)
go func(doc Document) {
defer wg.Done()

// Don't even start if another goroutine already failed.
if ctx.Err() != nil {
return
}

// Wait here while $concurrency other goroutines are creating documents.
semaphore <- struct{}{}
defer func() { <-semaphore }()

err := c.AddDocument(ctx, doc)
if err != nil {
setGlobalErr(fmt.Errorf("couldn't add document '%s': %w", doc.ID, err))
return
}
}(doc)
}

wg.Wait()

return globalErr
}

// AddDocument adds a document to the collection.
// If the document doesn't have an embedding, it will be created using the collection's
// embedding function.
func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
if doc.ID == "" {
return errors.New("document ID is empty")
}
if len(doc.Embedding) == 0 && doc.Content == "" {
return errors.New("either document embedding or content must be filled")
}

// We copy the metadata to avoid data races in case the caller modifies the
// map after creating the document while we range over it.
m := make(map[string]string, len(doc.Metadata))
for k, v := range doc.Metadata {
m[k] = v
}

// Create embedding if they don't exist
if len(doc.Embedding) == 0 {
embedding, err := c.embed(ctx, doc.Content)
if err != nil {
return fmt.Errorf("couldn't create embedding of document: %w", err)
}
doc.Embedding = embedding
}

c.documentsLock.Lock()
// We don't defer the unlock because we want to do it earlier.
c.documents[doc.ID] = &doc
c.documentsLock.Unlock()

// Persist the document
if c.persistDirectory != "" {
safeID := hash2hex(doc.ID)
filePath := path.Join(c.persistDirectory, safeID)
err := persist(filePath, doc)
if err != nil {
return fmt.Errorf("couldn't persist document: %w", err)
}
}

return nil
}

// Count returns the number of documents in the collection.
Expand Down Expand Up @@ -155,91 +297,3 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
// Return the top nResults
return res[:nResults], nil
}

func (c *Collection) add(ctx context.Context, ids []string, documents []string, embeddings [][]float32, metadatas []map[string]string, concurrency int) error {
if len(ids) == 0 || len(documents) == 0 {
return errors.New("ids and documents must not be empty")
}
if len(ids) != len(documents) {
return errors.New("ids and documents must have the same length")
}
if len(embeddings) != 0 && len(ids) != len(embeddings) {
return errors.New("ids, embeddings and documents must have the same length")
}
if len(metadatas) != 0 && len(ids) != len(metadatas) {
return errors.New("ids, metadatas and documents must have the same length")
}

ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

var wg sync.WaitGroup
var globalErr error
var globalErrLock sync.Mutex
semaphore := make(chan struct{}, concurrency)
for i, document := range documents {
var embedding []float32
var metadata map[string]string
if len(embeddings) != 0 {
embedding = embeddings[i]
}
if len(metadatas) != 0 {
metadata = metadatas[i]
}

wg.Add(1)
go func(id string, embedding []float32, metadata map[string]string, document string) {
defer wg.Done()

// Don't even start if we already have an error
if ctx.Err() != nil {
return
}

// Wait here while $concurrency other goroutines are creating documents.
semaphore <- struct{}{}
defer func() { <-semaphore }()

err := c.addRow(ctx, id, document, embedding, metadata)
if err != nil {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
return
}
}(ids[i], embedding, metadata, document)
}

wg.Wait()

return globalErr
}

func (c *Collection) addRow(ctx context.Context, id string, document string, embedding []float32, metadata map[string]string) error {
doc, err := newDocument(ctx, id, embedding, metadata, document, c.embed)
if err != nil {
return fmt.Errorf("couldn't create document '%s': %w", id, err)
}

c.documentsLock.Lock()
// We don't defer the unlock because we want to do it earlier.
c.documents[id] = doc
c.documentsLock.Unlock()

// Persist the document
if c.persistDirectory != "" {
safeID := hash2hex(id)
filePath := path.Join(c.persistDirectory, safeID)
err := persist(filePath, doc)
if err != nil {
return fmt.Errorf("couldn't persist document: %w", err)
}
}

return nil
}
8 changes: 4 additions & 4 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func TestCollection_Add(t *testing.T) {
// Add document
ids := []string{"1", "2"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
documents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, documents)
contents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Error("expected nil, got", err)
}
Expand All @@ -54,8 +54,8 @@ func TestCollection_Count(t *testing.T) {
// Add documents
ids := []string{"1", "2"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
documents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, documents)
contents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Error("expected nil, got", err)
}
Expand Down
8 changes: 4 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
"sync"
)

// EmbeddingFunc is a function that creates embeddings for a given document.
// EmbeddingFunc is a function that creates embeddings for a given text.
// chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
// but you can provide your own function, using any model you like.
type EmbeddingFunc func(ctx context.Context, document string) ([]float32, error)
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)

// DB is the chromem-go database. It holds collections, which hold documents.
//
Expand Down Expand Up @@ -91,7 +91,7 @@ func NewPersistentDB(path string) (*DB, error) {
c := &Collection{
// We can fill Name, persistDirectory and metadata only after reading
// the metadata.
documents: make(map[string]*document),
documents: make(map[string]*Document),
// We can fill embed only when the user calls DB.GetCollection() or
// DB.GetOrCreateCollection().
}
Expand Down Expand Up @@ -119,7 +119,7 @@ func NewPersistentDB(path string) (*DB, error) {
c.metadata = pc.Metadata
} else if filepath.Ext(collectionDirEntry.Name()) == ".gob" {
// Read document
d := &document{}
d := &Document{}
err := read(fPath, d)
if err != nil {
return nil, fmt.Errorf("couldn't read document: %w", err)
Expand Down
Loading