From b7147075b883339283b5cc76d08e116f5bd1c919 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Mon, 14 Oct 2024 08:28:32 -0400 Subject: [PATCH] feat: add workspace API to SDK This change also changes the error behavior when running tools that are simply wrapped commands. Previously, all such tools would not return an error, rather an error message in hopes that the LLM would retry. However, if the tool is just a command (i.e. has no parent), then it should return an error so that the caller doesn't have to guess whether an error occurred. Signed-off-by: Donnie Adams --- pkg/engine/cmd.go | 5 +- pkg/sdkserver/datasets.go | 10 +- pkg/sdkserver/routes.go | 9 + pkg/sdkserver/workspaces.go | 328 ++++++++++++++++++++++++++++++++++++ 4 files changed, 345 insertions(+), 7 deletions(-) create mode 100644 pkg/sdkserver/workspaces.go diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 5b27a579..1dcdaff0 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -154,12 +154,13 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate result = stdout if err := cmd.Run(); err != nil { - if toolCategory == NoCategory { + if toolCategory == NoCategory && ctx.Parent != nil { + // If this is a sub-call, then don't return the error; return the error as a message so that the LLM can retry. return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, stdoutAndErr), nil } log.Errorf("failed to run tool [%s] cmd %v: %v", tool.Parameters.Name, cmd.Args, err) combinedOutput = stdoutAndErr.String() - return "", fmt.Errorf("ERROR: %s: %w", result, err) + return "", fmt.Errorf("ERROR: %s: %w", stdoutAndErr, err) } combinedOutput = stdoutAndErr.String() diff --git a/pkg/sdkserver/datasets.go b/pkg/sdkserver/datasets.go index 0085132c..a65566a4 100644 --- a/pkg/sdkserver/datasets.go +++ b/pkg/sdkserver/datasets.go @@ -62,7 +62,7 @@ func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), "List Datasets from "+req.getToolRepo(), "", loader.Options{ + prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Datasets", loader.Options{ Cache: g.Cache, }) @@ -123,7 +123,7 @@ func (s *server) createDataset(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), "Create Dataset from "+req.getToolRepo(), "", loader.Options{ + prg, err := loader.Program(r.Context(), req.getToolRepo(), "Create Dataset", loader.Options{ Cache: g.Cache, }) @@ -192,7 +192,7 @@ func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), "Add Element from "+req.getToolRepo(), "", loader.Options{ + prg, err := loader.Program(r.Context(), req.getToolRepo(), "Add Element", loader.Options{ Cache: g.Cache, }) if err != nil { @@ -251,7 +251,7 @@ func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), "List Elements from "+req.getToolRepo(), "", loader.Options{ + prg, err := loader.Program(r.Context(), req.getToolRepo(), "List Elements", loader.Options{ Cache: g.Cache, }) if err != nil { @@ -314,7 +314,7 @@ func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) { return } - prg, err := loader.Program(r.Context(), "Get Element from "+req.getToolRepo(), "", loader.Options{ + prg, err := loader.Program(r.Context(), req.getToolRepo(), "Get Element", loader.Options{ Cache: g.Cache, }) if err != nil { diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 8427a6a5..894823b3 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -72,6 +72,15 @@ func (s *server) addRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements) mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement) mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement) + + mux.HandleFunc("POST /workspaces/create", s.createWorkspace) + mux.HandleFunc("POST /workspaces/delete", s.deleteWorkspace) + mux.HandleFunc("POST /workspaces/list", s.listWorkspaceContents) + mux.HandleFunc("POST /workspaces/mkdir", s.mkDirInWorkspace) + mux.HandleFunc("POST /workspaces/rmdir", s.rmDirInWorkspace) + mux.HandleFunc("POST /workspaces/write-file", s.writeFileInWorkspace) + mux.HandleFunc("POST /workspaces/delete-file", s.removeFileInWorkspace) + mux.HandleFunc("POST /workspaces/read-file", s.readFileInWorkspace) } // health just provides an endpoint for checking whether the server is running and accessible. diff --git a/pkg/sdkserver/workspaces.go b/pkg/sdkserver/workspaces.go new file mode 100644 index 00000000..2541a190 --- /dev/null +++ b/pkg/sdkserver/workspaces.go @@ -0,0 +1,328 @@ +package sdkserver + +import ( + "encoding/json" + "fmt" + "net/http" + + gcontext "github.com/gptscript-ai/gptscript/pkg/context" + "github.com/gptscript-ai/gptscript/pkg/loader" +) + +type workspaceCommonRequest struct { + ID string `json:"id"` + WorkspaceToolRepo string `json:"workspaceToolRepo"` +} + +func (w workspaceCommonRequest) getToolRepo() string { + if w.WorkspaceToolRepo != "" { + return w.WorkspaceToolRepo + } + return "/Users/thedadams/code/workspace-provider" +} + +type createWorkspaceRequest struct { + workspaceCommonRequest `json:",inline"` + ProviderType string `json:"providerType"` +} + +func (s *server) createWorkspace(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + var reqObject createWorkspaceRequest + if err := json.NewDecoder(r.Body).Decode(&reqObject); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + prg, err := loader.Program(r.Context(), reqObject.getToolRepo(), "Create Workspace", loader.Options{Cache: s.client.Cache}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + if reqObject.ProviderType == "" { + reqObject.ProviderType = "directory" + } + + out, err := s.client.Run( + r.Context(), + prg, + s.gptscriptOpts.Env, + fmt.Sprintf( + `{"provider": "%s"}`, + reqObject.ProviderType, + ), + ) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": out}) +} + +type deleteWorkspaceRequest struct { + workspaceCommonRequest `json:",inline"` +} + +func (s *server) deleteWorkspace(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + var reqObject deleteWorkspaceRequest + if err := json.NewDecoder(r.Body).Decode(&reqObject); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + prg, err := loader.Program(r.Context(), reqObject.getToolRepo(), "Delete Workspace", loader.Options{Cache: s.client.Cache}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + out, err := s.client.Run( + r.Context(), + prg, + s.gptscriptOpts.Env, + fmt.Sprintf( + `{"workspace_id": "%s"}`, + reqObject.ID, + ), + ) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": out}) +} + +type listWorkspaceContentsRequest struct { + workspaceCommonRequest `json:",inline"` + ID string `json:"id"` + SubDir string `json:"subDir"` + NonRecursive bool `json:"nonRecursive"` + ExcludeHidden bool `json:"excludeHidden"` + JSON bool `json:"json"` +} + +func (s *server) listWorkspaceContents(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + var reqObject listWorkspaceContentsRequest + if err := json.NewDecoder(r.Body).Decode(&reqObject); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + prg, err := loader.Program(r.Context(), reqObject.getToolRepo(), "List Workspace Contents", loader.Options{Cache: s.client.Cache}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + out, err := s.client.Run( + r.Context(), + prg, + s.gptscriptOpts.Env, + fmt.Sprintf( + `{"workspace_id": "%s", "ls_sub_dir": "%s", "ls_non_recursive": %t, "ls_exclude_hidden": %t, "ls_json": %t}`, + reqObject.ID, reqObject.SubDir, reqObject.NonRecursive, reqObject.ExcludeHidden, reqObject.JSON, + ), + ) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": out}) +} + +type mkDirRequest struct { + workspaceCommonRequest `json:",inline"` + DirectoryName string `json:"directoryName"` + IgnoreExists bool `json:"ignoreExists"` + CreateDirs bool `json:"createDirs"` +} + +func (s *server) mkDirInWorkspace(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + var reqObject mkDirRequest + if err := json.NewDecoder(r.Body).Decode(&reqObject); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + prg, err := loader.Program(r.Context(), reqObject.getToolRepo(), "Create Directory In Workspace", loader.Options{Cache: s.client.Cache}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + out, err := s.client.Run( + r.Context(), + prg, + s.gptscriptOpts.Env, + fmt.Sprintf( + `{"workspace_id": "%s", "directory_name": "%s", "mk_dir_ignore_exists": %t, "mk_dir_create_dirs": %t}`, + reqObject.ID, reqObject.DirectoryName, reqObject.IgnoreExists, reqObject.CreateDirs, + ), + ) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": out}) +} + +type rmDirRequest struct { + workspaceCommonRequest `json:",inline"` + DirectoryName string `json:"directoryName"` + IgnoreNotFound bool `json:"ignoreNotFound"` + MustBeEmpty bool `json:"mustBeEmpty"` +} + +func (s *server) rmDirInWorkspace(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + var reqObject rmDirRequest + if err := json.NewDecoder(r.Body).Decode(&reqObject); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + prg, err := loader.Program(r.Context(), reqObject.getToolRepo(), "Remove Directory In Workspace", loader.Options{Cache: s.client.Cache}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + out, err := s.client.Run( + r.Context(), + prg, + s.gptscriptOpts.Env, + fmt.Sprintf( + `{"workspace_id": "%s", "directory_name": "%s", "ignore_not_found": %t, "rm_dir_must_be_empty": %t}`, + reqObject.ID, reqObject.DirectoryName, reqObject.IgnoreNotFound, reqObject.MustBeEmpty, + ), + ) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": out}) +} + +type writeFileInWorkspaceRequest struct { + workspaceCommonRequest `json:",inline"` + FilePath string `json:"filePath"` + Contents string `json:"contents"` + Base64EncodedInput bool `json:"base64EncodedInput"` + MustNotExist bool `json:"mustNotExist"` + CreateDirs bool `json:"createDirs"` + WithoutCreate bool `json:"withoutCreate"` +} + +func (s *server) writeFileInWorkspace(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + var reqObject writeFileInWorkspaceRequest + if err := json.NewDecoder(r.Body).Decode(&reqObject); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + prg, err := loader.Program(r.Context(), reqObject.getToolRepo(), "Write File In Workspace", loader.Options{Cache: s.client.Cache}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + out, err := s.client.Run( + r.Context(), + prg, + s.gptscriptOpts.Env, + fmt.Sprintf( + `{"workspace_id": "%s", "file_path": "%s", "file_contents": "%s", "write_file_must_not_exist": %t, "write_file_create_dirs": %t, "write_file_without_create": %t, "write_file_base64_encoded_input": %t}`, + reqObject.ID, reqObject.FilePath, reqObject.Contents, reqObject.MustNotExist, reqObject.CreateDirs, reqObject.WithoutCreate, reqObject.Base64EncodedInput, + ), + ) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": out}) +} + +type rmFileInWorkspaceRequest struct { + workspaceCommonRequest `json:",inline"` + FilePath string `json:"filePath"` + IgnoreNotFound bool `json:"ignoreNotFound"` +} + +func (s *server) removeFileInWorkspace(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + var reqObject rmFileInWorkspaceRequest + if err := json.NewDecoder(r.Body).Decode(&reqObject); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + prg, err := loader.Program(r.Context(), reqObject.getToolRepo(), "Remove File In Workspace", loader.Options{Cache: s.client.Cache}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + out, err := s.client.Run( + r.Context(), + prg, + s.gptscriptOpts.Env, + fmt.Sprintf( + `{"workspace_id": "%s", "file_path": "%s", "ignore_not_found": %t}`, + reqObject.ID, reqObject.FilePath, reqObject.IgnoreNotFound, + ), + ) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": out}) +} + +type readFileInWorkspaceRequest struct { + workspaceCommonRequest `json:",inline"` + FilePath string `json:"filePath"` + Base64EncodeOutput bool `json:"base64EncodeOutput"` +} + +func (s *server) readFileInWorkspace(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + var reqObject readFileInWorkspaceRequest + if err := json.NewDecoder(r.Body).Decode(&reqObject); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + prg, err := loader.Program(r.Context(), reqObject.getToolRepo(), "Read File In Workspace", loader.Options{Cache: s.client.Cache}) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) + return + } + + out, err := s.client.Run( + r.Context(), + prg, + s.gptscriptOpts.Env, + fmt.Sprintf( + `{"workspace_id": "%s", "file_path": "%s", "read_file_base64_encode_output": %t}`, + reqObject.ID, reqObject.FilePath, reqObject.Base64EncodeOutput, + ), + ) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": out}) +}