Skip to content

Commit

Permalink
add: google vertex ai embedding function
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed Jul 23, 2024
1 parent 7606de4 commit ed2161a
Showing 1 changed file with 159 additions and 0 deletions.
159 changes: 159 additions & 0 deletions embed_google.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package chromem

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sync"
)

type EmbeddingModelGoogle string

const (
EmbeddingModelGoogleEnglishV1 EmbeddingModelGoogle = "textembedding-gecko@001"
EmbeddingModelGoogleEnglishV2 EmbeddingModelGoogle = "textembedding-gecko@002"
EmbeddingModelGoogleEnglishV3 EmbeddingModelGoogle = "textembedding-gecko@003"
EmbeddingModelGoogleEnglishV4 EmbeddingModelGoogle = "text-embedding-004"

EmbeddingModelGoogleMultilingualV1 EmbeddingModelGoogle = "textembedding-gecko-multilingual@001"
EmbeddingModelGoogleMultilingualV2 EmbeddingModelGoogle = "text-multilingual-embedding-002"
)

const baseURLGoogle = "https://us-central1-aiplatform.googleapis.com/v1"

type GoogleOptions struct {
APIEndpoint string
AutoTruncate bool
}

func DefaultGoogleOptions() *GoogleOptions {
return &GoogleOptions{
APIEndpoint: baseURLGoogle,
AutoTruncate: false,
}
}

type GoogleOption func(*GoogleOptions)

func WithGoogleAPIEndpoint(apiEndpoint string) GoogleOption {
return func(o *GoogleOptions) {
o.APIEndpoint = apiEndpoint
}
}

func WithGoogleAutoTruncate(autoTruncate bool) GoogleOption {
return func(o *GoogleOptions) {
o.AutoTruncate = autoTruncate
}
}

type googleResponse struct {
Predictions []googlePrediction `json:"predictions"`
}

type googlePrediction struct {
Embeddings googleEmbeddings `json:"embeddings"`
}

type googleEmbeddings struct {
Values []float32 `json:"values"`
// there's more here, but we only care about the embeddings
}

func NewEmbeddingFuncGoogle(apiKey, project string, model EmbeddingModelGoogle, opts ...GoogleOption) EmbeddingFunc {

cfg := DefaultGoogleOptions()
for _, opt := range opts {
opt(cfg)
}

if cfg.APIEndpoint == "" {
cfg.APIEndpoint = baseURLGoogle
}

// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

var checkedNormalized bool
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {

b := map[string]any{
"instances": []map[string]any{
{
"content": text,
},
},
"parameters": map[string]any{
"autoTruncate": cfg.AutoTruncate,
},
}

// Prepare the request body.
reqBody, err := json.Marshal(b)
if err != nil {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}

fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", cfg.APIEndpoint, project, model)

// Create the request. Creating it with context is important for a timeout
// to be possible, because the client is configured without a timeout.
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("couldn't create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)

// Send the request.
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("couldn't send request: %w", err)
}
defer resp.Body.Close()

// Check the response status.
if resp.StatusCode != http.StatusOK {
return nil, errors.New("error response from the embedding API: " + resp.Status)
}

// Read and decode the response body.
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("couldn't read response body: %w", err)
}
var embeddingResponse googleResponse
err = json.Unmarshal(body, &embeddingResponse)
if err != nil {
return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
}

// Check if the response contains embeddings.
if len(embeddingResponse.Predictions) == 0 || len(embeddingResponse.Predictions[0].Embeddings.Values) == 0 {
return nil, errors.New("no embeddings found in the response")
}

v := embeddingResponse.Predictions[0].Embeddings.Values
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
} else {
checkedNormalized = false
}
})
if !checkedNormalized {
v = normalizeVector(v)
}

return v, nil
}
}

0 comments on commit ed2161a

Please sign in to comment.