Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor authentication process and update tests accordingly #693

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 2 additions & 28 deletions api_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package openai_test
import (
"context"
"errors"
"io"
"os"
"testing"

Expand All @@ -26,7 +25,7 @@ func TestAPI(t *testing.T) {
_, err = c.ListEngines(ctx)
checks.NoError(t, err, "ListEngines error")

_, err = c.GetEngine(ctx, "davinci")
_, err = c.GetEngine(ctx, "text-embedding-3-small")
checks.NoError(t, err, "GetEngine error")

fileRes, err := c.ListFiles(ctx)
Expand All @@ -42,7 +41,7 @@ func TestAPI(t *testing.T) {
"The food was delicious and the waiter",
"Other examples of embedding request",
},
Model: openai.AdaSearchQuery,
Model: openai.SmallEmbedding3,
}
_, err = c.CreateEmbeddings(ctx, embeddingReq)
checks.NoError(t, err, "Embedding error")
Expand Down Expand Up @@ -77,31 +76,6 @@ func TestAPI(t *testing.T) {
)
checks.NoError(t, err, "CreateChatCompletion (with name) returned error")

stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{
Prompt: "Ex falso quodlibet",
Model: openai.GPT3Ada,
MaxTokens: 5,
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

counter := 0
for {
_, err = stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
t.Errorf("Stream error: %v", err)
} else {
counter++
}
}
if counter == 0 {
t.Error("Stream did not return any responses")
}

_, err = c.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Expand Down
16 changes: 9 additions & 7 deletions api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,55 +41,57 @@ func TestRequestAuthHeader(t *testing.T) {
Name string
APIType APIType
HeaderKey string
Token string
Token AuthBuilder
OrgID string
Expect string
}{
{
"OpenAIDefault",
"",
"Authorization",
"dummy-token-openai",
APIKey("dummy-token-openai"),
"",
"Bearer dummy-token-openai",
},
{
"OpenAIOrg",
APITypeOpenAI,
"Authorization",
"dummy-token-openai",
APIKey("dummy-token-openai"),
"dummy-org-openai",
"Bearer dummy-token-openai",
},
{
"OpenAI",
APITypeOpenAI,
"Authorization",
"dummy-token-openai",
APIKey("dummy-token-openai"),
"",
"Bearer dummy-token-openai",
},
{
"AzureAD",
APITypeAzureAD,
"Authorization",
"dummy-token-azure",
APIKey("dummy-token-azure"),
"",
"Bearer dummy-token-azure",
},
{
"Azure",
APITypeAzure,
AzureAPIKeyHeader,
"dummy-api-key-here",
AzureAPIKey("dummy-api-key-here"),
"",
"dummy-api-key-here",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultConfig(c.Token)
az := ClientConfig{
AuthToken: c.Token,
}
az.APIType = c.APIType
az.OrgID = c.OrgID

Expand Down
12 changes: 4 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

// Client is OpenAI GPT-3 API client.
type Client struct {
config ClientConfig
config *ClientConfig

requestBuilder utils.RequestBuilder
createFormBuilder func(io.Writer) utils.FormBuilder
Expand Down Expand Up @@ -47,7 +47,7 @@ func NewClient(authToken string) *Client {
// NewClientWithConfig creates new OpenAI API client for specified config.
func NewClientWithConfig(config ClientConfig) *Client {
return &Client{
config: config,
config: &config,
requestBuilder: utils.NewRequestBuilder(),
createFormBuilder: func(body io.Writer) utils.FormBuilder {
return utils.NewFormBuilder(body)
Expand Down Expand Up @@ -173,12 +173,8 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else if c.config.authToken != "" {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
c.config.AuthToken(req)

if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}
Expand Down
15 changes: 13 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net/http"
"reflect"
"strings"
"testing"

"github.com/sashabaranov/go-openai/internal/test"
Expand All @@ -22,16 +23,26 @@ func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ htt
return nil, errTestRequestBuilderFailed
}

func getToken(authToken AuthBuilder, transformer func(h http.Header) string) string {
req, _ := http.NewRequest("GET", "http://example.com", nil)
authToken(req)
return transformer(req.Header)
}

func getBearerToken(h http.Header) string {
return strings.ReplaceAll(h.Get("Authorization"), "Bearer ", "")
}

func TestClient(t *testing.T) {
const mockToken = "mock token"
client := NewClient(mockToken)
if client.config.authToken != mockToken {
if getToken(client.config.AuthToken, getBearerToken) != mockToken {
t.Errorf("Client does not contain proper token")
}

const mockOrg = "mock org"
client = NewOrgClient(mockToken, mockOrg)
if client.config.authToken != mockToken {
if getToken(client.config.AuthToken, getBearerToken) != mockToken {
t.Errorf("Client does not contain proper token")
}
if client.config.OrgID != mockOrg {
Expand Down
22 changes: 18 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai

import (
"fmt"
"net/http"
"regexp"
)
Expand All @@ -23,10 +24,23 @@ const (

const AzureAPIKeyHeader = "api-key"

type AuthBuilder func(req *http.Request)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we make it an interface with setAuthTokenForRequest func


func APIKey(authToken string) AuthBuilder {
return func(req *http.Request) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
}
}

func AzureAPIKey(apiKey string) AuthBuilder {
return func(req *http.Request) {
req.Header.Set(AzureAPIKeyHeader, apiKey)
}
}

// ClientConfig is a configuration of a client.
type ClientConfig struct {
authToken string

AuthToken AuthBuilder
BaseURL string
OrgID string
APIType APIType
Expand All @@ -39,7 +53,7 @@ type ClientConfig struct {

func DefaultConfig(authToken string) ClientConfig {
return ClientConfig{
authToken: authToken,
AuthToken: APIKey(authToken),
BaseURL: openaiAPIURLv1,
APIType: APITypeOpenAI,
OrgID: "",
Expand All @@ -52,7 +66,7 @@ func DefaultConfig(authToken string) ClientConfig {

func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
return ClientConfig{
authToken: apiKey,
AuthToken: AzureAPIKey(apiKey),
BaseURL: baseURL,
OrgID: "",
APIType: APITypeAzure,
Expand Down
Loading