Skip to content

Commit

Permalink
Merge pull request #61 from philippgille/add-cohere-embedding-provider
Browse files Browse the repository at this point in the history
Add Cohere embedding provider
  • Loading branch information
philippgille authored Mar 23, 2024
2 parents bb2271e + 8854814 commit f971ad9
Showing 1 changed file with 167 additions and 0 deletions.
167 changes: 167 additions & 0 deletions embed_cohere.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package chromem

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
)

type EmbeddingModelCohere string

const (
EmbeddingModelCohereMultilingualV2 EmbeddingModelCohere = "embed-multilingual-v2.0"
EmbeddingModelCohereEnglishLightV2 EmbeddingModelCohere = "embed-english-light-v2.0"
EmbeddingModelCohereEnglishV2 EmbeddingModelCohere = "embed-english-v2.0"
EmbeddingModelCohereMultilingualLightV3 EmbeddingModelCohere = "embed-multilingual-light-v3.0"
EmbeddingModelCohereEnglishLightV3 EmbeddingModelCohere = "embed-english-light-v3.0"
EmbeddingModelCohereMultilingualV3 EmbeddingModelCohere = "embed-multilingual-v3.0"
EmbeddingModelCohereEnglishV3 EmbeddingModelCohere = "embed-english-v3.0"
)

// Prefixes for external use.
const (
InputTypeCohereSearchDocumentPrefix string = "search_document: "
InputTypeCohereSearchQueryPrefix string = "search_query: "
InputTypeCohereClassificationPrefix string = "classification: "
InputTypeCohereClusteringPrefix string = "clustering: "
)

// Input types for internal use.
const (
inputTypeCohereSearchDocument string = "search_document"
inputTypeCohereSearchQuery string = "search_query"
inputTypeCohereClassification string = "classification"
inputTypeCohereClustering string = "clustering"
)

const baseURLCohere = "https://api.cohere.ai/v1"

var validInputTypesCohere = map[string]string{
inputTypeCohereSearchDocument: InputTypeCohereSearchDocumentPrefix,
inputTypeCohereSearchQuery: InputTypeCohereSearchQueryPrefix,
inputTypeCohereClassification: InputTypeCohereClassificationPrefix,
inputTypeCohereClustering: InputTypeCohereClusteringPrefix,
}

type cohereResponse struct {
Embeddings [][]float32 `json:"embeddings"`
}

// NewEmbeddingFuncCohere returns a function that creates embeddings for a text
// using Cohere's API. One important difference to OpenAI's and other's APIs is
// that Cohere differentiates between document embeddings and search/query embeddings.
// In order for this embedding func to do the differentiation, you have to prepend
// the text with either "search_document" or "search_query". We'll cut off that
// prefix before sending the document/query body to the API, we'll just use the
// prefix to choose the right "input type" as they call it.
//
// When you set up a chromem-go collection with this embedding function, you might
// want to create the document separately with [NewDocument] and then cut off the
// prefix before adding the document to the collection. Otherwise when you query
// the collection, the returned documents will still have the prefix in their content.
//
// cohereFunc := chromem.NewEmbeddingFuncCohere(cohereApiKey, chromem.EmbeddingModelCohereEnglishV3)
// content := "The sky is blue because of Rayleigh scattering."
// // Create the document with the prefix.
// contentWithPrefix := chromem.InputTypeCohereSearchDocumentPrefix + content
// doc, _ := NewDocument(ctx, id, metadata, nil, contentWithPrefix, cohereFunc)
// // Remove the prefix so that later query results don't have it.
// doc.Content = content
// _ = collection.AddDocument(ctx, doc)
//
// This is not necessary if you don't keep the content in the documents, as chromem-go
// also works when documents only have embeddings.
// You can also keep the prefix in the document, and only remove it after querying.
//
// We plan to improve this in the future.
func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) EmbeddingFunc {
// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

var checkedNormalized bool
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {
var inputType string
for validInputType, validInputTypePrefix := range validInputTypesCohere {
if strings.HasPrefix(text, validInputTypePrefix) {
inputType = validInputType
text = strings.TrimPrefix(text, validInputTypePrefix)
break
}
}
if inputType == "" {
return nil, errors.New("text must start with a valid input type plus colon and space")
}

// Prepare the request body.
reqBody, err := json.Marshal(map[string]any{
"model": model,
"texts": []string{text},
"input_type": inputType,
})
if err != nil {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}

// Create the request. Creating it with context is important for a timeout
// to be possible, because the client is configured without a timeout.
req, err := http.NewRequestWithContext(ctx, "POST", baseURLCohere+"/embed", bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("couldn't create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)

// Send the request.
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("couldn't send request: %w", err)
}
defer resp.Body.Close()

// Check the response status.
if resp.StatusCode != http.StatusOK {
return nil, errors.New("error response from the embedding API: " + resp.Status)
}

// Read and decode the response body.
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("couldn't read response body: %w", err)
}
var embeddingResponse cohereResponse
err = json.Unmarshal(body, &embeddingResponse)
if err != nil {
return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
}

// Check if the response contains embeddings.
if len(embeddingResponse.Embeddings) == 0 || len(embeddingResponse.Embeddings[0]) == 0 {
return nil, errors.New("no embeddings found in the response")
}

v := embeddingResponse.Embeddings[0]
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
} else {
checkedNormalized = false
}
})
if !checkedNormalized {
v = normalizeVector(v)
}

return v, nil
}
}

0 comments on commit f971ad9

Please sign in to comment.