-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #61 from philippgille/add-cohere-embedding-provider
Add Cohere embedding provider
- Loading branch information
Showing
1 changed file
with
167 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
package chromem | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strings" | ||
"sync" | ||
) | ||
|
||
type EmbeddingModelCohere string | ||
|
||
const ( | ||
EmbeddingModelCohereMultilingualV2 EmbeddingModelCohere = "embed-multilingual-v2.0" | ||
EmbeddingModelCohereEnglishLightV2 EmbeddingModelCohere = "embed-english-light-v2.0" | ||
EmbeddingModelCohereEnglishV2 EmbeddingModelCohere = "embed-english-v2.0" | ||
EmbeddingModelCohereMultilingualLightV3 EmbeddingModelCohere = "embed-multilingual-light-v3.0" | ||
EmbeddingModelCohereEnglishLightV3 EmbeddingModelCohere = "embed-english-light-v3.0" | ||
EmbeddingModelCohereMultilingualV3 EmbeddingModelCohere = "embed-multilingual-v3.0" | ||
EmbeddingModelCohereEnglishV3 EmbeddingModelCohere = "embed-english-v3.0" | ||
) | ||
|
||
// Prefixes for external use. | ||
const ( | ||
InputTypeCohereSearchDocumentPrefix string = "search_document: " | ||
InputTypeCohereSearchQueryPrefix string = "search_query: " | ||
InputTypeCohereClassificationPrefix string = "classification: " | ||
InputTypeCohereClusteringPrefix string = "clustering: " | ||
) | ||
|
||
// Input types for internal use. | ||
const ( | ||
inputTypeCohereSearchDocument string = "search_document" | ||
inputTypeCohereSearchQuery string = "search_query" | ||
inputTypeCohereClassification string = "classification" | ||
inputTypeCohereClustering string = "clustering" | ||
) | ||
|
||
const baseURLCohere = "https://api.cohere.ai/v1" | ||
|
||
var validInputTypesCohere = map[string]string{ | ||
inputTypeCohereSearchDocument: InputTypeCohereSearchDocumentPrefix, | ||
inputTypeCohereSearchQuery: InputTypeCohereSearchQueryPrefix, | ||
inputTypeCohereClassification: InputTypeCohereClassificationPrefix, | ||
inputTypeCohereClustering: InputTypeCohereClusteringPrefix, | ||
} | ||
|
||
type cohereResponse struct { | ||
Embeddings [][]float32 `json:"embeddings"` | ||
} | ||
|
||
// NewEmbeddingFuncCohere returns a function that creates embeddings for a text | ||
// using Cohere's API. One important difference to OpenAI's and other's APIs is | ||
// that Cohere differentiates between document embeddings and search/query embeddings. | ||
// In order for this embedding func to do the differentiation, you have to prepend | ||
// the text with either "search_document" or "search_query". We'll cut off that | ||
// prefix before sending the document/query body to the API, we'll just use the | ||
// prefix to choose the right "input type" as they call it. | ||
// | ||
// When you set up a chromem-go collection with this embedding function, you might | ||
// want to create the document separately with [NewDocument] and then cut off the | ||
// prefix before adding the document to the collection. Otherwise when you query | ||
// the collection, the returned documents will still have the prefix in their content. | ||
// | ||
// cohereFunc := chromem.NewEmbeddingFuncCohere(cohereApiKey, chromem.EmbeddingModelCohereEnglishV3) | ||
// content := "The sky is blue because of Rayleigh scattering." | ||
// // Create the document with the prefix. | ||
// contentWithPrefix := chromem.InputTypeCohereSearchDocumentPrefix + content | ||
// doc, _ := NewDocument(ctx, id, metadata, nil, contentWithPrefix, cohereFunc) | ||
// // Remove the prefix so that later query results don't have it. | ||
// doc.Content = content | ||
// _ = collection.AddDocument(ctx, doc) | ||
// | ||
// This is not necessary if you don't keep the content in the documents, as chromem-go | ||
// also works when documents only have embeddings. | ||
// You can also keep the prefix in the document, and only remove it after querying. | ||
// | ||
// We plan to improve this in the future. | ||
func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) EmbeddingFunc { | ||
// We don't set a default timeout here, although it's usually a good idea. | ||
// In our case though, the library user can set the timeout on the context, | ||
// and it might have to be a long timeout, depending on the text length. | ||
client := &http.Client{} | ||
|
||
var checkedNormalized bool | ||
checkNormalized := sync.Once{} | ||
|
||
return func(ctx context.Context, text string) ([]float32, error) { | ||
var inputType string | ||
for validInputType, validInputTypePrefix := range validInputTypesCohere { | ||
if strings.HasPrefix(text, validInputTypePrefix) { | ||
inputType = validInputType | ||
text = strings.TrimPrefix(text, validInputTypePrefix) | ||
break | ||
} | ||
} | ||
if inputType == "" { | ||
return nil, errors.New("text must start with a valid input type plus colon and space") | ||
} | ||
|
||
// Prepare the request body. | ||
reqBody, err := json.Marshal(map[string]any{ | ||
"model": model, | ||
"texts": []string{text}, | ||
"input_type": inputType, | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("couldn't marshal request body: %w", err) | ||
} | ||
|
||
// Create the request. Creating it with context is important for a timeout | ||
// to be possible, because the client is configured without a timeout. | ||
req, err := http.NewRequestWithContext(ctx, "POST", baseURLCohere+"/embed", bytes.NewBuffer(reqBody)) | ||
if err != nil { | ||
return nil, fmt.Errorf("couldn't create request: %w", err) | ||
} | ||
req.Header.Set("Accept", "application/json") | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Authorization", "Bearer "+apiKey) | ||
|
||
// Send the request. | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return nil, fmt.Errorf("couldn't send request: %w", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
// Check the response status. | ||
if resp.StatusCode != http.StatusOK { | ||
return nil, errors.New("error response from the embedding API: " + resp.Status) | ||
} | ||
|
||
// Read and decode the response body. | ||
body, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return nil, fmt.Errorf("couldn't read response body: %w", err) | ||
} | ||
var embeddingResponse cohereResponse | ||
err = json.Unmarshal(body, &embeddingResponse) | ||
if err != nil { | ||
return nil, fmt.Errorf("couldn't unmarshal response body: %w", err) | ||
} | ||
|
||
// Check if the response contains embeddings. | ||
if len(embeddingResponse.Embeddings) == 0 || len(embeddingResponse.Embeddings[0]) == 0 { | ||
return nil, errors.New("no embeddings found in the response") | ||
} | ||
|
||
v := embeddingResponse.Embeddings[0] | ||
checkNormalized.Do(func() { | ||
if isNormalized(v) { | ||
checkedNormalized = true | ||
} else { | ||
checkedNormalized = false | ||
} | ||
}) | ||
if !checkedNormalized { | ||
v = normalizeVector(v) | ||
} | ||
|
||
return v, nil | ||
} | ||
} |