Skip to content

Commit

Permalink
Merge pull request #34 from philippgille/add-go-idiomatic-methods
Browse files Browse the repository at this point in the history
Export Document struct and add Go-idiomatic methods for adding them to a collection
  • Loading branch information
philippgille authored Mar 4, 2024
2 parents a2df4ad + cb0fe2f commit ed5dca6
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 180 deletions.
248 changes: 151 additions & 97 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Collection struct {

persistDirectory string
metadata map[string]string
documents map[string]*document
documents map[string]*Document
documentsLock sync.RWMutex
embed EmbeddingFunc
}
Expand All @@ -38,7 +38,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
Name: name,

metadata: m,
documents: make(map[string]*document),
documents: make(map[string]*Document),
embed: embed,
}

Expand Down Expand Up @@ -73,24 +73,166 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
//
// - ids: The ids of the embeddings you wish to add
// - embeddings: The embeddings to add. If nil, embeddings will be computed based
// on the documents using the embeddingFunc set for the Collection. Optional.
// on the contents using the embeddingFunc set for the Collection. Optional.
// - metadatas: The metadata to associate with the embeddings. When querying,
// you can filter on this metadata. Optional.
// - documents: The documents to associate with the embeddings.
// - contents: The contents to associate with the embeddings.
//
// A row-based API will be added when Chroma adds it (they already plan to).
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string) error {
return c.add(ctx, ids, documents, embeddings, metadatas, 1)
// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
}

// AddConcurrently is like Add, but adds embeddings concurrently.
// This is mostly useful when you don't pass any embeddings so they have to be created.
// Upon error, concurrently running operations are canceled and the error is returned.
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string, concurrency int) error {
//
// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
if len(ids) == 0 {
return errors.New("ids are empty")
}
if len(embeddings) == 0 && len(contents) == 0 {
return errors.New("either embeddings or contents must be filled")
}
if len(embeddings) != 0 {
if len(embeddings) != len(ids) {
return errors.New("ids and embeddings must have the same length")
}
} else {
// Assign empty slice so we can simply access via index later
embeddings = make([][]float32, len(ids))
}
if len(metadatas) != 0 && len(ids) != len(metadatas) {
return errors.New("ids, metadatas and contents must have the same length")
}
if len(contents) != 0 {
if len(contents) != len(ids) {
return errors.New("ids and contents must have the same length")
}
} else {
// Assign empty slice so we can simply access via index later
contents = make([]string, len(ids))
}
if concurrency < 1 {
return errors.New("concurrency must be at least 1")
}
return c.add(ctx, ids, documents, embeddings, metadatas, concurrency)

// Convert Chroma-style parameters into a slice of documents.
docs := make([]Document, 0, len(ids))
for i, id := range ids {
docs = append(docs, Document{
ID: id,
Metadata: metadatas[i],
Embedding: embeddings[i],
Content: contents[i],
})
}

return c.AddDocuments(ctx, docs, concurrency)
}

// AddDocuments adds documents to the collection with the specified concurrency.
// If the documents don't have embeddings, they will be created using the collection's
// embedding function.
// Upon error, concurrently running operations are canceled and the error is returned.
func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error {
if len(documents) == 0 {
// TODO: Should this be a no-op instead?
return errors.New("documents slice is nil or empty")
}
if concurrency < 1 {
return errors.New("concurrency must be at least 1")
}
// For other validations we rely on AddDocument.

var globalErr error
globalErrLock := sync.Mutex{}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
setGlobalErr := func(err error) {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
}

var wg sync.WaitGroup
semaphore := make(chan struct{}, concurrency)
for _, doc := range documents {
wg.Add(1)
go func(doc Document) {
defer wg.Done()

// Don't even start if another goroutine already failed.
if ctx.Err() != nil {
return
}

// Wait here while $concurrency other goroutines are creating documents.
semaphore <- struct{}{}
defer func() { <-semaphore }()

err := c.AddDocument(ctx, doc)
if err != nil {
setGlobalErr(fmt.Errorf("couldn't add document '%s': %w", doc.ID, err))
return
}
}(doc)
}

wg.Wait()

return globalErr
}

// AddDocument adds a document to the collection.
// If the document doesn't have an embedding, it will be created using the collection's
// embedding function.
func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
if doc.ID == "" {
return errors.New("document ID is empty")
}
if len(doc.Embedding) == 0 && doc.Content == "" {
return errors.New("either document embedding or content must be filled")
}

// We copy the metadata to avoid data races in case the caller modifies the
// map after creating the document while we range over it.
m := make(map[string]string, len(doc.Metadata))
for k, v := range doc.Metadata {
m[k] = v
}

// Create embedding if they don't exist
if len(doc.Embedding) == 0 {
embedding, err := c.embed(ctx, doc.Content)
if err != nil {
return fmt.Errorf("couldn't create embedding of document: %w", err)
}
doc.Embedding = embedding
}

c.documentsLock.Lock()
// We don't defer the unlock because we want to do it earlier.
c.documents[doc.ID] = &doc
c.documentsLock.Unlock()

// Persist the document
if c.persistDirectory != "" {
safeID := hash2hex(doc.ID)
filePath := path.Join(c.persistDirectory, safeID)
err := persist(filePath, doc)
if err != nil {
return fmt.Errorf("couldn't persist document: %w", err)
}
}

return nil
}

// Count returns the number of documents in the collection.
Expand Down Expand Up @@ -155,91 +297,3 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
// Return the top nResults
return res[:nResults], nil
}

func (c *Collection) add(ctx context.Context, ids []string, documents []string, embeddings [][]float32, metadatas []map[string]string, concurrency int) error {
if len(ids) == 0 || len(documents) == 0 {
return errors.New("ids and documents must not be empty")
}
if len(ids) != len(documents) {
return errors.New("ids and documents must have the same length")
}
if len(embeddings) != 0 && len(ids) != len(embeddings) {
return errors.New("ids, embeddings and documents must have the same length")
}
if len(metadatas) != 0 && len(ids) != len(metadatas) {
return errors.New("ids, metadatas and documents must have the same length")
}

ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

var wg sync.WaitGroup
var globalErr error
var globalErrLock sync.Mutex
semaphore := make(chan struct{}, concurrency)
for i, document := range documents {
var embedding []float32
var metadata map[string]string
if len(embeddings) != 0 {
embedding = embeddings[i]
}
if len(metadatas) != 0 {
metadata = metadatas[i]
}

wg.Add(1)
go func(id string, embedding []float32, metadata map[string]string, document string) {
defer wg.Done()

// Don't even start if we already have an error
if ctx.Err() != nil {
return
}

// Wait here while $concurrency other goroutines are creating documents.
semaphore <- struct{}{}
defer func() { <-semaphore }()

err := c.addRow(ctx, id, document, embedding, metadata)
if err != nil {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
return
}
}(ids[i], embedding, metadata, document)
}

wg.Wait()

return globalErr
}

func (c *Collection) addRow(ctx context.Context, id string, document string, embedding []float32, metadata map[string]string) error {
doc, err := newDocument(ctx, id, embedding, metadata, document, c.embed)
if err != nil {
return fmt.Errorf("couldn't create document '%s': %w", id, err)
}

c.documentsLock.Lock()
// We don't defer the unlock because we want to do it earlier.
c.documents[id] = doc
c.documentsLock.Unlock()

// Persist the document
if c.persistDirectory != "" {
safeID := hash2hex(id)
filePath := path.Join(c.persistDirectory, safeID)
err := persist(filePath, doc)
if err != nil {
return fmt.Errorf("couldn't persist document: %w", err)
}
}

return nil
}
8 changes: 4 additions & 4 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func TestCollection_Add(t *testing.T) {
// Add document
ids := []string{"1", "2"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
documents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, documents)
contents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Error("expected nil, got", err)
}
Expand All @@ -54,8 +54,8 @@ func TestCollection_Count(t *testing.T) {
// Add documents
ids := []string{"1", "2"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
documents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, documents)
contents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Error("expected nil, got", err)
}
Expand Down
8 changes: 4 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
"sync"
)

// EmbeddingFunc is a function that creates embeddings for a given document.
// EmbeddingFunc is a function that creates embeddings for a given text.
// chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
// but you can provide your own function, using any model you like.
type EmbeddingFunc func(ctx context.Context, document string) ([]float32, error)
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)

// DB is the chromem-go database. It holds collections, which hold documents.
//
Expand Down Expand Up @@ -91,7 +91,7 @@ func NewPersistentDB(path string) (*DB, error) {
c := &Collection{
// We can fill Name, persistDirectory and metadata only after reading
// the metadata.
documents: make(map[string]*document),
documents: make(map[string]*Document),
// We can fill embed only when the user calls DB.GetCollection() or
// DB.GetOrCreateCollection().
}
Expand Down Expand Up @@ -119,7 +119,7 @@ func NewPersistentDB(path string) (*DB, error) {
c.metadata = pc.Metadata
} else if filepath.Ext(collectionDirEntry.Name()) == ".gob" {
// Read document
d := &document{}
d := &Document{}
err := read(fPath, d)
if err != nil {
return nil, fmt.Errorf("couldn't read document: %w", err)
Expand Down
Loading

0 comments on commit ed5dca6

Please sign in to comment.