Skip to content

Commit

Permalink
Merge pull request #74 from iwilltry42/feat/azure
Browse files Browse the repository at this point in the history
add: AzureOpenAI compatibility
  • Loading branch information
philippgille authored May 13, 2024
2 parents 4c3f500 + 88f1efa commit aa151c6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
15 changes: 15 additions & 0 deletions embed_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,18 @@ const baseURLLocalAI = "http://localhost:8080/v1"
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil)
}

const (
azureDefaultAPIVersion = "2024-02-01"
)

// NewEmbeddingFuncAzureOpenAI returns a function that creates embeddings for a text
// using the Azure OpenAI API.
// The `deploymentURL` is the URL of the deployed model, e.g. "https://YOUR_RESOURCE_NAME.openai.azure.com/openai/deployments/YOUR_DEPLOYMENT_NAME"
// See https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#how-to-get-embeddings
func NewEmbeddingFuncAzureOpenAI(apiKey string, deploymentURL string, apiVersion string, model string) EmbeddingFunc {
if apiVersion == "" {
apiVersion = azureDefaultAPIVersion
}
return newEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, map[string]string{"api-key": apiKey}, map[string]string{"api-version": apiVersion})
}
25 changes: 25 additions & 0 deletions embed_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding
// The flag is optional. If it's nil, it will be autodetected on the first request
// (which bears a small risk that the vector just happens to have a length of 1).
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc {
return newEmbeddingFuncOpenAICompat(baseURL, apiKey, model, normalized, nil, nil)
}

// newEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
// using an OpenAI compatible API.
// It offers options to set request headers and query parameters
// e.g. to pass the `api-key` header and the `api-version` query parameter for Azure OpenAI.
//
// The `normalized` parameter indicates whether the vectors returned by the embedding
// model are already normalized, as is the case for OpenAI's and Mistral's models.
// The flag is optional. If it's nil, it will be autodetected on the first request
// (which bears a small risk that the vector just happens to have a length of 1).
func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, headers map[string]string, queryParams map[string]string) EmbeddingFunc {
// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the text length.
Expand Down Expand Up @@ -84,6 +97,18 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)

// Add headers
for k, v := range headers {
req.Header.Add(k, v)
}

// Add query parameters
q := req.URL.Query()
for k, v := range queryParams {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()

// Send the request.
resp, err := client.Do(req)
if err != nil {
Expand Down

0 comments on commit aa151c6

Please sign in to comment.