From 51e0db56931cd127619b353ad7e2cc470273dfd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 28 Jul 2024 16:47:22 +0200 Subject: [PATCH 1/2] Make Vertex options hybrid This way we can still have the mandatory options directly in the embedding func constructor, making it more similar to the existing ones. --- embed_vertex.go | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/embed_vertex.go b/embed_vertex.go index 6efc667..38937da 100644 --- a/embed_vertex.go +++ b/embed_vertex.go @@ -25,32 +25,26 @@ const ( const baseURLVertex = "https://us-central1-aiplatform.googleapis.com/v1" -type vertexConfig struct { - apiKey string - project string - model EmbeddingModelVertex - - // Optional +type vertexOptions struct { apiEndpoint string autoTruncate bool } -func NewVertexConfig(apiKey, project string, model EmbeddingModelVertex) *vertexConfig { - return &vertexConfig{ - apiKey: apiKey, - project: project, - model: model, +// NewVertexOptions creates a new vertexOptions struct with default values. +// Use the `With...()` methods to change them. +func NewVertexOptions() *vertexOptions { + return &vertexOptions{ apiEndpoint: baseURLVertex, autoTruncate: false, } } -func (c *vertexConfig) WithAPIEndpoint(apiEndpoint string) *vertexConfig { +func (c *vertexOptions) WithAPIEndpoint(apiEndpoint string) *vertexOptions { c.apiEndpoint = apiEndpoint return c } -func (c *vertexConfig) WithAutoTruncate(autoTruncate bool) *vertexConfig { +func (c *vertexOptions) WithAutoTruncate(autoTruncate bool) *vertexOptions { c.autoTruncate = autoTruncate return c } @@ -68,7 +62,12 @@ type vertexEmbeddings struct { // there's more here, but we only care about the embeddings } -func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc { +// NewEmbeddingFuncVertex creates an EmbeddingFunc that uses the GCP Vertex API. +// For the opts you can pass nil to use the default options. +func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts *vertexOptions) EmbeddingFunc { + if opts == nil { + opts = NewVertexOptions() + } // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, @@ -79,7 +78,6 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc { checkNormalized := sync.Once{} return func(ctx context.Context, text string) ([]float32, error) { - b := map[string]any{ "instances": []map[string]any{ { @@ -87,7 +85,7 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc { }, }, "parameters": map[string]any{ - "autoTruncate": config.autoTruncate, + "autoTruncate": opts.autoTruncate, }, } @@ -97,7 +95,7 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc { return nil, fmt.Errorf("couldn't marshal request body: %w", err) } - fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", config.apiEndpoint, config.project, config.model) + fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", opts.apiEndpoint, project, model) // Create the request. Creating it with context is important for a timeout // to be possible, because the client is configured without a timeout. @@ -107,7 +105,7 @@ func NewEmbeddingFuncVertex(config *vertexConfig) EmbeddingFunc { } req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+config.apiKey) + req.Header.Set("Authorization", "Bearer "+apiKey) // Send the request. resp, err := client.Do(req) From a8ff96967be6ce4acfcf01c1b1a9fb85bf9d1a88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 28 Jul 2024 17:17:24 +0200 Subject: [PATCH 2/2] Rename NewVertexOptions to DefaultVertexOptions To make clear that the created struct is prefilled with default values --- embed_vertex.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/embed_vertex.go b/embed_vertex.go index 38937da..384b5b7 100644 --- a/embed_vertex.go +++ b/embed_vertex.go @@ -30,9 +30,9 @@ type vertexOptions struct { autoTruncate bool } -// NewVertexOptions creates a new vertexOptions struct with default values. +// DefaultVertexOptions creates a new vertexOptions struct with default values. // Use the `With...()` methods to change them. -func NewVertexOptions() *vertexOptions { +func DefaultVertexOptions() *vertexOptions { return &vertexOptions{ apiEndpoint: baseURLVertex, autoTruncate: false, @@ -66,7 +66,7 @@ type vertexEmbeddings struct { // For the opts you can pass nil to use the default options. func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts *vertexOptions) EmbeddingFunc { if opts == nil { - opts = NewVertexOptions() + opts = DefaultVertexOptions() } // We don't set a default timeout here, although it's usually a good idea.