diff --git a/embed_google.go b/embed_google.go new file mode 100644 index 0000000..e042d95 --- /dev/null +++ b/embed_google.go @@ -0,0 +1,159 @@ +package chromem + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sync" +) + +type EmbeddingModelGoogle string + +const ( + EmbeddingModelGoogleEnglishV1 EmbeddingModelGoogle = "textembedding-gecko@001" + EmbeddingModelGoogleEnglishV2 EmbeddingModelGoogle = "textembedding-gecko@002" + EmbeddingModelGoogleEnglishV3 EmbeddingModelGoogle = "textembedding-gecko@003" + EmbeddingModelGoogleEnglishV4 EmbeddingModelGoogle = "text-embedding-004" + + EmbeddingModelGoogleMultilingualV1 EmbeddingModelGoogle = "textembedding-gecko-multilingual@001" + EmbeddingModelGoogleMultilingualV2 EmbeddingModelGoogle = "text-multilingual-embedding-002" +) + +const baseURLGoogle = "https://us-central1-aiplatform.googleapis.com/v1" + +type GoogleOptions struct { + APIEndpoint string + AutoTruncate bool +} + +func DefaultGoogleOptions() *GoogleOptions { + return &GoogleOptions{ + APIEndpoint: baseURLGoogle, + AutoTruncate: false, + } +} + +type GoogleOption func(*GoogleOptions) + +func WithGoogleAPIEndpoint(apiEndpoint string) GoogleOption { + return func(o *GoogleOptions) { + o.APIEndpoint = apiEndpoint + } +} + +func WithGoogleAutoTruncate(autoTruncate bool) GoogleOption { + return func(o *GoogleOptions) { + o.AutoTruncate = autoTruncate + } +} + +type googleResponse struct { + Predictions []googlePrediction `json:"predictions"` +} + +type googlePrediction struct { + Embeddings googleEmbeddings `json:"embeddings"` +} + +type googleEmbeddings struct { + Values []float32 `json:"values"` + // there's more here, but we only care about the embeddings +} + +func NewEmbeddingFuncGoogle(apiKey, project string, model EmbeddingModelGoogle, opts ...GoogleOption) EmbeddingFunc { + + cfg := DefaultGoogleOptions() + for _, opt := range opts { + opt(cfg) + } + + if cfg.APIEndpoint == "" { + cfg.APIEndpoint = baseURLGoogle + } + + // We don't set a default timeout here, although it's usually a good idea. + // In our case though, the library user can set the timeout on the context, + // and it might have to be a long timeout, depending on the text length. + client := &http.Client{} + + var checkedNormalized bool + checkNormalized := sync.Once{} + + return func(ctx context.Context, text string) ([]float32, error) { + + b := map[string]any{ + "instances": []map[string]any{ + { + "content": text, + }, + }, + "parameters": map[string]any{ + "autoTruncate": cfg.AutoTruncate, + }, + } + + // Prepare the request body. + reqBody, err := json.Marshal(b) + if err != nil { + return nil, fmt.Errorf("couldn't marshal request body: %w", err) + } + + fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", cfg.APIEndpoint, project, model) + + // Create the request. Creating it with context is important for a timeout + // to be possible, because the client is configured without a timeout. + req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("couldn't create request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + // Send the request. + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("couldn't send request: %w", err) + } + defer resp.Body.Close() + + // Check the response status. + if resp.StatusCode != http.StatusOK { + return nil, errors.New("error response from the embedding API: " + resp.Status) + } + + // Read and decode the response body. + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("couldn't read response body: %w", err) + } + var embeddingResponse googleResponse + err = json.Unmarshal(body, &embeddingResponse) + if err != nil { + return nil, fmt.Errorf("couldn't unmarshal response body: %w", err) + } + + // Check if the response contains embeddings. + if len(embeddingResponse.Predictions) == 0 || len(embeddingResponse.Predictions[0].Embeddings.Values) == 0 { + return nil, errors.New("no embeddings found in the response") + } + + v := embeddingResponse.Predictions[0].Embeddings.Values + checkNormalized.Do(func() { + if isNormalized(v) { + checkedNormalized = true + } else { + checkedNormalized = false + } + }) + if !checkedNormalized { + v = normalizeVector(v) + } + + return v, nil + } +}