diff --git a/collection.go b/collection.go index 4e45f2a..d3395cd 100644 --- a/collection.go +++ b/collection.go @@ -19,7 +19,7 @@ type Collection struct { persistDirectory string metadata map[string]string - documents map[string]*document + documents map[string]*Document documentsLock sync.RWMutex embed EmbeddingFunc } @@ -38,7 +38,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, Name: name, metadata: m, - documents: make(map[string]*document), + documents: make(map[string]*Document), embed: embed, } @@ -73,24 +73,166 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, // // - ids: The ids of the embeddings you wish to add // - embeddings: The embeddings to add. If nil, embeddings will be computed based -// on the documents using the embeddingFunc set for the Collection. Optional. +// on the contents using the embeddingFunc set for the Collection. Optional. // - metadatas: The metadata to associate with the embeddings. When querying, // you can filter on this metadata. Optional. -// - documents: The documents to associate with the embeddings. +// - contents: The contents to associate with the embeddings. // -// A row-based API will be added when Chroma adds it (they already plan to). -func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string) error { - return c.add(ctx, ids, documents, embeddings, metadatas, 1) +// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments]. +func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error { + return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1) } // AddConcurrently is like Add, but adds embeddings concurrently. // This is mostly useful when you don't pass any embeddings so they have to be created. // Upon error, concurrently running operations are canceled and the error is returned. -func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string, concurrency int) error { +// +// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments]. +func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error { + if len(ids) == 0 { + return errors.New("ids are empty") + } + if len(embeddings) == 0 && len(contents) == 0 { + return errors.New("either embeddings or contents must be filled") + } + if len(embeddings) != 0 { + if len(embeddings) != len(ids) { + return errors.New("ids and embeddings must have the same length") + } + } else { + // Assign empty slice so we can simply access via index later + embeddings = make([][]float32, len(ids)) + } + if len(metadatas) != 0 && len(ids) != len(metadatas) { + return errors.New("ids, metadatas and contents must have the same length") + } + if len(contents) != 0 { + if len(contents) != len(ids) { + return errors.New("ids and contents must have the same length") + } + } else { + // Assign empty slice so we can simply access via index later + contents = make([]string, len(ids)) + } if concurrency < 1 { return errors.New("concurrency must be at least 1") } - return c.add(ctx, ids, documents, embeddings, metadatas, concurrency) + + // Convert Chroma-style parameters into a slice of documents. + docs := make([]Document, 0, len(ids)) + for i, id := range ids { + docs = append(docs, Document{ + ID: id, + Metadata: metadatas[i], + Embedding: embeddings[i], + Content: contents[i], + }) + } + + return c.AddDocuments(ctx, docs, concurrency) +} + +// AddDocuments adds documents to the collection with the specified concurrency. +// If the documents don't have embeddings, they will be created using the collection's +// embedding function. +// Upon error, concurrently running operations are canceled and the error is returned. +func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error { + if len(documents) == 0 { + // TODO: Should this be a no-op instead? + return errors.New("documents slice is nil or empty") + } + if concurrency < 1 { + return errors.New("concurrency must be at least 1") + } + // For other validations we rely on AddDocument. + + var globalErr error + globalErrLock := sync.Mutex{} + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + setGlobalErr := func(err error) { + globalErrLock.Lock() + defer globalErrLock.Unlock() + // Another goroutine might have already set the error. + if globalErr == nil { + globalErr = err + // Cancel the operation for all other goroutines. + cancel(globalErr) + } + } + + var wg sync.WaitGroup + semaphore := make(chan struct{}, concurrency) + for _, doc := range documents { + wg.Add(1) + go func(doc Document) { + defer wg.Done() + + // Don't even start if another goroutine already failed. + if ctx.Err() != nil { + return + } + + // Wait here while $concurrency other goroutines are creating documents. + semaphore <- struct{}{} + defer func() { <-semaphore }() + + err := c.AddDocument(ctx, doc) + if err != nil { + setGlobalErr(fmt.Errorf("couldn't add document '%s': %w", doc.ID, err)) + return + } + }(doc) + } + + wg.Wait() + + return globalErr +} + +// AddDocument adds a document to the collection. +// If the document doesn't have an embedding, it will be created using the collection's +// embedding function. +func (c *Collection) AddDocument(ctx context.Context, doc Document) error { + if doc.ID == "" { + return errors.New("document ID is empty") + } + if len(doc.Embedding) == 0 && doc.Content == "" { + return errors.New("either document embedding or content must be filled") + } + + // We copy the metadata to avoid data races in case the caller modifies the + // map after creating the document while we range over it. + m := make(map[string]string, len(doc.Metadata)) + for k, v := range doc.Metadata { + m[k] = v + } + + // Create embedding if they don't exist + if len(doc.Embedding) == 0 { + embedding, err := c.embed(ctx, doc.Content) + if err != nil { + return fmt.Errorf("couldn't create embedding of document: %w", err) + } + doc.Embedding = embedding + } + + c.documentsLock.Lock() + // We don't defer the unlock because we want to do it earlier. + c.documents[doc.ID] = &doc + c.documentsLock.Unlock() + + // Persist the document + if c.persistDirectory != "" { + safeID := hash2hex(doc.ID) + filePath := path.Join(c.persistDirectory, safeID) + err := persist(filePath, doc) + if err != nil { + return fmt.Errorf("couldn't persist document: %w", err) + } + } + + return nil } // Count returns the number of documents in the collection. @@ -155,91 +297,3 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, // Return the top nResults return res[:nResults], nil } - -func (c *Collection) add(ctx context.Context, ids []string, documents []string, embeddings [][]float32, metadatas []map[string]string, concurrency int) error { - if len(ids) == 0 || len(documents) == 0 { - return errors.New("ids and documents must not be empty") - } - if len(ids) != len(documents) { - return errors.New("ids and documents must have the same length") - } - if len(embeddings) != 0 && len(ids) != len(embeddings) { - return errors.New("ids, embeddings and documents must have the same length") - } - if len(metadatas) != 0 && len(ids) != len(metadatas) { - return errors.New("ids, metadatas and documents must have the same length") - } - - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - - var wg sync.WaitGroup - var globalErr error - var globalErrLock sync.Mutex - semaphore := make(chan struct{}, concurrency) - for i, document := range documents { - var embedding []float32 - var metadata map[string]string - if len(embeddings) != 0 { - embedding = embeddings[i] - } - if len(metadatas) != 0 { - metadata = metadatas[i] - } - - wg.Add(1) - go func(id string, embedding []float32, metadata map[string]string, document string) { - defer wg.Done() - - // Don't even start if we already have an error - if ctx.Err() != nil { - return - } - - // Wait here while $concurrency other goroutines are creating documents. - semaphore <- struct{}{} - defer func() { <-semaphore }() - - err := c.addRow(ctx, id, document, embedding, metadata) - if err != nil { - globalErrLock.Lock() - defer globalErrLock.Unlock() - // Another goroutine might have already set the error. - if globalErr == nil { - globalErr = err - // Cancel the operation for all other goroutines. - cancel(globalErr) - } - return - } - }(ids[i], embedding, metadata, document) - } - - wg.Wait() - - return globalErr -} - -func (c *Collection) addRow(ctx context.Context, id string, document string, embedding []float32, metadata map[string]string) error { - doc, err := newDocument(ctx, id, embedding, metadata, document, c.embed) - if err != nil { - return fmt.Errorf("couldn't create document '%s': %w", id, err) - } - - c.documentsLock.Lock() - // We don't defer the unlock because we want to do it earlier. - c.documents[id] = doc - c.documentsLock.Unlock() - - // Persist the document - if c.persistDirectory != "" { - safeID := hash2hex(id) - filePath := path.Join(c.persistDirectory, safeID) - err := persist(filePath, doc) - if err != nil { - return fmt.Errorf("couldn't persist document: %w", err) - } - } - - return nil -} diff --git a/collection_test.go b/collection_test.go index 472e119..8b1059a 100644 --- a/collection_test.go +++ b/collection_test.go @@ -26,8 +26,8 @@ func TestCollection_Add(t *testing.T) { // Add document ids := []string{"1", "2"} metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} - documents := []string{"hello world", "hallo welt"} - err = c.Add(context.Background(), ids, nil, metadatas, documents) + contents := []string{"hello world", "hallo welt"} + err = c.Add(context.Background(), ids, nil, metadatas, contents) if err != nil { t.Error("expected nil, got", err) } @@ -54,8 +54,8 @@ func TestCollection_Count(t *testing.T) { // Add documents ids := []string{"1", "2"} metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} - documents := []string{"hello world", "hallo welt"} - err = c.Add(context.Background(), ids, nil, metadatas, documents) + contents := []string{"hello world", "hallo welt"} + err = c.Add(context.Background(), ids, nil, metadatas, contents) if err != nil { t.Error("expected nil, got", err) } diff --git a/db.go b/db.go index 0f9ac92..42ab698 100644 --- a/db.go +++ b/db.go @@ -10,10 +10,10 @@ import ( "sync" ) -// EmbeddingFunc is a function that creates embeddings for a given document. +// EmbeddingFunc is a function that creates embeddings for a given text. // chromem-go will use OpenAI`s "text-embedding-3-small" model by default, // but you can provide your own function, using any model you like. -type EmbeddingFunc func(ctx context.Context, document string) ([]float32, error) +type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error) // DB is the chromem-go database. It holds collections, which hold documents. // @@ -91,7 +91,7 @@ func NewPersistentDB(path string) (*DB, error) { c := &Collection{ // We can fill Name, persistDirectory and metadata only after reading // the metadata. - documents: make(map[string]*document), + documents: make(map[string]*Document), // We can fill embed only when the user calls DB.GetCollection() or // DB.GetOrCreateCollection(). } @@ -119,7 +119,7 @@ func NewPersistentDB(path string) (*DB, error) { c.metadata = pc.Metadata } else if filepath.Ext(collectionDirEntry.Name()) == ".gob" { // Read document - d := &document{} + d := &Document{} err := read(fPath, d) if err != nil { return nil, fmt.Errorf("couldn't read document: %w", err) diff --git a/document.go b/document.go index 49b8a38..4fe1f54 100644 --- a/document.go +++ b/document.go @@ -2,38 +2,49 @@ package chromem import ( "context" + "errors" ) -type document struct { - ID string - Metadata map[string]string - Document string - - Vectors []float32 +// Document represents a single document. +type Document struct { + ID string + Metadata map[string]string + Embedding []float32 + Content string } -// newDocument creates a new document, including its embeddings. +// NewDocument creates a new document, including its embeddings. +// Metadata is optional. // If the embeddings are not provided, they are created using the embedding function. -func newDocument(ctx context.Context, id string, embeddings []float32, metadata map[string]string, doc string, embed EmbeddingFunc) (*document, error) { - if len(embeddings) == 0 { - vectors, err := embed(ctx, doc) - if err != nil { - return nil, err - } - embeddings = vectors +// You can leave the content empty if you only want to store embeddings. +// If embeddingFunc is nil, the default embedding function is used. +// +// If you want to create a document without embeddings, for example to let [Collection.AddDocuments] +// create them concurrently, you can create a document with `chromem.Document{...}` +// instead of using this constructor. +func NewDocument(ctx context.Context, id string, metadata map[string]string, embedding []float32, content string, embeddingFunc EmbeddingFunc) (Document, error) { + if id == "" { + return Document{}, errors.New("id is empty") + } + if len(embedding) == 0 && content == "" { + return Document{}, errors.New("either embedding or content must be filled") } - // We copy the metadata to avoid data races in case the caller modifies the - // map after creating the document while we range over it. - m := make(map[string]string, len(metadata)) - for k, v := range metadata { - m[k] = v + if embeddingFunc == nil { + embeddingFunc = NewEmbeddingFuncDefault() } - return &document{ - ID: id, - Metadata: metadata, - Document: doc, + if len(embedding) == 0 { + var err error + embedding, err = embeddingFunc(ctx, content) + if err != nil { + return Document{}, err + } + } - Vectors: embeddings, + return Document{ + ID: id, + Metadata: metadata, + Embedding: embedding, + Content: content, }, nil } diff --git a/embed_compat.go b/embed_compat.go index f82e091..a3d18c9 100644 --- a/embed_compat.go +++ b/embed_compat.go @@ -6,7 +6,7 @@ const ( embeddingModelMistral = "mistral-embed" ) -// NewEmbeddingFuncMistral returns a function that creates embeddings for a document +// NewEmbeddingFuncMistral returns a function that creates embeddings for a text // using the Mistral API. func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc { // The Mistral API docs don't mention the `encoding_format` as optional, @@ -25,7 +25,7 @@ const ( EmbeddingModelJina2BaseZH EmbeddingModelJina = "jina-embeddings-v2-base-zh" ) -// NewEmbeddingFuncJina returns a function that creates embeddings for a document +// NewEmbeddingFuncJina returns a function that creates embeddings for a text // using the Jina API. func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc { return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model)) @@ -46,7 +46,7 @@ const ( EmbeddingModelMixedbreadGTELargeZh EmbeddingModelMixedbread = "gte-large-zh" ) -// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a document +// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text // using the mixedbread.ai API. func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc { return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model)) @@ -54,7 +54,7 @@ func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) E const baseURLLocalAI = "http://localhost:8080/v1" -// NewEmbeddingFuncLocalAI returns a function that creates embeddings for a document +// NewEmbeddingFuncLocalAI returns a function that creates embeddings for a text // using the LocalAI API. // You can start a LocalAI instance like this: // diff --git a/embed_ollama.go b/embed_ollama.go index aa14905..2e5f718 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -16,21 +16,21 @@ type ollamaResponse struct { Embedding []float32 `json:"embedding"` } -// NewEmbeddingFuncOllama returns a function that creates embeddings for a document +// NewEmbeddingFuncOllama returns a function that creates embeddings for a text // using Ollama's embedding API. You can pass any model that Ollama supports and // that supports embeddings. A good one as of 2024-03-02 is "nomic-embed-text". // See https://ollama.com/library/nomic-embed-text func NewEmbeddingFuncOllama(model string) EmbeddingFunc { // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, - // and it might have to be a long timeout, depending on the document size. + // and it might have to be a long timeout, depending on the text length. client := &http.Client{} - return func(ctx context.Context, document string) ([]float32, error) { + return func(ctx context.Context, text string) ([]float32, error) { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ "model": model, - "prompt": document, + "prompt": text, }) if err != nil { return nil, fmt.Errorf("couldn't marshal request body: %w", err) diff --git a/embed_openai.go b/embed_openai.go index 681db86..2d301c7 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -27,22 +27,22 @@ type openAIResponse struct { } `json:"data"` } -// NewEmbeddingFuncDefault returns a function that creates embeddings for a document +// NewEmbeddingFuncDefault returns a function that creates embeddings for a text // using OpenAI`s "text-embedding-3-small" model via their API. -// The model supports a maximum document length of 8191 tokens. +// The model supports a maximum text length of 8191 tokens. // The API key is read from the environment variable "OPENAI_API_KEY". func NewEmbeddingFuncDefault() EmbeddingFunc { apiKey := os.Getenv("OPENAI_API_KEY") return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small) } -// NewEmbeddingFuncOpenAI returns a function that creates embeddings for a document +// NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text // using the OpenAI API. func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model)) } -// NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a document +// NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text // using an OpenAI compatible API. For example: // - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service // - LitLLM: https://github.com/BerriAI/litellm @@ -51,13 +51,13 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc { // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, - // and it might have to be a long timeout, depending on the document size. + // and it might have to be a long timeout, depending on the text length. client := &http.Client{} - return func(ctx context.Context, document string) ([]float32, error) { + return func(ctx context.Context, text string) ([]float32, error) { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ - "input": document, + "input": text, "model": model, }) if err != nil { diff --git a/embed_openai_test.go b/embed_openai_test.go index af729b5..70c62f9 100644 --- a/embed_openai_test.go +++ b/embed_openai_test.go @@ -24,10 +24,10 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { apiKey := "secret" model := "model-small" baseURLSuffix := "/v1" - document := "hello world" + input := "hello world" wantBody, err := json.Marshal(map[string]string{ - "input": document, + "input": input, "model": model, }) if err != nil { @@ -76,7 +76,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { baseURL := ts.URL + baseURLSuffix f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model) - res, err := f(context.Background(), document) + res, err := f(context.Background(), input) if err != nil { t.Error("expected nil, got", err) } diff --git a/example/main.go b/example/main.go index 0e5e632..35e710d 100644 --- a/example/main.go +++ b/example/main.go @@ -54,6 +54,7 @@ func main() { } // Add docs to the collection, if the collection was just created (and not // loaded from persistent storage). + docs := []chromem.Document{} if collection.Count() == 0 { // Here we use a DBpedia sample, where each line contains the lead section/introduction // to some Wikipedia article and its category. @@ -62,9 +63,6 @@ func main() { panic(err) } d := json.NewDecoder(f) - var ids []string - var metadatas []map[string]string - var texts []string log.Println("Reading JSON lines...") for i := 1; ; i++ { var article struct { @@ -78,12 +76,14 @@ func main() { panic(err) } - ids = append(ids, strconv.Itoa(i)) - metadatas = append(metadatas, map[string]string{"category": article.Category}) - texts = append(texts, article.Text) + docs = append(docs, chromem.Document{ + ID: strconv.Itoa(i), + Metadata: map[string]string{"category": article.Category}, + Content: article.Text, + }) } log.Println("Adding documents to chromem-go, including creating their embeddings via Ollama API...") - err = collection.AddConcurrently(ctx, ids, nil, metadatas, texts, runtime.NumCPU()) + err = collection.AddDocuments(ctx, docs, runtime.NumCPU()) if err != nil { panic(err) } @@ -105,12 +105,12 @@ func main() { // Print the retrieved documents and their similarity to the question. for i, res := range docRes { - log.Printf("Document %d (similarity: %f): \"%s\"\n", i+1, res.Similarity, res.Document) + log.Printf("Document %d (similarity: %f): \"%s\"\n", i+1, res.Similarity, res.Content) } // Now we can ask the LLM again, augmenting the question with the knowledge we retrieved. // In this example we just use both retrieved documents as context. - contexts := []string{docRes[0].Document, docRes[1].Document} + contexts := []string{docRes[0].Content, docRes[1].Content} log.Println("Asking LLM with augmented question...") reply = askLLM(ctx, contexts, question) log.Printf("Reply after augmenting the question with knowledge: \"" + reply + "\"\n") diff --git a/query.go b/query.go index 317a729..4d15f60 100644 --- a/query.go +++ b/query.go @@ -2,6 +2,7 @@ package chromem import ( "context" + "fmt" "runtime" "strings" "sync" @@ -12,9 +13,9 @@ var supportedFilters = []string{"$contains", "$not_contains"} // Result represents a single result from a query. type Result struct { ID string - Embedding []float32 Metadata map[string]string - Document string + Embedding []float32 + Content string // The cosine similarity between the query and the document. // The higher the value, the more similar the document is to the query. @@ -24,8 +25,8 @@ type Result struct { // filterDocs filters a map of documents by metadata and content. // It does this concurrently. -func filterDocs(docs map[string]*document, where, whereDocument map[string]string) []*document { - filteredDocs := make([]*document, 0, len(docs)) +func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { + filteredDocs := make([]*Document, 0, len(docs)) filteredDocsLock := sync.Mutex{} // Determine concurrency. Use number of docs or CPUs, whichever is smaller. @@ -36,7 +37,7 @@ func filterDocs(docs map[string]*document, where, whereDocument map[string]strin concurrency = numDocs } - docChan := make(chan *document, concurrency*2) + docChan := make(chan *Document, concurrency*2) wg := sync.WaitGroup{} for i := 0; i < concurrency; i++ { @@ -65,7 +66,7 @@ func filterDocs(docs map[string]*document, where, whereDocument map[string]strin // documentMatchesFilters checks if a document matches the given filters. // When calling this function, the whereDocument keys must already be validated! -func documentMatchesFilters(document *document, where, whereDocument map[string]string) bool { +func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool { // A document's metadata must have *all* the fields in the where clause. for k, v := range where { // TODO: Do we want to check for existence of the key? I.e. should @@ -80,11 +81,11 @@ func documentMatchesFilters(document *document, where, whereDocument map[string] for k, v := range whereDocument { switch k { case "$contains": - if !strings.Contains(document.Document, v) { + if !strings.Contains(document.Content, v) { return false } case "$not_contains": - if strings.Contains(document.Document, v) { + if strings.Contains(document.Content, v) { return false } default: @@ -97,7 +98,7 @@ func documentMatchesFilters(document *document, where, whereDocument map[string] return true } -func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*document) ([]Result, error) { +func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]Result, error) { res := make([]Result, len(docs)) resLock := sync.Mutex{} @@ -109,14 +110,23 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*docu concurrency = numDocs } - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - - docChan := make(chan *document, concurrency*2) var globalErr error globalErrLock := sync.Mutex{} + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + setGlobalErr := func(err error) { + globalErrLock.Lock() + defer globalErrLock.Unlock() + // Another goroutine might have already set the error. + if globalErr == nil { + globalErr = err + // Cancel the operation for all other goroutines. + cancel(globalErr) + } + } wg := sync.WaitGroup{} + docChan := make(chan *Document, concurrency*2) for i := 0; i < concurrency; i++ { wg.Add(1) go func() { @@ -127,16 +137,9 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*docu return } - sim, err := cosineSimilarity(queryVectors, doc.Vectors) + sim, err := cosineSimilarity(queryVectors, doc.Embedding) if err != nil { - globalErrLock.Lock() - defer globalErrLock.Unlock() - // Another goroutine might have already set the error. - if globalErr == nil { - globalErr = err - // Cancel the operation for all other goroutines. - cancel(globalErr) - } + setGlobalErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) return } @@ -144,9 +147,9 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*docu // We don't defer the unlock because we want to unlock much earlier. res = append(res, Result{ ID: doc.ID, - Embedding: doc.Vectors, Metadata: doc.Metadata, - Document: doc.Document, + Embedding: doc.Embedding, + Content: doc.Content, Similarity: sim, })