Skip to content

Commit

Permalink
feat: sdkserver: add credential routes (#846)
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville authored Sep 13, 2024
1 parent c0f116c commit 90e7868
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pkg/cli/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (c *Credential) Run(cmd *cobra.Command, _ []string) error {

ctx := c.root.CredentialContext
if c.AllContexts {
ctx = "*"
ctx = credentials.AllCredentialContexts
}

opts, err := c.root.NewGPTScriptOpts()
Expand Down
9 changes: 7 additions & 2 deletions pkg/credentials/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import (
"github.com/gptscript-ai/gptscript/pkg/config"
)

const (
DefaultCredentialContext = "default"
AllCredentialContexts = "*"
)

type CredentialBuilder interface {
EnsureCredentialHelpers(ctx context.Context) error
}
Expand Down Expand Up @@ -105,7 +110,7 @@ func (s Store) List(ctx context.Context) ([]Credential, error) {
if err != nil {
return nil, err
}
if s.credCtx == "*" || c.Context == s.credCtx {
if s.credCtx == AllCredentialContexts || c.Context == s.credCtx {
creds = append(creds, c)
}
}
Expand Down Expand Up @@ -139,7 +144,7 @@ func validateCredentialCtx(ctx string) error {
return fmt.Errorf("credential context cannot be empty")
}

if ctx == "*" { // this represents "all contexts" and is allowed
if ctx == AllCredentialContexts {
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func Complete(opts ...Options) Options {
result.Env = os.Environ()
}
if result.CredentialContext == "" {
result.CredentialContext = "default"
result.CredentialContext = credentials.DefaultCredentialContext
}

return result
Expand Down
176 changes: 176 additions & 0 deletions pkg/sdkserver/credentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package sdkserver

import (
"encoding/json"
"fmt"
"net/http"

"github.com/gptscript-ai/gptscript/pkg/config"
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
)

func (s *server) initializeCredentialStore(ctx string) (credentials.CredentialStore, error) {
cfg, err := config.ReadCLIConfig(s.gptscriptOpts.OpenAI.ConfigFile)
if err != nil {
return nil, fmt.Errorf("failed to read CLI config: %w", err)
}

// TODO - are we sure we want to always use runtimes.Default here?
store, err := credentials.NewStore(cfg, runtimes.Default(s.gptscriptOpts.Cache.CacheDir), ctx, s.gptscriptOpts.Cache.CacheDir)
if err != nil {
return nil, fmt.Errorf("failed to initialize credential store: %w", err)
}

return store, nil
}

func (s *server) listCredentials(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
req := new(credentialsRequest)
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err))
return
}

if req.AllContexts {
req.Context = credentials.AllCredentialContexts
} else if req.Context == "" {
req.Context = credentials.DefaultCredentialContext
}

store, err := s.initializeCredentialStore(req.Context)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, err)
return
}

creds, err := store.List(r.Context())
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to list credentials: %w", err))
return
}

// Remove the environment variable values (which are secrets) and refresh tokens from the response.
for i := range creds {
for k := range creds[i].Env {
creds[i].Env[k] = ""
}
creds[i].RefreshToken = ""
}

writeResponse(logger, w, map[string]any{"stdout": creds})
}

func (s *server) createCredential(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
req := new(credentialsRequest)
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err))
return
}

cred := new(credentials.Credential)
if err := json.Unmarshal([]byte(req.Content), cred); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid credential: %w", err))
return
}

if cred.Context == "" {
cred.Context = credentials.DefaultCredentialContext
}

store, err := s.initializeCredentialStore(cred.Context)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, err)
return
}

if err := store.Add(r.Context(), *cred); err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to create credential: %w", err))
return
}

writeResponse(logger, w, map[string]any{"stdout": "Credential created successfully"})
}

func (s *server) revealCredential(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
req := new(credentialsRequest)
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err))
return
}

if req.Name == "" {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("missing credential name"))
return
}

if req.AllContexts || req.Context == credentials.AllCredentialContexts {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("allContexts is not supported for credential retrieval; please specify the specific context that the credential is in"))
return
} else if req.Context == "" {
req.Context = credentials.DefaultCredentialContext
}

store, err := s.initializeCredentialStore(req.Context)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, err)
return
}

cred, ok, err := store.Get(r.Context(), req.Name)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to get credential: %w", err))
return
} else if !ok {
writeError(logger, w, http.StatusNotFound, fmt.Errorf("credential not found"))
return
}

writeResponse(logger, w, map[string]any{"stdout": cred})
}

func (s *server) deleteCredential(w http.ResponseWriter, r *http.Request) {
logger := gcontext.GetLogger(r.Context())
req := new(credentialsRequest)
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err))
}

if req.Name == "" {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("missing credential name"))
return
}

if req.AllContexts || req.Context == credentials.AllCredentialContexts {
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("allContexts is not supported for credential deletion; please specify the specific context that the credential is in"))
return
} else if req.Context == "" {
req.Context = credentials.DefaultCredentialContext
}

store, err := s.initializeCredentialStore(req.Context)
if err != nil {
writeError(logger, w, http.StatusInternalServerError, err)
return
}

// Check to see if a cred exists so we can return a 404 if it doesn't.
if _, ok, err := store.Get(r.Context(), req.Name); err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to get credential: %w", err))
return
} else if !ok {
writeError(logger, w, http.StatusNotFound, fmt.Errorf("credential not found"))
return
}

if err := store.Remove(r.Context(), req.Name); err != nil {
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to delete credential: %w", err))
return
}

writeResponse(logger, w, map[string]any{"stdout": "Credential deleted successfully"})
}
5 changes: 5 additions & 0 deletions pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ func (s *server) addRoutes(mux *http.ServeMux) {
mux.HandleFunc("POST /confirm/{id}", s.confirm)
mux.HandleFunc("POST /prompt/{id}", s.prompt)
mux.HandleFunc("POST /prompt-response/{id}", s.promptResponse)

mux.HandleFunc("POST /credentials", s.listCredentials)
mux.HandleFunc("POST /credentials/create", s.createCredential)
mux.HandleFunc("POST /credentials/reveal", s.revealCredential)
mux.HandleFunc("POST /credentials/delete", s.deleteCredential)
}

// health just provides an endpoint for checking whether the server is running and accessible.
Expand Down
7 changes: 7 additions & 0 deletions pkg/sdkserver/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,10 @@ type prompt struct {
Type runner.EventType `json:"type,omitempty"`
Time time.Time `json:"time,omitempty"`
}

type credentialsRequest struct {
content `json:",inline"`
AllContexts bool `json:"allContexts"`
Context string `json:"context"`
Name string `json:"name"`
}

0 comments on commit 90e7868

Please sign in to comment.