diff --git a/cmd/port-forward.go b/cmd/port-forward.go new file mode 100644 index 00000000..1a18518d --- /dev/null +++ b/cmd/port-forward.go @@ -0,0 +1,307 @@ +package cmd + +import ( + "errors" + "fmt" + log "github.com/sirupsen/logrus" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + + "github.com/pterm/pterm" + "github.com/spf13/cobra" + "golang.org/x/net/context" + + "github.com/qovery/qovery-cli/pkg" + "github.com/qovery/qovery-cli/utils" +) + +var portForwardCmd = &cobra.Command{ + Use: "port-forward", + Short: "Port forward a port to an application container", + Run: func(cmd *cobra.Command, args []string) { + utils.Capture(cmd) + + if len(ports) == 0 { + log.Fatal("port flag must be specified at least once") + return + } + + var portForwardRequest *pkg.PortForwardRequest + var err error + if len(args) > 0 { + portForwardRequest, err = portForwardRequestWithApplicationUrl(args) + } else { + portForwardRequest, err = portForwardRequestWithoutArg() + } + if err != nil { + utils.PrintlnError(err) + return + } + + for _, port := range ports { + ps := strings.Split(port, ":") + var localPortStr, remotePortStr string + if len(ps) > 1 { + localPortStr = ps[0] + remotePortStr = ps[1] + } else { + localPortStr = ps[0] + remotePortStr = ps[0] + } + + localPort, err := strconv.ParseUint(localPortStr, 10, 16) + if err != nil { + log.Fatal("Invalid local port {} {}", port, err) + } + + remotePort, err := strconv.ParseUint(remotePortStr, 10, 16) + if err != nil { + log.Fatal("Invalid remote port {} {}", port, err) + } + + req := *portForwardRequest + req.LocalPort = uint16(localPort) + req.Port = uint16(remotePort) + go pkg.ExecPortForward(&req) + } + + done := make(chan os.Signal, 1) + signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) + <-done + }, +} +var ( + ports []string +) + +func portForwardRequestWithoutArg() (*pkg.PortForwardRequest, error) { + useContext := false + currentContext, err := utils.CurrentContext() + if err != nil { + return nil, err + } + + utils.PrintlnInfo("Current context:") + if currentContext.ServiceId != "" && currentContext.ServiceName != "" && + currentContext.EnvironmentId != "" && currentContext.EnvironmentName != "" && + currentContext.ProjectId != "" && currentContext.ProjectName != "" && + currentContext.OrganizationId != "" && currentContext.OrganizationName != "" { + if err := utils.PrintlnContext(); err != nil { + fmt.Println("Context not yet configured.") + } + fmt.Println() + + utils.PrintlnInfo("Continue with port-forward command using this context ?") + useContext = utils.Validate("context") + fmt.Println() + } else { + if err := utils.PrintlnContext(); err != nil { + fmt.Println("Context not yet configured.") + fmt.Println("Unable to use current context for `port-forward` command.") + fmt.Println() + } + } + + var req *pkg.PortForwardRequest + if useContext { + req, err = portForwardRequestFromContext(currentContext) + } else { + req, err = portForwardRequestFromSelect() + } + if err != nil { + return nil, err + } + + return req, nil +} + +func portForwardRequestFromSelect() (*pkg.PortForwardRequest, error) { + utils.PrintlnInfo("Select organization") + orga, err := utils.SelectOrganization() + if err != nil { + return nil, err + } + + utils.PrintlnInfo("Select project") + project, err := utils.SelectProject(orga.ID) + if err != nil { + return nil, err + } + + utils.PrintlnInfo("Select environment") + env, err := utils.SelectEnvironment(project.ID) + if err != nil { + return nil, err + } + + utils.PrintlnInfo("Select service") + service, err := utils.SelectService(env.ID) + if err != nil { + return nil, err + } + + return &pkg.PortForwardRequest{ + ServiceID: service.ID, + ProjectID: project.ID, + OrganizationID: orga.ID, + EnvironmentID: env.ID, + ClusterID: env.ClusterID, + PodName: podName, + Port: 0, + LocalPort: 0, + }, nil +} + +func portForwardRequestFromContext(currentContext utils.QoveryContext) (*pkg.PortForwardRequest, error) { + tokenType, token, err := utils.GetAccessToken() + if err != nil { + utils.PrintlnError(err) + os.Exit(1) + panic("unreachable") // staticcheck false positive: https://staticcheck.io/docs/checks#SA5011 + } + + client := utils.GetQoveryClient(tokenType, token) + + e, res, err := client.EnvironmentMainCallsAPI.GetEnvironment(context.Background(), string(currentContext.EnvironmentId)).Execute() + if err != nil { + return nil, err + } + if res.StatusCode >= 400 { + return nil, errors.New("Received " + res.Status + " response while fetching environment. ") + } + + return &pkg.PortForwardRequest{ + ServiceID: currentContext.ServiceId, + ProjectID: currentContext.ProjectId, + OrganizationID: currentContext.OrganizationId, + EnvironmentID: currentContext.EnvironmentId, + ClusterID: utils.Id(e.ClusterId), + PodName: podName, + Port: 0, + LocalPort: 0, + }, nil +} + +func portForwardRequestWithApplicationUrl(args []string) (*pkg.PortForwardRequest, error) { + var url = args[0] + url = strings.Replace(url, "https://console.qovery.com/", "", 1) + url = strings.Replace(url, "https://new.console.qovery.com/", "", 1) + urlSplit := strings.Split(url, "/") + + if len(urlSplit) < 8 { + return nil, errors.New("Wrong URL format: " + url) + } + + var organizationId = urlSplit[1] + organization, err := utils.GetOrganizationById(organizationId) + if err != nil { + return nil, err + } + + var projectId = urlSplit[3] + project, err := utils.GetProjectById(projectId) + if err != nil { + return nil, err + } + + var environmentId = urlSplit[5] + environment, err := utils.GetEnvironmentById(environmentId) + if err != nil { + return nil, err + } + + environmentServices, err := utils.GetEnvironmentServicesById(environmentId) + if err != nil { + return nil, err + } + + var service utils.Service + var serviceId = urlSplit[7] + for _, envService := range environmentServices { + if envService.ID == serviceId { + switch envService.Type { + + case utils.ApplicationType: + applicationAPI, err := utils.GetApplicationById(serviceId) + if err != nil { + return nil, err + } + service = utils.Service{ + ID: applicationAPI.ID, + Name: applicationAPI.Name, + Type: utils.ApplicationType, + } + + case utils.ContainerType: + containerAPI, err := utils.GetContainerById(serviceId) + if err != nil { + return nil, err + } + service = utils.Service{ + ID: containerAPI.ID, + Name: containerAPI.Name, + Type: utils.ContainerType, + } + + case utils.JobType: + jobAPI, err := utils.GetJobById(serviceId) + if err != nil { + return nil, err + } + service = utils.Service{ + ID: jobAPI.ID, + Name: jobAPI.Name, + Type: utils.JobType, + } + + case utils.DatabaseType: + db, err := utils.GetDatabaseById(serviceId) + if err != nil { + return nil, err + } + service = *db + + case utils.HelmType: + helm, err := utils.GetHelmById(serviceId) + if err != nil { + return nil, err + } + service = *helm + + default: + return nil, errors.New("ServiceLevel type `" + string(envService.Type) + "` is not supported for port-forward") + } + } + } + + _ = pterm.DefaultTable.WithData(pterm.TableData{ + {"Organization", string(organization.Name)}, + {"Project", string(project.Name)}, + {"Environment", string(environment.Name)}, + {"ServiceLevel", string(service.Name)}, + {"ServiceType", string(service.Type)}, + }).Render() + + return &pkg.PortForwardRequest{ + OrganizationID: organization.ID, + ProjectID: project.ID, + EnvironmentID: environment.ID, + ServiceID: service.ID, + ClusterID: environment.ClusterID, + PodName: podName, + Port: 8000, + LocalPort: 8000, + }, nil +} + +func init() { + var portForwardCmd = portForwardCmd + portForwardCmd.Flags().StringVarP(&podName, "pod", "", "", "pod name where to forward traffic") + portForwardCmd.Flags().StringSliceVarP(&ports, "port", "p", nil, "port that will be forwarded. Format \"local_port:remote_port\" i.e: 8080:80") + _ = portForwardCmd.MarkFlagRequired("port") + + rootCmd.AddCommand(portForwardCmd) +} diff --git a/cmd/shell.go b/cmd/shell.go index 3c41f0ee..b875b122 100644 --- a/cmd/shell.go +++ b/cmd/shell.go @@ -221,6 +221,20 @@ func shellRequestWithApplicationUrl(args []string) (*pkg.ShellRequest, error) { Type: utils.JobType, } + case utils.DatabaseType: + db, err := utils.GetDatabaseById(serviceId) + if err != nil { + return nil, err + } + service = *db + + case utils.HelmType: + helm, err := utils.GetHelmById(serviceId) + if err != nil { + return nil, err + } + service = *helm + default: return nil, errors.New("ServiceLevel type `" + string(envService.Type) + "` is not supported for shell") } diff --git a/pkg/port-forward.go b/pkg/port-forward.go new file mode 100644 index 00000000..0309bfad --- /dev/null +++ b/pkg/port-forward.go @@ -0,0 +1,125 @@ +package pkg + +import ( + "errors" + "fmt" + "github.com/appscode/go-querystring/query" + "io" + "net" + "net/http" + "net/url" + "regexp" + + "github.com/gorilla/websocket" + "github.com/qovery/qovery-cli/utils" + log "github.com/sirupsen/logrus" +) + +type PortForwardRequest struct { + ServiceID utils.Id `url:"service"` + EnvironmentID utils.Id `url:"environment"` + ProjectID utils.Id `url:"project"` + OrganizationID utils.Id `url:"organization"` + ClusterID utils.Id `url:"cluster"` + PodName string `url:"pod_name,omitempty"` + Port uint16 `url:"port"` + LocalPort uint16 +} + +type WebsocketPortForward struct { + ws *websocket.Conn +} + +func (w WebsocketPortForward) Write(p []byte) (n int, err error) { + err = w.ws.WriteMessage(websocket.BinaryMessage, p) + + return len(p), err +} +func (w WebsocketPortForward) Read(p []byte) (n int, err error) { + _, msg, err := w.ws.ReadMessage() + if err != nil { + return 0, err + } + + return copy(p, msg), err +} + +func mkWebsocketConn(req *PortForwardRequest) (*WebsocketPortForward, error) { + command, err := query.Values(req) + if err != nil { + return nil, err + } + + wsURL, err := url.Parse("wss://ws.qovery.com/shell/portforward") + if err != nil { + return nil, err + } + pattern := regexp.MustCompile("%5B([0-9]+)%5D=") + wsURL.RawQuery = pattern.ReplaceAllString(command.Encode(), "[${1}]=") + + tokenType, token, err := utils.GetAccessToken() + if err != nil { + return nil, err + } + + headers := http.Header{"Authorization": {utils.GetAuthorizationHeaderValue(tokenType, token)}} + wsConn, _, err := websocket.DefaultDialer.Dial(wsURL.String(), headers) + if err != nil { + return nil, err + } + + ws := WebsocketPortForward{ws: wsConn} + return &ws, nil +} + +func ExecPortForward(req *PortForwardRequest) { + listen, error := net.Listen("tcp", fmt.Sprintf("localhost:%d", req.LocalPort)) + + // Handles eventual errors + if error != nil { + fmt.Println(error) + return + } + + fmt.Printf("Listening on %s => %d\n", listen.Addr().String(), req.Port) + + for { + // Accepts connections + con, error := listen.Accept() + + // Handles eventual errors + if error != nil { + fmt.Println(error) + continue + } + + go handleConnection(con, req) + } +} + +func handleConnection(con net.Conn, req *PortForwardRequest) { + var errRet error + fmt.Printf("Connection accepted from %s => %d\n", con.RemoteAddr().String(), req.Port) + defer func() { + con.Close() + fmt.Printf("Connection closed from %s => %d\n", con.RemoteAddr().String(), req.Port) + var e *websocket.CloseError + if errors.As(errRet, &e) && e.Code != websocket.CloseNormalClosure { + log.Error("connection terminated badly with ", e) + } + }() + + wsConn, err := mkWebsocketConn(req) + if err != nil { + log.Fatal("error while creating websocket connection", err) + } + defer func() { + wsConn.ws.Close() + }() + + go func() { + _, _ = io.Copy(wsConn, con) + }() + _, err = io.Copy(con, wsConn) + errRet = err +} diff --git a/utils/qovery.go b/utils/qovery.go index aa18ef34..2be35b0e 100644 --- a/utils/qovery.go +++ b/utils/qovery.go @@ -644,6 +644,52 @@ func GetContainerById(id string) (*Container, error) { }, nil } +func GetDatabaseById(id string) (*Service, error) { + tokenType, token, err := GetAccessToken() + if err != nil { + return nil, err + } + + client := GetQoveryClient(tokenType, token) + + database, res, err := client.DatabaseMainCallsAPI.GetDatabase(context.Background(), id).Execute() + if res.StatusCode >= 400 { + return nil, errors.New("Received " + res.Status + " response while getting database " + id) + } + if err != nil { + return nil, err + } + + return &Service{ + ID: Id(database.Id), + Name: Name(database.GetName()), + Type: DatabaseType, + }, nil +} + +func GetHelmById(id string) (*Service, error) { + tokenType, token, err := GetAccessToken() + if err != nil { + return nil, err + } + + client := GetQoveryClient(tokenType, token) + + helm, res, err := client.HelmMainCallsAPI.GetHelm(context.Background(), id).Execute() + if res.StatusCode >= 400 { + return nil, errors.New("Received " + res.Status + " response while getting helm " + id) + } + if err != nil { + return nil, err + } + + return &Service{ + ID: Id(helm.Id), + Name: Name(helm.GetName()), + Type: HelmType, + }, nil +} + type Job struct { ID Id Name Name