Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Vertex options hybrid #92

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions embed_vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,26 @@ const (

const baseURLVertex = "https://us-central1-aiplatform.googleapis.com/v1"

type vertexConfig struct {
apiKey string
project string
model EmbeddingModelVertex

// Optional
type vertexOptions struct {
apiEndpoint string
autoTruncate bool
}

func NewVertexConfig(apiKey, project string, model EmbeddingModelVertex) *vertexConfig {
return &vertexConfig{
apiKey: apiKey,
project: project,
model: model,
// DefaultVertexOptions creates a new vertexOptions struct with default values.
// Use the `With...()` methods to change them.
func DefaultVertexOptions() *vertexOptions {
return &vertexOptions{
apiEndpoint: baseURLVertex,
autoTruncate: false,
}
}

func (c *vertexConfig) WithAPIEndpoint(apiEndpoint string) *vertexConfig {
func (c *vertexOptions) WithAPIEndpoint(apiEndpoint string) *vertexOptions {
c.apiEndpoint = apiEndpoint
return c
}

func (c *vertexConfig) WithAutoTruncate(autoTruncate bool) *vertexConfig {
func (c *vertexOptions) WithAutoTruncate(autoTruncate bool) *vertexOptions {
c.autoTruncate = autoTruncate
return c
}
Expand All @@ -68,7 +62,12 @@ type vertexEmbeddings struct {
// there's more here, but we only care about the embeddings
}

func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {
// NewEmbeddingFuncVertex creates an EmbeddingFunc that uses the GCP Vertex API.
// For the opts you can pass nil to use the default options.
func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts *vertexOptions) EmbeddingFunc {
if opts == nil {
opts = DefaultVertexOptions()
}

// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
Expand All @@ -79,15 +78,14 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {

b := map[string]any{
"instances": []map[string]any{
{
"content": text,
},
},
"parameters": map[string]any{
"autoTruncate": config.autoTruncate,
"autoTruncate": opts.autoTruncate,
},
}

Expand All @@ -97,7 +95,7 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}

fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", config.apiEndpoint, config.project, config.model)
fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", opts.apiEndpoint, project, model)

// Create the request. Creating it with context is important for a timeout
// to be possible, because the client is configured without a timeout.
Expand All @@ -107,7 +105,7 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc {
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+config.apiKey)
req.Header.Set("Authorization", "Bearer "+apiKey)

// Send the request.
resp, err := client.Do(req)
Expand Down