diff --git a/openai/assistant.go b/openai/assistant.go index 6f1c2d2..47e158c 100644 --- a/openai/assistant.go +++ b/openai/assistant.go @@ -116,8 +116,8 @@ type AssistantToolResources struct { type AssistantResponseFormat struct { StringValue string `json:"-"` - Type string `json:"type"` - JsonSchema *struct { + Type string `json:"type"` + JsonSchema *struct { Description *string `json:"description,omitempty"` Name string `json:"name"` Schema map[string]interface{} `json:"schema"` diff --git a/openai/client.go b/openai/client.go index c82a07d..4ffd1c1 100644 --- a/openai/client.go +++ b/openai/client.go @@ -16,7 +16,8 @@ const ( // Client - OpenAI client. type Client struct { - authToken string + authToken string + adminToken string BaseURL *url.URL OrganizationID string @@ -25,10 +26,11 @@ type Client struct { } // NewClient creates new OpenAI client. -func NewClient(authToken string) *Client { +func NewClient(authToken string, adminToken string) *Client { c := &Client{ HTTPClient: &http.Client{Timeout: 30 * time.Second}, authToken: authToken, + adminToken: adminToken, UserAgent: "skyscrapr/openai-sdk-go", } c.BaseURL, _ = url.Parse(apiURL) diff --git a/openai/endpoint.go b/openai/endpoint.go index e5cf5f3..12884d0 100644 --- a/openai/endpoint.go +++ b/openai/endpoint.go @@ -1,6 +1,7 @@ package openai import ( + "fmt" "net/http" "net/url" "path" @@ -25,6 +26,10 @@ type betaEndpoint struct { endpoint } +type organizationEndpoint struct { + endpoint +} + func newEndpoint(c *Client, endpointPath string) *endpoint { e := &endpoint{ Client: c, @@ -40,6 +45,13 @@ func newBetaEndpoint(c *Client, endpointPath string) *betaEndpoint { return e } +func newOrganizationEndpoint(c *Client, endpointPath string) *organizationEndpoint { + e := &organizationEndpoint{ + endpoint: *newEndpoint(c, endpointPath), + } + return e +} + func (e *endpoint) buildURL(endpointPath string) (*url.URL, error) { u, err := url.Parse(endpointPath) if err != nil { @@ -64,3 +76,9 @@ func (e *betaEndpoint) newRequest(method string, u *url.URL, body interface{}) ( req.Header.Set("OpenAI-Beta", "assistants=v2") return req, err } + +func (e *organizationEndpoint) newRequest(method string, u *url.URL, body interface{}) (*http.Request, error) { + req, err := e.Client.newRequest(method, u, body) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", e.adminToken)) + return req, err +} diff --git a/openai/endpoint_test.go b/openai/endpoint_test.go index d6828e2..abced05 100644 --- a/openai/endpoint_test.go +++ b/openai/endpoint_test.go @@ -6,7 +6,7 @@ import ( func TestNewEndpoint(t *testing.T) { testEndpointPath := "testEndpointPath" - testClient := NewClient("testapikey") + testClient := NewClient("testapikey", "testadminkey") e := newEndpoint(testClient, testEndpointPath) if e.BaseURL.String() != testClient.BaseURL.String() { t.Errorf("VendorsEndpoint BaseURL mismatch. Got %s. Want %s", e.BaseURL.String(), testClient.BaseURL.String()) diff --git a/openai/projects.go b/openai/projects.go new file mode 100644 index 0000000..675f352 --- /dev/null +++ b/openai/projects.go @@ -0,0 +1,90 @@ +package openai + +import ( + "fmt" + "net/url" +) + +const ProjectsEndpointPath = "/organization/projects" + +// ProjectsEndpoint - OpenAI Projects Endpoint +// +// List and describe the projects available. +// You can refer to the [Projects]: https://platform.openai.com/docs/api-reference/projects documentation. +type ProjectsEndpoint struct { + *organizationEndpoint +} + +// Projects - Projects Endpoint +func (c *Client) Projects() *ProjectsEndpoint { + return &ProjectsEndpoint{newOrganizationEndpoint(c, ProjectsEndpointPath)} +} + +// Project - OpenAPI Project. +type Project struct { + ID string `json:"id"` + Object string `json:"object"` + Name string `json:"name"` + CreatedAt int64 `json:"created_at"` + ArchivedAt int64 `json:"archived_at"` + Status string `json:"status"` +} + +type Projects struct { + Object string `json:"object"` + Data []Project `json:"data"` +} + +type ProjectRequest struct { + // The name of the assistant. The maximum length is 256 characters. + Name *string `json:"name"` +} + +// Lists the currently available projects, +// and provides basic information about each one. +// +// [OpenAI Documentation]: https://platform.openai.com/docs/api-reference/projects/list +func (e *ProjectsEndpoint) ListProjects() ([]Project, error) { + var projects Projects + err := e.do(e, "GET", "", nil, nil, &projects) + // TODO: This needs to move somewhere central + if err == nil && projects.Object != "list" { + err = fmt.Errorf("expected 'list' object type, got %s", projects.Object) + } + return projects.Data, err +} + +// Create a project. +// [OpenAI Documentation]: https://platform.openai.com/docs/api-reference/projects/create +func (e *ProjectsEndpoint) CreateProject(req *ProjectRequest) (*Project, error) { + var project Project + err := e.do(e, "POST", "", req, nil, &project) + return &project, err +} + +// Retrieves a project instance, +// providing basic information about the project. +// +// [OpenAI Documentation]: https://platform.openai.com/docs/api-reference/projects/retrieve +func (e *ProjectsEndpoint) RetrieveProject(id string) (*Project, error) { + var project Project + err := e.do(e, "GET", id, nil, nil, &project) + return &project, err +} + +// Modifies a project instance, +// +// [OpenAI Documentation]: https://platform.openai.com/docs/api-reference/projects/modify +func (e *ProjectsEndpoint) ModifyProject(id string, req ProjectRequest) (*Project, error) { + var project Project + err := e.do(e, "POST", id, req, nil, &project) + return &project, err +} + +// Archive a project. +// [OpenAI Documentation]: https://platform.openai.com/docs/api-reference/projects/archive +func (e *ProjectsEndpoint) ArchiveProject(id string) (*Project, error) { + var project Project + err := e.do(e, "POST", url.QueryEscape(id)+"/archive", nil, nil, &project) + return &project, err +} diff --git a/openai/projects_test.go b/openai/projects_test.go new file mode 100644 index 0000000..d7d3782 --- /dev/null +++ b/openai/projects_test.go @@ -0,0 +1,70 @@ +package openai_test + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/skyscrapr/openai-sdk-go/openai" + "github.com/skyscrapr/openai-sdk-go/openai/test" +) + +// TestListProjects Tests the Projects endpoint of the API using the mocked server. +func TestListProjects(t *testing.T) { + ts := openai_test.NewTestServer() + ts.RegisterHandler("/v1/organizations/projects", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Projects{Object: "list", Data: nil}) + fmt.Fprintln(w, string(resBytes)) + }) + ts.HTTPServer.Start() + defer ts.HTTPServer.Close() + + client := openai_test.NewTestClient(ts) + _, err := client.Projects().ListProjects() + t.Helper() + if err != nil { + t.Error(err, "TestListProjects error") + } +} + +func TestListProjectsInvalidObject(t *testing.T) { + expectedError := "expected 'list' object type, got project" + + ts := openai_test.NewTestServer() + ts.RegisterHandler("/v1/organizations/projects", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Projects{Object: "project", Data: nil}) + fmt.Fprintln(w, string(resBytes)) + }) + ts.HTTPServer.Start() + defer ts.HTTPServer.Close() + + client := openai_test.NewTestClient(ts) + _, err := client.Projects().ListProjects() + t.Helper() + if err != nil && err.Error() != expectedError { + t.Errorf("Unexpected error: %v , expected: %s", err, expectedError) + t.Fail() + } +} + +func TestRetrieveProject(t *testing.T) { + testProjectID := "testProjectID" + ts := openai_test.NewTestServer() + ts.RegisterHandler("/v1/organizations/projects/testProjectID", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Project{Object: "project", ID: testProjectID}) + fmt.Fprintln(w, string(resBytes)) + }) + ts.HTTPServer.Start() + defer ts.HTTPServer.Close() + + client := openai_test.NewTestClient(ts) + project, err := client.Projects().RetrieveProject(testProjectID) + t.Helper() + if err != nil { + t.Error(err, "GetProject error") + } + if project.ID != testProjectID { + t.Errorf("ProjectsEndpoint GetProject Project ID mismatch. Got %s. Expected %s", testProjectID, project.ID) + } +} diff --git a/openai/test/client.go b/openai/test/client.go index c5d53a6..b04ba63 100644 --- a/openai/test/client.go +++ b/openai/test/client.go @@ -6,13 +6,14 @@ import ( ) const test_api_key = "this-is-my-secure-apikey-do-not-steal!!" +const test_admin_key = "this-is-my-secure-adminkey-do-not-steal!!" func GetTestAuthToken() string { return test_api_key } func NewTestClient(ts *TestServer) *openai.Client { - client := openai.NewClient(test_api_key) + client := openai.NewClient(test_api_key, test_admin_key) if ts != nil { client.BaseURL, _ = url.Parse(ts.HTTPServer.URL) }