diff --git a/agent/agent.go b/agent/agent.go index d701e40b589..a9eeb9fe7b2 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -42,22 +42,16 @@ import ( "context" "crypto/rsa" "fmt" - "io" "math/rand" "net" - "net/http" - "net/netip" "net/url" "os" "runtime" "strings" - "sync" "sync/atomic" "time" "github.com/Masterminds/semver" - dockerclient "github.com/docker/docker/client" - "github.com/labstack/echo/v4" "github.com/pkg/errors" "github.com/shellhub-io/shellhub/agent/pkg/keygen" "github.com/shellhub-io/shellhub/agent/pkg/sysinfo" @@ -112,6 +106,10 @@ type Config struct { // MaxRetryConnectionTimeout specifies the maximum time, in seconds, that an agent will wait // before attempting to reconnect to the ShellHub server. Default is 60 seconds. MaxRetryConnectionTimeout int `env:"MAX_RETRY_CONNECTION_TIMEOUT,default=60" validate:"min=10,max=120"` + + // ConnectionVersion specifies the version of the connection protocol to use. + // Supported values are 1 and 2. Default is 1. + ConnectionVersion int `env:"CONNECTION_VERSION,default=1"` } func LoadConfigFromEnv() (*Config, map[string]interface{}, error) { @@ -161,10 +159,13 @@ type Agent struct { cli client.Client serverInfo *models.Info server *server.Server - tunnel *tunnel.Tunnel listening chan bool closed atomic.Bool mode Mode + // listener is the current connection to the server. + listener atomic.Pointer[net.Listener] + // logger is the agent's logger instance. + logger *log.Entry } // NewAgent creates a new agent instance, requiring the ShellHub server's address to connect to, the namespace's tenant @@ -255,6 +256,16 @@ func (a *Agent) Initialize() error { a.closed.Store(false) + a.logger = log.WithFields(log.Fields{ + "version": AgentVersion, + "tenant_id": a.authData.Namespace, + "server_address": a.config.ServerAddress, + "ssh_endpoint": a.serverInfo.Endpoints.SSH, + "api_endpoint": a.serverInfo.Endpoints.API, + "connection_version": a.config.ConnectionVersion, + "sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]), + }) + return nil } @@ -370,195 +381,93 @@ func (a *Agent) isClosed() bool { func (a *Agent) Close() error { a.closed.Store(true) - return a.tunnel.Close() -} - -func sshHandler(serv *server.Server) func(c echo.Context) error { - return func(c echo.Context) error { - hj, ok := c.Response().Writer.(http.Hijacker) - if !ok { - return c.String(http.StatusInternalServerError, "webserver doesn't support hijacking") - } + l := a.listener.Load() + if l == nil { + return nil + } - conn, _, err := hj.Hijack() - if err != nil { - return c.String(http.StatusInternalServerError, "failed to hijack connection") - } + return (*l).Close() +} - id := c.Param("id") - httpConn := c.Request().Context().Value("http-conn").(net.Conn) - serv.Sessions.Store(id, httpConn) - serv.HandleConn(httpConn) +const ( + ConnectionV1 = 1 + ConnectionV2 = 2 +) - conn.Close() +func (a *Agent) Listen(ctx context.Context) error { + a.mode.Serve(a) - return nil + switch a.config.ConnectionVersion { + case ConnectionV1: + return a.listenV1(ctx) + case ConnectionV2: + return a.listenV2(ctx) + default: + return fmt.Errorf("unsupported connection version: %d", a.config.ConnectionVersion) } } -// httpProxyHandler handlers proxy connections to the required address. -func httpProxyHandler(agent *Agent) func(c echo.Context) error { - const ProxyHandlerNetwork = "tcp" +func (a *Agent) listenV1(ctx context.Context) error { + tun := tunnel.NewTunnelV1() - return func(c echo.Context) error { - logger := log.WithFields(log.Fields{ - "remote": c.Request().RemoteAddr, - "namespace": c.Request().Header.Get("X-Namespace"), - "path": c.Request().Header.Get("X-Path"), - "version": AgentVersion, - }) + tun.Handle(HandleSSHOpenV1, sshHandlerV1(a)) + tun.Handle(HandleSSHCloseV1, sshCloseHandlerV1(a)) + tun.Handle(HandleHTTPProxyV1, httpProxyHandlerV1(a)) - errorResponse := func(err error, msg string, code int) error { - logger.WithError(err).Debug(msg) + go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck - return c.String(code, msg) - } + ctx, cancel := context.WithCancel(ctx) + go func() { + for { + if a.isClosed() { + a.logger.Info("Stopped listening for connections") - host, port, err := net.SplitHostPort(c.Param("addr")) - if err != nil { - return errorResponse(err, "failed because address is invalid", http.StatusInternalServerError) - } + cancel() - if _, ok := agent.mode.(*ConnectorMode); ok { - cli, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation()) - if err != nil { - return errorResponse(err, "failed to connect to the Docker Engine", http.StatusInternalServerError) + return } - container, err := cli.ContainerInspect(context.Background(), agent.server.ContainerID) - if err != nil { - return errorResponse(err, "failed to inspect the container", http.StatusInternalServerError) - } + ShellHubConnectV1Path := "/ssh/connection" - var target string + a.logger.Debug("Using tunnel version 1") - addr, err := netip.ParseAddr(host) + listener, err := a.cli.NewReverseListenerV1( + ctx, + a.authData.Token, + ShellHubConnectV1Path, + ) if err != nil { - return errorResponse(err, "failed to parse the for lookback checkage", http.StatusInternalServerError) - } - - if addr.IsLoopback() { - for _, network := range container.NetworkSettings.Networks { - target = network.IPAddress - - break - } - } else { - for _, network := range container.NetworkSettings.Networks { - subnet, err := netip.ParsePrefix(fmt.Sprintf("%s/%d", network.Gateway, network.IPPrefixLen)) - if err != nil { - logger.WithError(err).Trace("Failed to parse the gateway on proxy") - - continue - } - - ip, err := netip.ParseAddr(host) - if err != nil { - logger.WithError(err).Trace("Failed to parse the address on proxy") + a.logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds") - continue - } - - if subnet.Contains(ip) { - target = ip.String() - - break - } - } - } + time.Sleep(time.Second * 10) - if target == "" { - return errorResponse(nil, "address not found on the device", http.StatusInternalServerError) + continue } + a.listener.Store(&listener) - host = target - } - - // NOTE: Gets the to address to connect to. This address can be just a port, :8080, or the host and port, - // localhost:8080. - addr := fmt.Sprintf("%s:%s", host, port) + a.logger.Info("Server connection established") - in, err := net.Dial(ProxyHandlerNetwork, addr) - if err != nil { - return errorResponse(err, "failed to connect to the server on device", http.StatusInternalServerError) - } - - defer in.Close() + a.listening <- true - // NOTE: Inform to the connection that the dial was successfully. - if err := c.NoContent(http.StatusOK); err != nil { - return errorResponse(err, "failed to send the ok status code back to server", http.StatusInternalServerError) - } + if err := tun.Listen(ctx, listener); err != nil { + a.logger.WithError(err).Error("Tunnel listener exited with error") + } - // NOTE: Hijacks the connection to control the data transferred to the client connected. This way, we don't - // depend upon anything externally, only the data. - out, _, err := c.Response().Hijack() - if err != nil { - return errorResponse(err, "failed to hijack connection", http.StatusInternalServerError) + a.listening <- false } + }() - defer out.Close() // nolint:errcheck - - wg := new(sync.WaitGroup) - done := sync.OnceFunc(func() { - defer in.Close() - defer out.Close() - - logger.Trace("close called on in and out connections") - }) - - wg.Add(1) - go func() { - defer done() - defer wg.Done() - - io.Copy(in, out) //nolint:errcheck - }() - - wg.Add(1) - go func() { - defer done() - defer wg.Done() - - io.Copy(out, in) //nolint:errcheck - }() - - logger.WithError(err).Trace("proxy handler waiting for data pipe") - wg.Wait() - - logger.WithError(err).Trace("proxy handler done") - - return nil - } -} - -func sshCloseHandler(a *Agent, serv *server.Server) func(c echo.Context) error { - return func(c echo.Context) error { - id := c.Param("id") - serv.CloseSession(id) - - log.WithFields( - log.Fields{ - "id": id, - "version": AgentVersion, - "tenant_id": a.authData.Namespace, - "server_address": a.config.ServerAddress, - }, - ).Info("A tunnel connection was closed") + <-ctx.Done() - return nil - } + return a.Close() } -// Listen creates the SSH server and listening for connections. -func (a *Agent) Listen(ctx context.Context) error { - a.mode.Serve(a) +func (a *Agent) listenV2(ctx context.Context) error { + tun := tunnel.NewTunnelV2(a.cli) - a.tunnel = tunnel.NewBuilder(). - WithSSHHandler(sshHandler(a.server)). - WithSSHCloseHandler(sshCloseHandler(a, a.server)). - WithHTTPProxyHandler(httpProxyHandler(a)). - Build() + tun.Handle(HandleSSHOpenV2, sshHandlerV2(a)) + tun.Handle(HandleSSHCloseV2, sshCloseHandlerV2(a)) + tun.Handle(HandleHTTPProxyV2, httpProxyHandlerV2(a)) go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck @@ -566,65 +475,38 @@ func (a *Agent) Listen(ctx context.Context) error { go func() { for { if a.isClosed() { - log.WithFields(log.Fields{ - "version": AgentVersion, - "tenant_id": a.authData.Namespace, - "server_address": a.config.ServerAddress, - }).Info("Stopped listening for connections") + a.logger.Info("Stopped listening for connections") cancel() return } - namespace := a.authData.Namespace - tenantName := a.authData.Name - sshEndpoint := a.serverInfo.Endpoints.SSH + ShellHubConnectV2Path := "/agent/connection" - sshid := strings.NewReplacer( - "{namespace}", namespace, - "{tenantName}", tenantName, - "{sshEndpoint}", strings.Split(sshEndpoint, ":")[0], - ).Replace("{namespace}.{tenantName}@{sshEndpoint}") + a.logger.Debug("Using tunnel version 2") - listener, err := a.cli.NewReverseListener(ctx, a.authData.Token, "/ssh/connection") + listener, err := a.cli.NewReverseListenerV2( + ctx, + a.authData.Token, + ShellHubConnectV2Path, + client.NewReverseV2ConfigFromMap(a.authData.Config), + ) if err != nil { - log.WithError(err).WithFields(log.Fields{ - "version": AgentVersion, - "tenant_id": a.authData.Namespace, - "server_address": a.config.ServerAddress, - "ssh_server": sshEndpoint, - "sshid": sshid, - }).Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds") + a.logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds") + time.Sleep(time.Second * 10) continue } + a.listener.Store(&listener) - log.WithFields(log.Fields{ - "namespace": namespace, - "hostname": tenantName, - "server_address": a.config.ServerAddress, - "ssh_server": sshEndpoint, - "sshid": sshid, - }).Info("Server connection established") + a.logger.Info("Server connection established") a.listening <- true - { - // NOTE: Tunnel'll only realize that it lost its connection to the ShellHub SSH when the next - // "keep-alive" connection fails. As a result, it will take this interval to reconnect to its server. - err := a.tunnel.Listen(listener) - - log.WithError(err).WithFields(log.Fields{ - "namespace": namespace, - "hostname": tenantName, - "server_address": a.config.ServerAddress, - "ssh_server": sshEndpoint, - "sshid": sshid, - }).Info("Tunnel listener closed") - - listener.Close() // nolint:errcheck + if err := tun.Listen(ctx, listener); err != nil { + a.logger.WithError(err).Error("Tunnel listener exited with error") } a.listening <- false diff --git a/agent/go.mod b/agent/go.mod index 8a6d18b2789..63ab735f138 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -9,8 +9,10 @@ require ( github.com/docker/docker v28.5.1+incompatible github.com/gliderlabs/ssh v0.3.5 github.com/go-playground/assert/v2 v2.2.0 - github.com/labstack/echo/v4 v4.13.4 + github.com/gorilla/websocket v1.5.0 + github.com/labstack/echo/v4 v4.10.2 github.com/mattn/go-shellwords v1.0.12 + github.com/multiformats/go-multistream v0.6.1 github.com/openwall/yescrypt-go v1.0.0 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.9 @@ -22,14 +24,6 @@ require ( golang.org/x/sys v0.37.0 ) -require ( - github.com/labstack/gommon v0.4.2 // indirect - github.com/mattn/go-colorable v0.1.14 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasttemplate v1.2.2 // indirect -) - require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect @@ -48,18 +42,24 @@ require ( github.com/go-playground/validator/v10 v10.11.2 // indirect github.com/go-resty/resty/v2 v2.7.0 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/websocket v1.5.0 // indirect + github.com/hashicorp/yamux v0.1.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kr/fs v0.1.0 // indirect + github.com/labstack/gommon v0.4.0 // indirect github.com/leodido/go-urn v1.2.2 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/sys/atomicwriter v0.1.0 // indirect + github.com/multiformats/go-varint v0.0.6 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sethvargo/go-envconfig v0.9.0 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect @@ -69,6 +69,7 @@ require ( go.opentelemetry.io/proto/otlp v1.2.0 // indirect golang.org/x/net v0.45.0 // indirect golang.org/x/text v0.30.0 // indirect + golang.org/x/time v0.11.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gotest.tools/v3 v3.5.1 // indirect ) diff --git a/agent/go.sum b/agent/go.sum index 112cc29b2a4..1c54a4bfec7 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -56,6 +56,8 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 h1:/c3QmbOGMGTOumP2iT/rCwB7b0QDGLKzqOmktBjT+Is= github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1/go.mod h1:5SN9VR2LTsRFsrEC6FHgRbTWrTHu6tqPeKxEQv15giM= +github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= +github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= @@ -66,16 +68,19 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/labstack/echo/v4 v4.13.4 h1:oTZZW+T3s9gAu5L8vmzihV7/lkXGZuITzTQkTEhcXEA= -github.com/labstack/echo/v4 v4.13.4/go.mod h1:g63b33BZ5vZzcIUF8AtRH40DrTlXnx4UMC8rBdndmjQ= -github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= -github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= +github.com/labstack/echo/v4 v4.10.2 h1:n1jAhnq/elIFTHr1EYpiYtyKgx4RW9ccVgkqByZaN2M= +github.com/labstack/echo/v4 v4.10.2/go.mod h1:OEyqf2//K1DFdE57vw2DRgWY0M7s65IVQO2FzvI4J5k= +github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= +github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= github.com/leodido/go-urn v1.2.2 h1:7z68G0FCGvDk646jz1AelTYNYWrTNm0bEcFAo147wt4= github.com/leodido/go-urn v1.2.2/go.mod h1:kUaIbLZWttglzwNuG0pgsh5vuV6u2YcGBYz1hIPjtOQ= -github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= -github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk= github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -88,6 +93,10 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/multiformats/go-multistream v0.6.1 h1:4aoX5v6T+yWmc2raBHsTvzmFhOI8WVOer28DeBBEYdQ= +github.com/multiformats/go-multistream v0.6.1/go.mod h1:ksQf6kqHAb6zIsyw7Zm+gAuVo57Qbq84E27YlYqavqw= +github.com/multiformats/go-varint v0.0.6 h1:gk85QWKxh3TazbLxED/NlDVv8+q+ReFJk7Y2W/KhfNY= +github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= @@ -127,6 +136,7 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -187,13 +197,16 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= @@ -246,6 +259,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= diff --git a/agent/handlers.go b/agent/handlers.go new file mode 100644 index 00000000000..1f54cff350b --- /dev/null +++ b/agent/handlers.go @@ -0,0 +1,408 @@ +package main + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "sync" + + dockerclient "github.com/docker/docker/client" + "github.com/labstack/echo/v4" + "github.com/pkg/errors" + "github.com/shellhub-io/shellhub/agent/pkg/tunnel" + log "github.com/sirupsen/logrus" +) + +const ( + // HandleSSHOpenV2 is the protocol used to open a new SSH connection. + HandleSSHOpenV2 = "/ssh/open/1.0.0" + // HandleSSHCloseV2 is the protocol used to close an existing SSH connection. + HandleSSHCloseV2 = "/ssh/close/1.0.0" + // HandleHTTPProxyV2 is the protocol used to open a new HTTP proxy connection. + HandleHTTPProxyV2 = "/http/proxy/1.0.0" +) + +// httpProxyHandlerV2 handlers proxy connections to the required address. +func httpProxyHandlerV2(agent *Agent) tunnel.HandlerFunc { + const ProxyHandlerNetwork = "tcp" + + return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error { + headers, err := ctx.Headers() + if err != nil { + log.WithError(err).Error("failed to get the headers from the connection") + + return err + } + + id := headers["id"] + host := headers["host"] + port := headers["port"] + + logger := log.WithFields(log.Fields{ + "id": id, + "host": host, + "port": port, + }) + + if _, ok := agent.mode.(*ConnectorMode); ok { + cli, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation()) + if err != nil { + log.WithError(err).Error("failed to create the Docker client") + + return ctx.Error(errors.New("failed to connect to the Docker Engine")) + } + + container, err := cli.ContainerInspect(context.Background(), agent.server.ContainerID) + if err != nil { + log.WithError(err).Error("failed to inspect the container") + + return ctx.Error(errors.New("failed to inspect the container")) + } + + var target string + + addr, err := netip.ParseAddr(host) + if err != nil { + log.WithError(err).Error("failed to parse the address on proxy") + + return ctx.Error(errors.New("failed to parse the address on proxy")) + } + + if addr.IsLoopback() { + log.Trace("host is a loopback address, using the container IP address") + + for _, network := range container.NetworkSettings.Networks { + target = network.IPAddress + + break + } + } else { + for _, network := range container.NetworkSettings.Networks { + subnet, err := netip.ParsePrefix(fmt.Sprintf("%s/%d", network.Gateway, network.IPPrefixLen)) + if err != nil { + logger.WithError(err).Error("failed to parse the gateway on proxy") + + continue + } + + ip, err := netip.ParseAddr(host) + if err != nil { + logger.WithError(err).Error("failed to parse the address on proxy") + + continue + } + + if subnet.Contains(ip) { + target = ip.String() + + break + } + } + } + + if target == "" { + return ctx.Error(errors.New("address not found on the device")) + } + + host = target + } + + ErrFailedDialToAddressAndPort := errors.New("failed to dial to the address and port") + + logger.Trace("proxy handler connecting to the address") + + in, err := net.Dial(ProxyHandlerNetwork, net.JoinHostPort(host, port)) + if err != nil { + logger.WithError(err).Error("proxy handler failed to dial to the address") + + return ctx.Error(ErrFailedDialToAddressAndPort) + } + + defer in.Close() + + logger.Trace("proxy handler dialed to the address") + + // TODO: Add consts for status values. + if err := ctx.Status("ok"); err != nil { + logger.WithError(err).Error("proxy handler failed to send status response") + + return err + } + + wg := new(sync.WaitGroup) + done := sync.OnceFunc(func() { + defer in.Close() + defer rwc.Close() + + logger.Trace("close called on in and out connections") + }) + + wg.Add(1) + go func() { + defer done() + defer wg.Done() + + if _, err := io.Copy(in, rwc); err != nil && err != io.EOF { + logger.WithError(err).Error("proxy handler copy from rwc to in failed") + } + }() + + wg.Add(1) + go func() { + defer done() + defer wg.Done() + + if _, err := io.Copy(rwc, in); err != nil && err != io.EOF { + logger.WithError(err).Error("proxy handler copy from in to rwc failed") + } + }() + + logger.WithError(err).Info("proxy handler waiting for data pipe") + + wg.Wait() + + logger.WithError(err).Info("proxy handler done") + + return nil + } +} + +func sshHandlerV2(agent *Agent) tunnel.HandlerFunc { + return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error { + defer rwc.Close() + + headers, err := ctx.Headers() + if err != nil { + log.WithError(err).Error("failed to get the headers from the connection") + + return err + } + + id := headers["id"] + + conn, ok := rwc.(net.Conn) + if !ok { + log.Error("failed to cast the ReadWriteCloser to net.Conn") + + return errors.New("failed to cast the ReadWriteCloser to net.Conn") + } + + agent.server.Sessions.Store(id, conn) + agent.server.HandleConn(conn) + + return nil + } +} + +func sshCloseHandlerV2(agent *Agent) tunnel.HandlerFunc { + return func(ctx tunnel.Context, rwc io.ReadWriteCloser) error { + defer rwc.Close() + + headers, err := ctx.Headers() + if err != nil { + log.WithError(err).Error("failed to get the headers from the connection") + + return err + } + + id := headers["id"] + + agent.server.CloseSession(id) + + log.WithFields( + log.Fields{ + "id": id, + "version": AgentVersion, + "tenant_id": agent.authData.Namespace, + "server_address": agent.config.ServerAddress, + }, + ).Info("A tunnel connection was closed") + + return nil + } +} + +const ( + HandleSSHOpenV1 = "GET:///ssh/:id" + HandleSSHCloseV1 = "GET:///ssh/close/:id" + HandleHTTPProxyV1 = "CONNECT:///http/proxy/:addr" +) + +func httpProxyHandlerV1(agent *Agent) func(c echo.Context) error { + const ProxyHandlerNetwork = "tcp" + + return func(c echo.Context) error { + logger := log.WithFields(log.Fields{ + "remote": c.Request().RemoteAddr, + "namespace": c.Request().Header.Get("X-Namespace"), + "path": c.Request().Header.Get("X-Path"), + "version": AgentVersion, + }) + + errorResponse := func(err error, msg string, code int) error { + logger.WithError(err).Debug(msg) + + return c.String(code, msg) + } + + host, port, err := net.SplitHostPort(c.Param("addr")) + if err != nil { + return errorResponse(err, "failed because address is invalid", http.StatusInternalServerError) + } + + if _, ok := agent.mode.(*ConnectorMode); ok { + cli, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation()) + if err != nil { + return errorResponse(err, "failed to connect to the Docker Engine", http.StatusInternalServerError) + } + + container, err := cli.ContainerInspect(context.Background(), agent.server.ContainerID) + if err != nil { + return errorResponse(err, "failed to inspect the container", http.StatusInternalServerError) + } + + var target string + + addr, err := netip.ParseAddr(host) + if err != nil { + return errorResponse(err, "failed to parse the for lookback checkage", http.StatusInternalServerError) + } + + if addr.IsLoopback() { + for _, network := range container.NetworkSettings.Networks { + target = network.IPAddress + + break + } + } else { + for _, network := range container.NetworkSettings.Networks { + subnet, err := netip.ParsePrefix(fmt.Sprintf("%s/%d", network.Gateway, network.IPPrefixLen)) + if err != nil { + logger.WithError(err).Trace("Failed to parse the gateway on proxy") + + continue + } + + ip, err := netip.ParseAddr(host) + if err != nil { + logger.WithError(err).Trace("Failed to parse the address on proxy") + + continue + } + + if subnet.Contains(ip) { + target = ip.String() + + break + } + } + } + + if target == "" { + return errorResponse(nil, "address not found on the device", http.StatusInternalServerError) + } + + host = target + } + + // NOTE: Gets the to address to connect to. This address can be just a port, :8080, or the host and port, + // localhost:8080. + addr := fmt.Sprintf("%s:%s", host, port) + + in, err := net.Dial(ProxyHandlerNetwork, addr) + if err != nil { + return errorResponse(err, "failed to connect to the server on device", http.StatusInternalServerError) + } + + defer in.Close() + + // NOTE: Inform to the connection that the dial was successfully. + if err := c.NoContent(http.StatusOK); err != nil { + return errorResponse(err, "failed to send the ok status code back to server", http.StatusInternalServerError) + } + + // NOTE: Hijacks the connection to control the data transferred to the client connected. This way, we don't + // depend upon anything externally, only the data. + out, _, err := c.Response().Hijack() + if err != nil { + return errorResponse(err, "failed to hijack connection", http.StatusInternalServerError) + } + + defer out.Close() // nolint:errcheck + + wg := new(sync.WaitGroup) + done := sync.OnceFunc(func() { + defer in.Close() + defer out.Close() + + logger.Trace("close called on in and out connections") + }) + + wg.Add(1) + go func() { + defer done() + defer wg.Done() + + io.Copy(in, out) //nolint:errcheck + }() + + wg.Add(1) + go func() { + defer done() + defer wg.Done() + + io.Copy(out, in) //nolint:errcheck + }() + + logger.WithError(err).Trace("proxy handler waiting for data pipe") + wg.Wait() + + logger.WithError(err).Trace("proxy handler done") + + return nil + } +} + +func sshHandlerV1(ag *Agent) func(c echo.Context) error { + return func(c echo.Context) error { + hj, ok := c.Response().Writer.(http.Hijacker) + if !ok { + return c.String(http.StatusInternalServerError, "webserver doesn't support hijacking") + } + + conn, _, err := hj.Hijack() + if err != nil { + return c.String(http.StatusInternalServerError, "failed to hijack connection") + } + + id := c.Param("id") + httpConn := c.Request().Context().Value("http-conn").(net.Conn) + ag.server.Sessions.Store(id, httpConn) + ag.server.HandleConn(httpConn) + + conn.Close() + + return nil + } +} + +func sshCloseHandlerV1(a *Agent) func(c echo.Context) error { + return func(c echo.Context) error { + id := c.Param("id") + a.server.CloseSession(id) + + log.WithFields( + log.Fields{ + "id": id, + "version": AgentVersion, + "tenant_id": a.authData.Namespace, + "server_address": a.config.ServerAddress, + }, + ).Info("A tunnel connection was closed") + + return nil + } +} diff --git a/agent/pkg/tunnel/context.go b/agent/pkg/tunnel/context.go new file mode 100644 index 00000000000..820730ee6a0 --- /dev/null +++ b/agent/pkg/tunnel/context.go @@ -0,0 +1,79 @@ +package tunnel + +import ( + "context" + "encoding/json" + "errors" + "io" + "time" + + log "github.com/sirupsen/logrus" +) + +type Context struct { + ctx context.Context + + encoder *json.Encoder + decoder *json.Decoder +} + +func (c Context) Deadline() (deadline time.Time, ok bool) { + return c.ctx.Deadline() +} + +func (c Context) Done() <-chan struct{} { + return c.ctx.Done() +} + +func (c Context) Err() error { + return c.ctx.Err() +} + +func (c Context) Value(key any) any { + return c.ctx.Value(key) +} + +func (c Context) Status(status string) error { + if err := c.encoder.Encode(map[string]string{"status": status}); err != nil { + log.WithError(err).Error("failed to send status response") + + return errors.Join(errors.New("failed to send status response"), err) + } + + return nil +} + +func (c Context) Error(err error) error { + if err := c.encoder.Encode(map[string]string{"error": err.Error()}); err != nil { + log.WithError(err).Error("failed to send error response") + + return errors.Join(errors.New("failed to send error response"), err) + } + + return nil +} + +type Headers map[string]string + +func (c Context) Headers() (Headers, error) { + // TODO: cache the headers after the first call. + var header Headers + + if err := c.decoder.Decode(&header); err != nil { + log.WithError(err).Error("failed to decode the header") + + return nil, err + } + + return header, nil +} + +func NewContext(ctx context.Context, rwc io.ReadWriteCloser) Context { + return Context{ + ctx: ctx, + encoder: json.NewEncoder(rwc), + decoder: json.NewDecoder(rwc), + } +} + +type HandlerFunc func(ctx Context, rwc io.ReadWriteCloser) error diff --git a/agent/pkg/tunnel/tunnel.go b/agent/pkg/tunnel/tunnel.go index cb070ec3a48..d6f47bcd4a5 100644 --- a/agent/pkg/tunnel/tunnel.go +++ b/agent/pkg/tunnel/tunnel.go @@ -2,57 +2,101 @@ package tunnel import ( "context" + "errors" + "io" "net" "net/http" + "strings" + "github.com/gorilla/websocket" "github.com/labstack/echo/v4" - "github.com/shellhub-io/shellhub/pkg/revdial" + "github.com/multiformats/go-multistream" + "github.com/shellhub-io/shellhub/pkg/api/client" + log "github.com/sirupsen/logrus" ) -type Tunnel struct { - router *echo.Echo - srv *http.Server - HTTPProxyHandler func(e echo.Context) error - SSHHandler func(e echo.Context) error - SSHCloseHandler func(e echo.Context) error +type HandlerConstraint interface { + echo.HandlerFunc | HandlerFunc } -type Builder struct { - tunnel *Tunnel +type Tunnel[H HandlerConstraint] interface { + Handle(protocol string, handler H) + Listen(ctx context.Context, listener net.Listener) error + Close() error } -func NewBuilder() *Builder { - return &Builder{ - tunnel: NewTunnel(), +type TunnelV2 struct { + mux *multistream.MultistreamMuxer[string] + cli client.Client + listener net.Listener +} + +func NewTunnelV2(cli client.Client) Tunnel[HandlerFunc] { + return &TunnelV2{ + mux: multistream.NewMultistreamMuxer[string](), + cli: cli, } } -func (t *Builder) WithHTTPProxyHandler(handler func(e echo.Context) error) *Builder { - t.tunnel.HTTPProxyHandler = handler +func (t *TunnelV2) Handle(protocol string, handler HandlerFunc) { + t.mux.AddHandler(protocol, func(protocol string, rwc io.ReadWriteCloser) error { + log.WithField("protocol", protocol).Debug("handling connection") + defer log.WithField("protocol", protocol).Debug("handling connection closed") - return t + // TODO: Should we receive a context from outside? + return handler(NewContext(context.TODO(), rwc), rwc) + }) } -func (t *Builder) WithSSHHandler(handler func(e echo.Context) error) *Builder { - t.tunnel.SSHHandler = handler +func (t *TunnelV2) Listen(ctx context.Context, listener net.Listener) error { + t.listener = listener - return t -} + for { + stream, err := listener.Accept() + if err != nil { + log.WithError(err).Trace("failed to accept stream") -func (t *Builder) WithSSHCloseHandler(handler func(e echo.Context) error) *Builder { - t.tunnel.SSHCloseHandler = handler + switch { + case websocket.IsCloseError(err, websocket.CloseAbnormalClosure): + return errors.Join(ErrTunnelDisconnect, err) + } - return t + return err + } + + log.Trace("new stream accepted") + + go func() { + log.Trace("handling stream") + + if err := t.mux.Handle(stream); err != nil { + log.WithError(err).Trace("failed to handle stream") + + _ = stream.Close() + } + + log.Trace("stream handled") + }() + } +} + +// Close implements Tunnel. +func (t *TunnelV2) Close() error { + return t.listener.Close() } -func (t *Builder) Build() *Tunnel { - return t.tunnel +// ErrTunnelDisconnect is returned when the tunnel connection is closed. +var ErrTunnelDisconnect = errors.New("tunnel disconnected") + +type TunnelV1 struct { + router *echo.Echo + srv *http.Server } -func NewTunnel() *Tunnel { +func NewTunnelV1() *TunnelV1 { e := echo.New() - t := &Tunnel{ + t := &TunnelV1{ router: e, srv: &http.Server{ //nolint:gosec Handler: e, @@ -60,41 +104,30 @@ func NewTunnel() *Tunnel { return context.WithValue(ctx, "http-conn", c) //nolint:revive }, }, - SSHHandler: func(_ echo.Context) error { - panic("ConnHandler can not be nil") - }, - SSHCloseHandler: func(_ echo.Context) error { - panic("CloseHandler can not be nil") - }, - HTTPProxyHandler: func(_ echo.Context) error { - panic("ProxyHandler can not be nil") - }, } - e.GET("/ssh/:id", func(e echo.Context) error { - return t.SSHHandler(e) - }) - e.GET("/ssh/close/:id", func(e echo.Context) error { - return t.SSHCloseHandler(e) - }) - e.CONNECT("/http/proxy/:addr", func(e echo.Context) error { - // NOTE: The CONNECT HTTP method requests that a proxy establish a HTTP tunnel to this server, and if - // successful, blindly forward data in both directions until the tunnel is closed. - // - // https://en.wikipedia.org/wiki/HTTP_tunnel - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/CONNECT - return t.HTTPProxyHandler(e) - }) return t } -// Listen to reverse listener. -func (t *Tunnel) Listen(l *revdial.Listener) error { - return t.srv.Serve(l) +func (t *TunnelV1) Handle(protocol string, handler echo.HandlerFunc) { + parts := strings.SplitN(protocol, "://", 2) + + method := parts[0] + path := parts[1] + + t.router.Add(method, path, func(c echo.Context) error { + log.WithField("protocol", protocol).Debug("handling connection") + defer log.WithField("protocol", protocol).Debug("handling connection closed") + + return handler(c) + }) +} + +func (t *TunnelV1) Listen(ctx context.Context, listener net.Listener) error { + return t.srv.Serve(listener) } -// Close closes the tunnel. -func (t *Tunnel) Close() error { +func (t *TunnelV1) Close() error { if err := t.router.Close(); err != nil { return err } diff --git a/docker-compose.agent.yml b/docker-compose.agent.yml index f63953f8e50..8222671eabf 100644 --- a/docker-compose.agent.yml +++ b/docker-compose.agent.yml @@ -22,6 +22,7 @@ services: - SHELLHUB_LOG_LEVEL=${SHELLHUB_LOG_LEVEL} - SHELLHUB_LOG_FORMAT=${SHELLHUB_LOG_FORMAT} - SHELLHUB_PERMIT_EMPTY_PASSWORDS=${SHELLHUB_PERMIT_EMPTY_PASSWORDS} + - SHELLHUB_CONNECTION_VERSION=2 volumes: - ./agent:/go/src/github.com/shellhub-io/shellhub/agent - ./pkg:/go/src/github.com/shellhub-io/shellhub/pkg diff --git a/gateway/nginx/conf.d/shellhub.conf b/gateway/nginx/conf.d/shellhub.conf index 3bb8781982d..57192b993b4 100644 --- a/gateway/nginx/conf.d/shellhub.conf +++ b/gateway/nginx/conf.d/shellhub.conf @@ -368,6 +368,29 @@ server { proxy_redirect off; } + location /agent/connection { + set $upstream ssh:8080; + + auth_request /auth; + auth_request_set $tenant_id $upstream_http_x_tenant_id; + auth_request_set $device_uid $upstream_http_x_device_uid; + proxy_pass http://$upstream; + proxy_set_header Connection $connection_upgrade; + proxy_set_header Host $host; + proxy_set_header Upgrade $http_upgrade; + {{ if $cfg.EnableProxyProtocol -}} + proxy_set_header X-Real-IP $proxy_protocol_addr; + {{ else -}} + proxy_set_header X-Real-IP $x_real_ip; + {{ end -}} + proxy_set_header X-Device-UID $device_uid; + proxy_set_header X-Tenant-ID $tenant_id; + proxy_set_header X-Request-ID $request_id; + proxy_http_version 1.1; + proxy_cache_bypass $http_upgrade; + proxy_redirect off; + } + location /ssh/auth { set $upstream api:8080; diff --git a/go.mod b/go.mod index a22f0d38c11..15a258db73a 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.2 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 + github.com/hashicorp/yamux v0.1.2 github.com/hibiken/asynq v0.24.1 github.com/jarcoal/httpmock v1.3.1 github.com/labstack/echo/v4 v4.10.2 diff --git a/go.sum b/go.sum index 7034a5d08e0..f6a31a6df86 100644 --- a/go.sum +++ b/go.sum @@ -178,6 +178,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= +github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/hibiken/asynq v0.24.1 h1:+5iIEAyA9K/lcSPvx3qoPtsKJeKI5u9aOIvUmSsazEw= github.com/hibiken/asynq v0.24.1/go.mod h1:u5qVeSbrnfT+vtG5Mq8ZPzQu/BmCKMHvTGb91uy9Tts= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/pkg/api/client/client.go b/pkg/api/client/client.go index 411b7d36e3e..b7c41098ce8 100644 --- a/pkg/api/client/client.go +++ b/pkg/api/client/client.go @@ -13,8 +13,8 @@ import ( "time" resty "github.com/go-resty/resty/v2" + "github.com/shellhub-io/shellhub/pkg/api/client/reverser" "github.com/shellhub-io/shellhub/pkg/models" - "github.com/shellhub-io/shellhub/pkg/revdial" log "github.com/sirupsen/logrus" ) @@ -28,7 +28,12 @@ type publicAPI interface { Endpoints() (*models.Endpoints, error) AuthDevice(req *models.DeviceAuthRequest) (*models.DeviceAuthResponse, error) AuthPublicKey(req *models.PublicKeyAuthRequest, token string) (*models.PublicKeyAuthResponse, error) - NewReverseListener(ctx context.Context, token string, connPath string) (*revdial.Listener, error) + // NewReverseListener creates a new reverse listener to be used by the Agent to connect to ShellHub's SSH server + // using RevDial protocol. + NewReverseListenerV1(ctx context.Context, token string, path string) (net.Listener, error) + // NewReverseListenerV2 creates a new reverse listener to be used by the Agent to connect to ShellHub's SSH server + // using Yamux protocol. + NewReverseListenerV2(ctx context.Context, token string, path string, cfg *ReverseListenerV2Config) (net.Listener, error) } //go:generate mockery --name=Client --filename=client.go @@ -44,7 +49,7 @@ type client struct { http *resty.Client logger *log.Logger // reverser is used to create a reverse listener to Agent from ShellHub's SSH server. - reverser IReverser + reverser reverser.Reverser } var ErrParseAddress = fmt.Errorf("could not parse the address to the required format") diff --git a/pkg/api/client/client_public.go b/pkg/api/client/client_public.go index 9b8a8bd6a11..b194667e317 100644 --- a/pkg/api/client/client_public.go +++ b/pkg/api/client/client_public.go @@ -3,10 +3,16 @@ package client import ( "context" "errors" + "net" + "net/http" + "net/url" + "os" + "time" resty "github.com/go-resty/resty/v2" + "github.com/hashicorp/yamux" "github.com/shellhub-io/shellhub/pkg/models" - "github.com/shellhub-io/shellhub/pkg/revdial" + "github.com/shellhub-io/shellhub/pkg/wsconnadapter" log "github.com/sirupsen/logrus" ) @@ -105,14 +111,164 @@ func (c *client) AuthPublicKey(req *models.PublicKeyAuthRequest, token string) ( // NewReverseListener creates a new reverse listener connection to ShellHub's server. This listener receives the SSH // requests coming from the ShellHub server. Only authenticated devices can obtain a listener connection. -func (c *client) NewReverseListener(ctx context.Context, token string, connPath string) (*revdial.Listener, error) { +func (c *client) NewReverseListenerV1(ctx context.Context, token string, path string) (net.Listener, error) { if token == "" { return nil, errors.New("token is empty") } - if err := c.reverser.Auth(ctx, token, connPath); err != nil { + if err := c.reverser.Auth(ctx, token, path); err != nil { return nil, err } return c.reverser.NewListener() } + +type ReverseListenerV2Config struct { + // AcceptBacklog is used to limit how many streams may be + // waiting an accept. + AcceptBacklog int `json:"yamux_accept_backlog"` + + // EnableKeepalive is used to do a period keep alive + // messages using a ping. + EnableKeepAlive bool `json:"yamux_enable_keep_alive"` + + // KeepAliveInterval is how often to perform the keep alive + KeepAliveInterval time.Duration `json:"yamux_keep_alive_interval"` + + // ConnectionWriteTimeout is meant to be a "safety valve" timeout after + // we which will suspect a problem with the underlying connection and + // close it. This is only applied to writes, where's there's generally + // an expectation that things will move along quickly. + ConnectionWriteTimeout time.Duration `json:"yamux_connection_write_timeout"` + + // MaxStreamWindowSize is used to control the maximum + // window size that we allow for a stream. + MaxStreamWindowSize uint32 `json:"yamux_max_stream_window_size"` + + // StreamOpenTimeout is the maximum amount of time that a stream will + // be allowed to remain in pending state while waiting for an ack from the peer. + // Once the timeout is reached the session will be gracefully closed. + // A zero value disables the StreamOpenTimeout allowing unbounded + // blocking on OpenStream calls. + StreamOpenTimeout time.Duration `json:"yamux_stream_open_timeout"` + + // StreamCloseTimeout is the maximum time that a stream will allowed to + // be in a half-closed state when `Close` is called before forcibly + // closing the connection. Forcibly closed connections will empty the + // receive buffer, drop any future packets received for that stream, + // and send a RST to the remote side. + StreamCloseTimeout time.Duration `json:"yamux_stream_close_timeout"` +} + +var DefaultReverseListenerV2Config = ReverseListenerV2Config{ + AcceptBacklog: 256, + EnableKeepAlive: true, + KeepAliveInterval: 35 * time.Second, + ConnectionWriteTimeout: 15 * time.Second, + MaxStreamWindowSize: 256 * 1024, + StreamCloseTimeout: 5 * time.Minute, + StreamOpenTimeout: 75 * time.Second, +} + +// NewReverseV2ConfigFromMap creates a new Config from a map[string]any received from auth data from the server +// or returns the default config if the map is nil. If a key is missing, the default value is used. +func NewReverseV2ConfigFromMap(m map[string]any) *ReverseListenerV2Config { + cfg := DefaultReverseListenerV2Config + + if v, ok := m["yamux_accept_backlog"].(int); ok { + cfg.AcceptBacklog = v + } + + if v, ok := m["yamux_enable_keep_alive"].(bool); ok { + cfg.EnableKeepAlive = v + } + + if v, ok := m["yamux_keep_alive_interval"].(time.Duration); ok { + cfg.KeepAliveInterval = v + } + + if v, ok := m["yamux_connection_write_timeout"].(time.Duration); ok { + cfg.ConnectionWriteTimeout = v + } + + if v, ok := m["yamux_max_stream_window_size"].(uint32); ok { + cfg.MaxStreamWindowSize = v + } + + if v, ok := m["yamux_stream_open_timeout"].(time.Duration); ok { + cfg.StreamOpenTimeout = v + } + + if v, ok := m["yamux_stream_close_timeout"].(time.Duration); ok { + cfg.StreamCloseTimeout = v + } + + return &cfg +} + +func YamuxConfigFromReverseListenerV2(cfg *ReverseListenerV2Config) *yamux.Config { + if cfg == nil { + cfg = &DefaultReverseListenerV2Config + } + + return &yamux.Config{ + AcceptBacklog: cfg.AcceptBacklog, + EnableKeepAlive: cfg.EnableKeepAlive, + KeepAliveInterval: cfg.KeepAliveInterval, + ConnectionWriteTimeout: cfg.ConnectionWriteTimeout, + MaxStreamWindowSize: cfg.MaxStreamWindowSize, + StreamCloseTimeout: cfg.StreamCloseTimeout, + StreamOpenTimeout: cfg.StreamOpenTimeout, + // NOTE: LogOutput is required, and without it yamux will failed to create the session. + LogOutput: os.Stderr, + } +} + +func (c *client) NewReverseListenerV2(ctx context.Context, token string, path string, cfg *ReverseListenerV2Config) (net.Listener, error) { + if token == "" { + return nil, errors.New("token is empty") + } + + u, err := url.JoinPath(c.http.BaseURL, path) + if err != nil { + return nil, err + } + + wsconn, _, err := DialContext(ctx, u, http.Header{ + "Authorization": []string{"Bearer " + token}, + }) + if err != nil { + return nil, err + } + + var listener *yamux.Session + + conn := wsconnadapter.New(wsconn) + + listener, err = yamux.Server(conn, YamuxConfigFromReverseListenerV2(cfg)) + if err != nil { + log.WithError(err).WithFields(log.Fields{ + "accept_backlog": cfg.AcceptBacklog, + "enable_keep_alive": cfg.EnableKeepAlive, + "keep_alive_interval": cfg.KeepAliveInterval, + "connection_write_timeout": cfg.ConnectionWriteTimeout, + "max_stream_window_size": cfg.MaxStreamWindowSize, + "stream_close_timeout": cfg.StreamCloseTimeout, + "stream_open_timeout": cfg.StreamOpenTimeout, + }).Error("failed to create muxed session") + + // NOTE: If we fail to create the session, we should try again with the [DefaultConfig] as the client + // could be using different settings. + log.WithError(err).Warning("trying to create muxed session with default config") + listener, err = yamux.Server(conn, YamuxConfigFromReverseListenerV2(&DefaultReverseListenerV2Config)) + if err != nil { + log.WithError(err).Error("failed to create muxed session with default config") + + return nil, err + } + + log.WithError(err).Warning("muxed session created with default config due to error with custom config") + } + + return listener, err +} diff --git a/pkg/api/client/client_public_test.go b/pkg/api/client/client_public_test.go index d9ca125cd32..bf991554048 100644 --- a/pkg/api/client/client_public_test.go +++ b/pkg/api/client/client_public_test.go @@ -9,7 +9,7 @@ import ( "testing" mock "github.com/jarcoal/httpmock" - reversermock "github.com/shellhub-io/shellhub/pkg/api/client/mocks" + reversermock "github.com/shellhub-io/shellhub/pkg/api/client/reverser/mocks" "github.com/shellhub-io/shellhub/pkg/models" "github.com/shellhub-io/shellhub/pkg/revdial" "github.com/stretchr/testify/assert" @@ -415,7 +415,7 @@ func TestAuthPublicKey(t *testing.T) { } func TestReverseListener(t *testing.T) { - mock := new(reversermock.IReverser) + mock := new(reversermock.Reverser) tests := []struct { description string @@ -468,7 +468,7 @@ func TestReverseListener(t *testing.T) { test.requiredMocks() - _, err = cli.NewReverseListener(ctx, test.token, "") + _, err = cli.NewReverseListenerV1(ctx, test.token, "") assert.Equal(t, err, test.expected) }) } diff --git a/pkg/api/client/mocks/client.go b/pkg/api/client/mocks/client.go index 4730a4a493d..eb6959bb9e6 100644 --- a/pkg/api/client/mocks/client.go +++ b/pkg/api/client/mocks/client.go @@ -1,14 +1,17 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package mocks import ( context "context" - models "github.com/shellhub-io/shellhub/pkg/models" + client "github.com/shellhub-io/shellhub/pkg/api/client" + mock "github.com/stretchr/testify/mock" - revdial "github.com/shellhub-io/shellhub/pkg/revdial" + models "github.com/shellhub-io/shellhub/pkg/models" + + net "net" ) // Client is an autogenerated mock type for the Client type @@ -20,6 +23,10 @@ type Client struct { func (_m *Client) AuthDevice(req *models.DeviceAuthRequest) (*models.DeviceAuthResponse, error) { ret := _m.Called(req) + if len(ret) == 0 { + panic("no return value specified for AuthDevice") + } + var r0 *models.DeviceAuthResponse var r1 error if rf, ok := ret.Get(0).(func(*models.DeviceAuthRequest) (*models.DeviceAuthResponse, error)); ok { @@ -46,6 +53,10 @@ func (_m *Client) AuthDevice(req *models.DeviceAuthRequest) (*models.DeviceAuthR func (_m *Client) AuthPublicKey(req *models.PublicKeyAuthRequest, token string) (*models.PublicKeyAuthResponse, error) { ret := _m.Called(req, token) + if len(ret) == 0 { + panic("no return value specified for AuthPublicKey") + } + var r0 *models.PublicKeyAuthResponse var r1 error if rf, ok := ret.Get(0).(func(*models.PublicKeyAuthRequest, string) (*models.PublicKeyAuthResponse, error)); ok { @@ -68,10 +79,14 @@ func (_m *Client) AuthPublicKey(req *models.PublicKeyAuthRequest, token string) return r0, r1 } -// Endpoints provides a mock function with given fields: +// Endpoints provides a mock function with no fields func (_m *Client) Endpoints() (*models.Endpoints, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for Endpoints") + } + var r0 *models.Endpoints var r1 error if rf, ok := ret.Get(0).(func() (*models.Endpoints, error)); ok { @@ -98,6 +113,10 @@ func (_m *Client) Endpoints() (*models.Endpoints, error) { func (_m *Client) GetDevice(uid string) (*models.Device, error) { ret := _m.Called(uid) + if len(ret) == 0 { + panic("no return value specified for GetDevice") + } + var r0 *models.Device var r1 error if rf, ok := ret.Get(0).(func(string) (*models.Device, error)); ok { @@ -124,6 +143,10 @@ func (_m *Client) GetDevice(uid string) (*models.Device, error) { func (_m *Client) GetInfo(agentVersion string) (*models.Info, error) { ret := _m.Called(agentVersion) + if len(ret) == 0 { + panic("no return value specified for GetInfo") + } + var r0 *models.Info var r1 error if rf, ok := ret.Get(0).(func(string) (*models.Info, error)); ok { @@ -146,10 +169,14 @@ func (_m *Client) GetInfo(agentVersion string) (*models.Info, error) { return r0, r1 } -// ListDevices provides a mock function with given fields: +// ListDevices provides a mock function with no fields func (_m *Client) ListDevices() ([]models.Device, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for ListDevices") + } + var r0 []models.Device var r1 error if rf, ok := ret.Get(0).(func() ([]models.Device, error)); ok { @@ -172,25 +199,29 @@ func (_m *Client) ListDevices() ([]models.Device, error) { return r0, r1 } -// NewReverseListener provides a mock function with given fields: ctx, token, connPath -func (_m *Client) NewReverseListener(ctx context.Context, token string, connPath string) (*revdial.Listener, error) { - ret := _m.Called(ctx, token, connPath) +// NewReverseListenerV1 provides a mock function with given fields: ctx, token, path +func (_m *Client) NewReverseListenerV1(ctx context.Context, token string, path string) (net.Listener, error) { + ret := _m.Called(ctx, token, path) + + if len(ret) == 0 { + panic("no return value specified for NewReverseListenerV1") + } - var r0 *revdial.Listener + var r0 net.Listener var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (*revdial.Listener, error)); ok { - return rf(ctx, token, connPath) + if rf, ok := ret.Get(0).(func(context.Context, string, string) (net.Listener, error)); ok { + return rf(ctx, token, path) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) *revdial.Listener); ok { - r0 = rf(ctx, token, connPath) + if rf, ok := ret.Get(0).(func(context.Context, string, string) net.Listener); ok { + r0 = rf(ctx, token, path) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*revdial.Listener) + r0 = ret.Get(0).(net.Listener) } } if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, token, connPath) + r1 = rf(ctx, token, path) } else { r1 = ret.Error(1) } @@ -198,13 +229,42 @@ func (_m *Client) NewReverseListener(ctx context.Context, token string, connPath return r0, r1 } -type mockConstructorTestingTNewClient interface { - mock.TestingT - Cleanup(func()) +// NewReverseListenerV2 provides a mock function with given fields: ctx, token, path, cfg +func (_m *Client) NewReverseListenerV2(ctx context.Context, token string, path string, cfg *client.ReverseListenerV2Config) (net.Listener, error) { + ret := _m.Called(ctx, token, path, cfg) + + if len(ret) == 0 { + panic("no return value specified for NewReverseListenerV2") + } + + var r0 net.Listener + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, *client.ReverseListenerV2Config) (net.Listener, error)); ok { + return rf(ctx, token, path, cfg) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, *client.ReverseListenerV2Config) net.Listener); ok { + r0 = rf(ctx, token, path, cfg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Listener) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, *client.ReverseListenerV2Config) error); ok { + r1 = rf(ctx, token, path, cfg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewClient(t mockConstructorTestingTNewClient) *Client { +// The first argument is typically a *testing.T value. +func NewClient(t interface { + mock.TestingT + Cleanup(func()) +}) *Client { mock := &Client{} mock.Mock.Test(t) diff --git a/pkg/api/client/options.go b/pkg/api/client/options.go index ac25421311b..cb693d8094b 100644 --- a/pkg/api/client/options.go +++ b/pkg/api/client/options.go @@ -4,6 +4,7 @@ import ( "net/url" "strconv" + "github.com/shellhub-io/shellhub/pkg/api/client/reverser" "github.com/sirupsen/logrus" ) @@ -66,7 +67,7 @@ func WithLogger(logger *logrus.Logger) Opt { } } -func WithReverser(reverser IReverser) Opt { +func WithReverser(reverser reverser.Reverser) Opt { return func(c *client) error { c.reverser = reverser diff --git a/pkg/api/client/reverser.go b/pkg/api/client/reverser.go index 992396f01e3..9c2cb674ca5 100644 --- a/pkg/api/client/reverser.go +++ b/pkg/api/client/reverser.go @@ -8,16 +8,11 @@ import ( "net/url" "github.com/gorilla/websocket" + "github.com/shellhub-io/shellhub/pkg/api/client/reverser" "github.com/shellhub-io/shellhub/pkg/revdial" "github.com/shellhub-io/shellhub/pkg/wsconnadapter" ) -//go:generate mockery --name=IReverser --filename=reverser.go -type IReverser interface { - Auth(ctx context.Context, token string, connPath string) error - NewListener() (*revdial.Listener, error) -} - type Reverser struct { conn *websocket.Conn // host is the ShellHub's server address. @@ -26,7 +21,7 @@ type Reverser struct { host string } -var _ IReverser = new(Reverser) +var _ reverser.Reverser = new(Reverser) func NewReverser(host string) *Reverser { return &Reverser{ diff --git a/pkg/api/client/mocks/reverser.go b/pkg/api/client/reverser/mocks/reverser.go similarity index 55% rename from pkg/api/client/mocks/reverser.go rename to pkg/api/client/reverser/mocks/reverser.go index 60a1289f126..a0b8d3aba44 100644 --- a/pkg/api/client/mocks/reverser.go +++ b/pkg/api/client/reverser/mocks/reverser.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. +// Code generated by mockery v2.53.3. DO NOT EDIT. package mocks @@ -9,15 +9,19 @@ import ( mock "github.com/stretchr/testify/mock" ) -// IReverser is an autogenerated mock type for the IReverser type -type IReverser struct { +// Reverser is an autogenerated mock type for the Reverser type +type Reverser struct { mock.Mock } // Auth provides a mock function with given fields: ctx, token, connPath -func (_m *IReverser) Auth(ctx context.Context, token string, connPath string) error { +func (_m *Reverser) Auth(ctx context.Context, token string, connPath string) error { ret := _m.Called(ctx, token, connPath) + if len(ret) == 0 { + panic("no return value specified for Auth") + } + var r0 error if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { r0 = rf(ctx, token, connPath) @@ -28,10 +32,14 @@ func (_m *IReverser) Auth(ctx context.Context, token string, connPath string) er return r0 } -// NewListener provides a mock function with given fields: -func (_m *IReverser) NewListener() (*revdial.Listener, error) { +// NewListener provides a mock function with no fields +func (_m *Reverser) NewListener() (*revdial.Listener, error) { ret := _m.Called() + if len(ret) == 0 { + panic("no return value specified for NewListener") + } + var r0 *revdial.Listener var r1 error if rf, ok := ret.Get(0).(func() (*revdial.Listener, error)); ok { @@ -54,14 +62,13 @@ func (_m *IReverser) NewListener() (*revdial.Listener, error) { return r0, r1 } -type mockConstructorTestingTNewIReverser interface { +// NewReverser creates a new instance of Reverser. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewReverser(t interface { mock.TestingT Cleanup(func()) -} - -// NewIReverser creates a new instance of IReverser. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewIReverser(t mockConstructorTestingTNewIReverser) *IReverser { - mock := &IReverser{} +}) *Reverser { + mock := &Reverser{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/pkg/api/client/reverser/reverser.go b/pkg/api/client/reverser/reverser.go new file mode 100644 index 00000000000..13724de8abb --- /dev/null +++ b/pkg/api/client/reverser/reverser.go @@ -0,0 +1,13 @@ +package reverser + +import ( + "context" + + "github.com/shellhub-io/shellhub/pkg/revdial" +) + +//go:generate mockery --name=Reverser --filename=reverser.go +type Reverser interface { + Auth(ctx context.Context, token string, connPath string) error + NewListener() (*revdial.Listener, error) +} diff --git a/pkg/connman/connman.go b/pkg/connman/connman.go deleted file mode 100644 index 22c17ef3b8d..00000000000 --- a/pkg/connman/connman.go +++ /dev/null @@ -1,77 +0,0 @@ -package connman - -import ( - "context" - "errors" - "net" - - "github.com/shellhub-io/shellhub/pkg/revdial" - "github.com/shellhub-io/shellhub/pkg/wsconnadapter" - log "github.com/sirupsen/logrus" -) - -var ErrNoConnection = errors.New("no connection") - -type ConnectionManager struct { - dialers *SyncSliceMap - DialerDoneCallback func(string, *revdial.Dialer) - DialerKeepAliveCallback func(string, *revdial.Dialer) -} - -func New() *ConnectionManager { - return &ConnectionManager{ - dialers: &SyncSliceMap{}, - DialerDoneCallback: func(string, *revdial.Dialer) { - }, - } -} - -func (m *ConnectionManager) Set(key string, conn *wsconnadapter.Adapter, connPath string) { - dialer := revdial.NewDialer(conn.Logger, conn, connPath) - - m.dialers.Store(key, dialer) - - if size := m.dialers.Size(key); size > 1 { - log.WithFields(log.Fields{ - "key": key, - "size": size, - }).Warning("Multiple connections stored for the same identifier.") - } - - m.DialerKeepAliveCallback(key, dialer) - - // Start the ping loop and get the channel for pong responses - pong := conn.Ping() - - go func() { - for { - select { - case <-pong: - m.DialerKeepAliveCallback(key, dialer) - - continue - case <-dialer.Done(): - m.dialers.Delete(key, dialer) - m.DialerDoneCallback(key, dialer) - - return - } - } - }() -} - -func (m *ConnectionManager) Dial(ctx context.Context, key string) (net.Conn, error) { - dialer, ok := m.dialers.Load(key) - if !ok { - return nil, ErrNoConnection - } - - if size := m.dialers.Size(key); size > 1 { - log.WithFields(log.Fields{ - "key": key, - "size": size, - }).Warning("Multiple connections found for the same identifier during reverse tunnel dialing.") - } - - return dialer.(*revdial.Dialer).Dial(ctx) -} diff --git a/pkg/httptunnel/httptunnel.go b/pkg/httptunnel/httptunnel.go deleted file mode 100644 index 6c1146f32e1..00000000000 --- a/pkg/httptunnel/httptunnel.go +++ /dev/null @@ -1,140 +0,0 @@ -package httptunnel - -import ( - "bufio" - "context" - "io" - "net" - "net/http" - "strings" - - "github.com/gorilla/websocket" - "github.com/labstack/echo/v4" - "github.com/shellhub-io/shellhub/pkg/connman" - "github.com/shellhub-io/shellhub/pkg/revdial" - "github.com/shellhub-io/shellhub/pkg/wsconnadapter" -) - -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - Subprotocols: []string{"binary"}, - CheckOrigin: func(_ *http.Request) bool { - return true - }, -} - -const ( - DefaultConnectionURL = "/connection" - DefaultRevdialURL = "/revdial" -) - -type Tunnel struct { - ConnectionPath string - DialerPath string - ConnectionHandler func(*http.Request) (string, error) - CloseHandler func(string) - KeepAliveHandler func(string) - connman *connman.ConnectionManager - id chan string - online chan bool -} - -func NewTunnel(connectionPath, dialerPath string) *Tunnel { - tunnel := &Tunnel{ - ConnectionPath: connectionPath, - DialerPath: dialerPath, - ConnectionHandler: func(_ *http.Request) (string, error) { - panic("ConnectionHandler not implemented") - }, - CloseHandler: func(string) { - }, - KeepAliveHandler: func(string) { - }, - connman: connman.New(), - id: make(chan string), - online: make(chan bool), - } - - tunnel.connman.DialerDoneCallback = func(id string, _ *revdial.Dialer) { - tunnel.CloseHandler(id) - } - - tunnel.connman.DialerKeepAliveCallback = func(id string, _ *revdial.Dialer) { - tunnel.KeepAliveHandler(id) - } - - return tunnel -} - -func (t *Tunnel) Router() http.Handler { - e := echo.New() - - e.GET(t.ConnectionPath, func(c echo.Context) error { - conn, err := upgrader.Upgrade(c.Response(), c.Request(), nil) - if err != nil { - return c.String(http.StatusInternalServerError, err.Error()) - } - - key, err := t.ConnectionHandler(c.Request()) - if err != nil { - conn.Close() - - return c.String(http.StatusBadRequest, err.Error()) - } - - requestID := c.Request().Header.Get("X-Request-ID") - parts := strings.Split(key, ":") - tenant := parts[0] - device := parts[1] - - t.connman.Set( - key, - wsconnadapter. - New(conn). - WithID(requestID). - WithDevice(tenant, device), - t.DialerPath, - ) - - return nil - }) - - e.GET(t.DialerPath, echo.WrapHandler(revdial.ConnHandler(upgrader))) - - return e -} - -func (t *Tunnel) Dial(ctx context.Context, id string) (net.Conn, error) { - return t.connman.Dial(ctx, id) -} - -func (t *Tunnel) SendRequest(ctx context.Context, id string, req *http.Request) (*http.Response, error) { - conn, err := t.connman.Dial(ctx, id) - if err != nil { - return nil, err - } - - if err := req.Write(conn); err != nil { - return nil, err - } - - resp, err := http.ReadResponse(bufio.NewReader(conn), req) - if err != nil { - return nil, err - } - - return resp, nil -} - -func (t *Tunnel) ForwardResponse(resp *http.Response, w http.ResponseWriter) { - for key, values := range resp.Header { - for _, value := range values { - w.Header().Add(key, value) - } - } - - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) // nolint:errcheck - resp.Body.Close() -} diff --git a/pkg/models/device.go b/pkg/models/device.go index 6af17a0dd1e..063fcb42a5d 100644 --- a/pkg/models/device.go +++ b/pkg/models/device.go @@ -70,6 +70,12 @@ type DeviceAuthResponse struct { Token string `json:"token"` Name string `json:"name"` Namespace string `json:"namespace"` + // Config holds device-specific configuration settings. + // This can include various parameters that the device needs to operate correctly. + // The structure of this map can vary depending on the device type and its requirements. + // Example configurations might include network settings, operational modes, or feature toggles. + // It's designed to be flexible to accommodate different device needs. + Config map[string]any `json:"config,omitempty"` } type DeviceIdentity struct { diff --git a/pkg/wsconnadapter/wsconnadapter.go b/pkg/wsconnadapter/wsconnadapter.go index b9366122ef5..e43c481ea97 100644 --- a/pkg/wsconnadapter/wsconnadapter.go +++ b/pkg/wsconnadapter/wsconnadapter.go @@ -35,24 +35,24 @@ type Adapter struct { CreatedAt time.Time } -func (a *Adapter) WithID(requestID string) *Adapter { - a.Logger = a.Logger.WithFields(log.Fields{ - "request-id": requestID, - }) +type Option func(*Adapter) - return a +func WithID(id string) Option { + return func(a *Adapter) { + a.UUID = id + } } -func (a *Adapter) WithDevice(tenant string, device string) *Adapter { - a.Logger = a.Logger.WithFields(log.Fields{ - "tenant": tenant, - "device": device, - }) - - return a +func WithDevice(tenant string, device string) Option { + return func(a *Adapter) { + a.Logger = a.Logger.WithFields(log.Fields{ + "tenant": tenant, + "device": device, + }) + } } -func New(conn *websocket.Conn) *Adapter { +func New(conn *websocket.Conn, options ...Option) *Adapter { adapter := &Adapter{ conn: conn, Logger: log.NewEntry(&log.Logger{ @@ -64,6 +64,10 @@ func New(conn *websocket.Conn) *Adapter { CreatedAt: clock.Now(), } + for _, option := range options { + option(adapter) + } + return adapter } diff --git a/ssh/go.mod b/ssh/go.mod index 53d6233fcee..d5a4ef7ff7b 100644 --- a/ssh/go.mod +++ b/ssh/go.mod @@ -7,22 +7,28 @@ require ( github.com/gliderlabs/ssh v0.3.8 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/gorilla/websocket v1.5.3 + github.com/hashicorp/yamux v0.1.2 github.com/labstack/echo-contrib v0.17.4 github.com/labstack/echo/v4 v4.13.4 + github.com/multiformats/go-multistream v0.6.1 github.com/pires/go-proxyproto v0.8.0 github.com/shellhub-io/shellhub v0.13.4 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.43.0 golang.org/x/net v0.46.0 + golang.org/x/time v0.11.0 ) require ( + github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect + github.com/Microsoft/hcsshim v0.12.2 // indirect github.com/adhocore/gronx v1.8.1 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.11.2 // indirect @@ -30,28 +36,34 @@ require ( github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-resty/resty/v2 v2.11.0 // indirect github.com/golang/protobuf v1.5.4 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hibiken/asynq v0.24.1 // indirect - github.com/klauspost/compress v1.17.4 // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/labstack/gommon v0.4.2 // indirect github.com/leodido/go-urn v1.2.2 // indirect + github.com/lufia/plan9stats v0.0.0-20240408141607-282e7b5d6b74 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/multiformats/go-varint v0.0.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/redis/go-redis/v9 v9.0.3 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sethvargo/go-envconfig v0.9.0 // indirect + github.com/shirou/gopsutil/v3 v3.24.3 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/tklauser/go-sysconf v0.3.13 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect github.com/vmihailenco/go-tinylfu v0.2.2 // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + go.uber.org/goleak v1.3.0 // indirect golang.org/x/sync v0.17.0 // indirect golang.org/x/sys v0.37.0 // indirect golang.org/x/text v0.30.0 // indirect - golang.org/x/time v0.11.0 // indirect google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/ssh/go.sum b/ssh/go.sum index b51042a3568..9f49980efb7 100644 --- a/ssh/go.sum +++ b/ssh/go.sum @@ -1,13 +1,13 @@ dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= -github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= -github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/Microsoft/hcsshim v0.11.7 h1:vl/nj3Bar/CvJSYo7gIQPyRWc9f3c6IeSNavBTSZNZQ= -github.com/Microsoft/hcsshim v0.11.7/go.mod h1:MV8xMfmECjl5HdO7U/3/hFVnkmSBjAjmA09d4bExKcU= +github.com/Microsoft/hcsshim v0.12.2 h1:AcXy+yfRvrx20g9v7qYaJv5Rh+8GaHOS6b8G6Wx/nKs= +github.com/Microsoft/hcsshim v0.12.2/go.mod h1:RZV12pcHCXQ42XnlQ3pz6FZfmrC1C+R4gaOHhRNML1g= github.com/adhocore/gronx v1.8.1 h1:F2mLTG5sB11z7vplwD4iydz3YCEjstSfYmCrdSm3t6A= github.com/adhocore/gronx v1.8.1/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= @@ -56,8 +56,9 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -94,19 +95,22 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= +github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/hibiken/asynq v0.24.1 h1:+5iIEAyA9K/lcSPvx3qoPtsKJeKI5u9aOIvUmSsazEw= github.com/hibiken/asynq v0.24.1/go.mod h1:u5qVeSbrnfT+vtG5Mq8ZPzQu/BmCKMHvTGb91uy9Tts= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= -github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= @@ -122,8 +126,9 @@ github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0 github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/leodido/go-urn v1.2.2 h1:7z68G0FCGvDk646jz1AelTYNYWrTNm0bEcFAo147wt4= github.com/leodido/go-urn v1.2.2/go.mod h1:kUaIbLZWttglzwNuG0pgsh5vuV6u2YcGBYz1hIPjtOQ= -github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= +github.com/lufia/plan9stats v0.0.0-20240408141607-282e7b5d6b74 h1:1KuuSOy4ZNgW0KA2oYIngXVFhQcXxhLqCVK7cBcldkk= +github.com/lufia/plan9stats v0.0.0-20240408141607-282e7b5d6b74/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -144,6 +149,10 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/multiformats/go-multistream v0.6.1 h1:4aoX5v6T+yWmc2raBHsTvzmFhOI8WVOer28DeBBEYdQ= +github.com/multiformats/go-multistream v0.6.1/go.mod h1:ksQf6kqHAb6zIsyw7Zm+gAuVo57Qbq84E27YlYqavqw= +github.com/multiformats/go-varint v0.0.6 h1:gk85QWKxh3TazbLxED/NlDVv8+q+ReFJk7Y2W/KhfNY= +github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -167,8 +176,9 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= +github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= @@ -179,10 +189,11 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po github.com/rwtodd/Go.Sed v0.0.0-20210816025313-55464686f9ef/go.mod h1:8AEUvGVi2uQ5b24BIhcr0GCcpd/RNAFWaN2CJFrWIIQ= github.com/sethvargo/go-envconfig v0.9.0 h1:Q6FQ6hVEeTECULvkJZakq3dZMeBQ3JUpcKMfPQbKMDE= github.com/sethvargo/go-envconfig v0.9.0/go.mod h1:Iz1Gy1Sf3T64TQlJSvee81qDhf7YIlt8GMUX6yyNFs0= -github.com/shirou/gopsutil/v3 v3.23.12 h1:z90NtUkp3bMtmICZKpC4+WaknU1eXtp5vtbQ11DgpE4= -github.com/shirou/gopsutil/v3 v3.23.12/go.mod h1:1FrWgea594Jp7qmjHUUPlJDTPgcsb9mGnXDxavtikzM= +github.com/shirou/gopsutil/v3 v3.24.3 h1:eoUGJSmdfLzJ3mxIhmOAhgKEKgQkeOwKpz1NbhVnuPE= +github.com/shirou/gopsutil/v3 v3.24.3/go.mod h1:JpND7O217xa72ewWz9zN2eIIkPWsDN/3pl0H8Qt0uwg= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= +github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= @@ -199,16 +210,20 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/testcontainers/testcontainers-go v0.32.0 h1:ug1aK08L3gCHdhknlTTwWjPHPS+/alvLJU/DRxTD/ME= github.com/testcontainers/testcontainers-go v0.32.0/go.mod h1:CRHrzHLQhlXUsa5gXjTOfqIEJcrK5+xMDmBr/WMI88E= github.com/testcontainers/testcontainers-go/modules/redis v0.32.0 h1:HW5Qo9qfLi5iwfS7cbXwG6qe8ybXGePcgGPEmVlVDlo= github.com/testcontainers/testcontainers-go/modules/redis v0.32.0/go.mod h1:5kltdxVKZG0aP1iegeqKz4K8HHyP0wbkW5o84qLyMjY= -github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= -github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= +github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4= +github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4= +github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= @@ -223,8 +238,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= -github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= @@ -235,8 +250,9 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= -go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -276,24 +292,30 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/ssh/http/handlers.go b/ssh/http/handlers.go new file mode 100644 index 00000000000..bc5615c01a6 --- /dev/null +++ b/ssh/http/handlers.go @@ -0,0 +1,302 @@ +package http + +import ( + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/labstack/echo/v4" + "github.com/shellhub-io/shellhub/pkg/api/internalclient" + "github.com/shellhub-io/shellhub/pkg/wsconnadapter" + "github.com/shellhub-io/shellhub/ssh/pkg/dialer" + log "github.com/sirupsen/logrus" +) + +type Handlers struct { + Config *Config + Dialer *dialer.Dialer + Client internalclient.Client +} + +const ( + // HandleSSHClosePath receives a request to close an existing SSH session. + HandleSSHClosePath = "/api/sessions/:uid/close" + // HandleHTTPProxyPath proxies an inbound HTTP request to a device's HTTP server. + HandleHTTPProxyPath = "/http/proxy" + // HandleHealthcheckPath is used for readiness/liveness checks. + HandleHealthcheckPath = "/healthcheck" +) + +const ( + // HandleConnectionV1Path is the connection endpoint where agents using revdial connects to establish + // a WebSocket connection. Each new logical session requires an extra reverse dial handshake. + HandleConnectionV1Path = "/ssh/connection" + // HandleConnectionV2Path is the connection endpoint where agents using yamux/multistream connects to + // establish a WebSocket connection. Subsequent logical streams are opened without additional HTTP + // handshakes and are protocol-negotiated via multistream-select. + HandleConnectionV2Path = "/agent/connection" +) + +const ( + // HandleRevdialPath is the reverse dial endpoint where agents using revdial requests a new logical + // session. + HandleRevdialPath = "/ssh/revdial" +) + +// HandleSSHClose receives a notification from the agent that an SSH +// session should be closed. It dials the device (choosing the correct +// transport version) and then performs the version-specific close +// sequence: HTTP GET for V1 or multistream + JSON payload for V2. +func (h *Handlers) HandleSSHClose(c echo.Context) error { + var data struct { + UID string `param:"uid"` + Device string `json:"device"` + } + + if err := c.Bind(&data); err != nil { + return err + } + + ctx := c.Request().Context() + + tenant := c.Request().Header.Get("X-Tenant-ID") + + if _, err := h.Dialer.DialTo(ctx, tenant, data.Device, dialer.SSHCloseTarget{SessionID: data.UID}); err != nil { + log.WithError(err).Error("failed to send ssh close message") + + return ErrDeviceTunnelDial + } + + return c.NoContent(http.StatusOK) +} + +// HandleHTTPProxy proxies an inbound HTTP request to a device's HTTP +// service exposed through the reverse tunnel/web endpoint feature. It +// supports both transport versions: +// - V1: issues a CONNECT prelude then performs a standard HTTP request over the established raw tunnel. +// - V2: negotiates the /http/proxy multistream protocol and exchanges a JSON envelope to set up the target host/port. +// +// The handler then hijacks the Echo response writer to stream data +// bidirectionally between client and device. +func (h *Handlers) HandleHTTPProxy(c echo.Context) error { + requestID := c.Request().Header.Get("X-Request-ID") + + address := c.Request().Header.Get("X-Address") + log.WithFields(log.Fields{ + "request-id": requestID, + "address": address, + }).Debug("address value") + + path := c.Request().Header.Get("X-Path") + log.WithFields(log.Fields{ + "request-id": requestID, + "address": address, + }).Debug("path") + + endpoint, err := h.Client.LookupWebEndpoints(c.Request().Context(), address) + if err != nil { + log.WithError(err).Error("failed to get the web endpoint") + + return c.JSON(http.StatusForbidden, NewMessageFromError(ErrWebEndpointForbidden)) + } + + logger := log.WithFields(log.Fields{ + "request-id": requestID, + "namespace": endpoint.Namespace, + "device": endpoint.DeviceUID, + }) + + conn, err := h.Dialer.DialTo(c.Request().Context(), endpoint.Namespace, endpoint.DeviceUID, dialer.HTTPProxyTarget{ + RequestID: requestID, + Host: endpoint.Host, + Port: endpoint.Port, + }) + if err != nil { + logger.WithError(err).Error("failed to dial to device") + + return c.JSON(http.StatusForbidden, NewMessageFromError(ErrDeviceTunnelDial)) + } + defer conn.Close() + + logger.Trace("new web endpoint connection initialized") + defer logger.Trace("web endpoint connection doned") + + req := c.Request() + req.Host = strings.Join([]string{address, h.Config.WebEndpointsDomain}, ".") + req.URL, err = url.Parse(path) + if err != nil { + logger.WithError(err).Error("failed to parse the path") + + return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelReadResponse)) + } + + if err := req.Write(conn); err != nil { + logger.WithError(err).Error("failed to write the request to the agent") + + return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelWriteRequest)) + } + + log.WithFields(log.Fields{ + "request-id": requestID, + "method": req.Method, + "url": req.URL.String(), + "host": req.Host, + "headers": req.Header, + }).Debug("request to device") + + ctr := http.NewResponseController(c.Response()) + out, _, err := ctr.Hijack() + if err != nil { + logger.WithError(err).Error("failed to hijack the http request") + + return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelHijackRequest)) + } + + defer out.Close() + + // Bidirectional copy between the client and the device. + var wg sync.WaitGroup + wg.Add(2) + + starTime := time.Now() + + go func() { + defer wg.Done() + + if _, err := io.Copy(conn, out); err != nil { + logger.WithError(err).Debug("in and out done returned a error") + } + + logger.Trace("in and out done") + }() + + go func() { + defer wg.Done() + + if _, err := io.Copy(out, conn); err != nil { + logger.WithError(err).Debug("out and in done returned a error") + } + + logger.Trace("out and in done") + }() + + wg.Wait() + + logger.WithFields(log.Fields{ + "duration": time.Since(starTime).String(), + }).Info("web endpoint request completed") + + return nil +} + +// HandleHealthcheck returns a simple 200 OK used for readiness/liveness +// checks. +func (h *Handlers) HandleHealthcheck(c echo.Context) error { + return c.String(http.StatusOK, "OK") +} + +// HandleConnectionV1 upgrades the HTTP connection to WebSocket and +// registers a legacy (V1) reverse dialer for the agent. Each new logical +// session requires an extra reverse dial handshake. +func (h *Handlers) HandleConnectionV1(c echo.Context) error { + conn, err := upgrader.Upgrade(c.Response(), c.Request(), nil) + if err != nil { + return c.String(http.StatusInternalServerError, err.Error()) + } + + requestID := c.Request().Header.Get("X-Request-ID") + + tenant := c.Request().Header.Get("X-Tenant-ID") + uid := c.Request().Header.Get("X-Device-UID") + + // WARN: In versions before 0.15, the agent's authentication may not provide the "X-Tenant-ID" header. + // This can cause issues with establishing sessions and tracking online devices. To solve this, + // we retrieve the tenant ID by querying the API. Maybe this can be removed in a future release. + if tenant == "" { + device, err := h.Client.GetDevice(c.Request().Context(), uid) + if err != nil { + log.WithError(err). + WithField("uid", uid). + Error("unable to retrieve device's tenant id") + + return err + } + + tenant = device.TenantID + } + + h.Dialer.Manager.Set( + dialer.NewKey(tenant, uid), + wsconnadapter.New( + conn, + wsconnadapter.WithID(requestID), + wsconnadapter.WithDevice(tenant, uid), + ), + HandleRevdialPath, + ) + + return nil +} + +type HandleConnectionV2Data struct { + RequestID string `header:"x-request-id" validate:"required"` + UID string `header:"x-device-uid" validate:"required,len=64"` + Tenant string `header:"x-tenant-id" validate:"required,uuid"` +} + +// HandleConnectionV2 upgrades the HTTP connection to WebSocket and +// binds it to a yamux session (V2). Subsequent logical streams are +// opened without additional HTTP handshakes and are protocol-negotiated +// via multistream-select. +func (h *Handlers) HandleConnectionV2(c echo.Context) error { + log.Trace("handling v2 connection") + defer log.Trace("v2 connection handle closed") + + conn, err := upgrader.Upgrade(c.Response(), c.Request(), nil) + if err != nil { + return c.String(http.StatusInternalServerError, err.Error()) + } + + var data HandleConnectionV2Data + + if err := c.Bind(&data); err != nil { + log.WithError(err).Error("failed to bind the request") + + return err + } + + if err := c.Validate(&data); err != nil { + log.WithError(err).Error("failed to validate the request") + + return err + } + + logger := log.WithFields(log.Fields{ + "request-id": data.RequestID, + "tenant": data.Tenant, + "uid": data.UID, + }) + + logger.Info("v2 connection established") + + if err := h.Dialer.Manager.Bind( + data.Tenant, + data.UID, + wsconnadapter.New( + conn, + wsconnadapter.WithID(data.RequestID), + wsconnadapter.WithDevice(data.Tenant, data.UID), + ), + ); err != nil { + logger.WithError(err).Error("failed to bind the connection") + + return err + } + + logger.Info("v2 connection bound") + + return nil +} diff --git a/ssh/http/server.go b/ssh/http/server.go new file mode 100644 index 00000000000..07870ad40db --- /dev/null +++ b/ssh/http/server.go @@ -0,0 +1,150 @@ +package http + +import ( + "errors" + "net/http" + + "github.com/gorilla/websocket" + "github.com/labstack/echo/v4" + "github.com/shellhub-io/shellhub/pkg/api/internalclient" + "github.com/shellhub-io/shellhub/pkg/revdial" + "github.com/shellhub-io/shellhub/pkg/validator" + "github.com/shellhub-io/shellhub/ssh/pkg/dialer" +) + +type Message struct { + Message string `json:"message"` +} + +func NewMessageFromError(err error) Message { + return Message{ + Message: err.Error(), + } +} + +// Config controls optional features for the SSH HTTP sidecar server. +// +// When WebEndpoints is enabled the server exposes an HTTP proxy entry +// point (/http/proxy) that allows externally accessible per-device +// subdomains to be resolved and forwarded through the reverse tunnel +// transport (supporting both legacy V1 and yamux/multistream V2). +type Config struct { + // WebEndpoints enables the web endpoints (HTTP proxy) feature. + WebEndpoints bool + // WebEndpointsDomain is the base domain used when constructing the + // host header for tunneled HTTP requests (e.g.
.). + WebEndpointsDomain string +} + +// Server wires HTTP routes (connection upgrade, reverse dialing, +// web endpoint proxy, healthcheck) to the underlying dialer and +// handlers. It exposes both V1 (/ssh/connection + /ssh/revdial) and V2 +// (/connection) endpoints during the transition period while agents +// upgrade. +type Server struct { + Config *Config + Router *echo.Echo + Handlers *Handlers +} + +var ( + ErrWebEndpointForbidden = errors.New("web endpoint not found") + ErrDeviceTunnelDial = errors.New("failed to connect to device") + ErrDeviceTunnelWriteRequest = errors.New("failed to send data to the device") + ErrDeviceTunnelReadResponse = errors.New("failed to write the response back to the client") + ErrDeviceTunnelHijackRequest = errors.New("failed to capture the request") + ErrDeviceTunnelParsePath = errors.New("failed to parse the path") + ErrDeviceTunnelConnect = errors.New("failed to connect to the port on device") +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + Subprotocols: []string{"binary"}, + CheckOrigin: func(_ *http.Request) bool { + return true + }, +} + +// ListenAndServe starts the Echo HTTP server on the provided address. +func (s *Server) ListenAndServe(address string) error { + return s.Router.Start(address) +} + +type Binder struct{} + +func NewBinder() *Binder { + return &Binder{} +} + +func (b *Binder) Bind(s any, c echo.Context) error { + binder := new(echo.DefaultBinder) + if err := binder.Bind(s, c); err != nil { + err := err.(*echo.HTTPError) //nolint:forcetypeassert + + return err + } + + if err := binder.BindHeaders(c, s); err != nil { + err := err.(*echo.HTTPError) //nolint:forcetypeassert + + return err + } + + return nil +} + +type Validator struct { + validator *validator.Validator +} + +// NewValidator creates a new validator for the echo framework from the ShellHub validator. +func NewValidator() *Validator { + return &Validator{validator: validator.New()} +} + +// Validate is called by the echo framework to validate the request body. +// If the request body is invalid, it returns an error with the invalid fields. +func (v *Validator) Validate(structure any) error { + if ok, err := v.validator.Struct(structure); !ok || err != nil { + return err + } + + return nil +} + +func NewServer(d *dialer.Dialer, cli internalclient.Client, cfg *Config) *Server { + r := echo.New() + + r.Binder = NewBinder() + r.Validator = NewValidator() + r.HideBanner = true + r.HidePort = true + + handlers := &Handlers{ + Dialer: d, + Client: cli, + Config: cfg, + } + + r.GET(HandleConnectionV1Path, handlers.HandleConnectionV1) + r.GET(HandleConnectionV2Path, handlers.HandleConnectionV2) + + r.GET(HandleRevdialPath, echo.WrapHandler(revdial.ConnHandler(upgrader))) + + r.POST(HandleSSHClosePath, handlers.HandleSSHClose) + r.GET(HandleHealthcheckPath, handlers.HandleHealthcheck) + + if cfg.WebEndpoints { + // NOTE: The `/http/proxy` endpoint is invoked by the NGINX gateway when a tunnel URL is accessed. It processes + // the `X-Address` and `X-Path` headers, which specify the tunnel's address and the target path on the server, + // returning an error related to the connection to device or what was returned from the server inside the tunnel. + r.Any(HandleHTTPProxyPath, handlers.HandleHTTPProxy) + } + + return &Server{ + Config: cfg, + Router: r, + Handlers: handlers, + } +} diff --git a/ssh/main.go b/ssh/main.go index baa657cafac..553e5495179 100644 --- a/ssh/main.go +++ b/ssh/main.go @@ -2,16 +2,17 @@ package main import ( "fmt" - "net/http" "runtime" "time" "github.com/labstack/echo-contrib/pprof" + "github.com/shellhub-io/shellhub/pkg/api/internalclient" "github.com/shellhub-io/shellhub/pkg/cache" "github.com/shellhub-io/shellhub/pkg/envs" "github.com/shellhub-io/shellhub/pkg/loglevel" - "github.com/shellhub-io/shellhub/ssh/pkg/tunnel" - "github.com/shellhub-io/shellhub/ssh/server" + "github.com/shellhub-io/shellhub/ssh/http" + "github.com/shellhub-io/shellhub/ssh/pkg/dialer" + ssh "github.com/shellhub-io/shellhub/ssh/server" "github.com/shellhub-io/shellhub/ssh/web" log "github.com/sirupsen/logrus" ) @@ -47,17 +48,20 @@ func main() { Fatal("failed to connect to redis cache") } - tun, err := tunnel.NewTunnel("/ssh/connection", "/ssh/revdial", tunnel.Config{ - Tunnels: env.WebEndpoints, - TunnelsDomain: env.WebEndpointsDomain, - RedisURI: env.RedisURI, - }) + cli, err := internalclient.NewClient(nil, internalclient.WithAsynqWorker(env.RedisURI)) if err != nil { log.WithError(err). Fatal("failed to create the tunnel") } - router := tun.GetRouter() + d := dialer.NewDialer(cli) + + h := http.NewServer(d, cli, &http.Config{ + WebEndpoints: env.WebEndpoints, + WebEndpointsDomain: env.WebEndpointsDomain, + }) + + router := h.Router web.NewSSHServerBridge(router, cache) @@ -68,6 +72,11 @@ func main() { log.Info("Profiling enabled at http://0.0.0.0:8080/debug/pprof/") } + s := ssh.NewServer(d, cache, &ssh.Options{ + ConnectTimeout: env.ConnectTimeout, + AllowPublickeyAccessBelow060: env.AllowPublickeyAccessBelow060, + }) + errs := make(chan error) go func() { @@ -79,7 +88,7 @@ func main() { } }() - errs <- http.ListenAndServe(ListenAddress, router) //nolint:gosec + errs <- h.ListenAndServe(ListenAddress) }() go func() { @@ -91,10 +100,7 @@ func main() { } }() - errs <- server.NewServer(&server.Options{ - ConnectTimeout: env.ConnectTimeout, - AllowPublickeyAccessBelow060: env.AllowPublickeyAccessBelow060, - }, tun.Tunnel, cache).ListenAndServe() + errs <- s.ListenAndServe() }() if err := <-errs; err != nil { diff --git a/ssh/pkg/dialer/dialer.go b/ssh/pkg/dialer/dialer.go new file mode 100644 index 00000000000..104f7338312 --- /dev/null +++ b/ssh/pkg/dialer/dialer.go @@ -0,0 +1,97 @@ +package dialer + +import ( + "context" + "errors" + "net" + "strings" + + "github.com/shellhub-io/shellhub/pkg/api/internalclient" + log "github.com/sirupsen/logrus" +) + +// NewKey joins tenant and device UID in the canonical form used as the +// identifier inside the connection manager maps. +func NewKey(tenant, uid string) string { + return strings.Join([]string{tenant, uid}, ":") +} + +type Dialer struct { + Manager *Manager + client internalclient.Client +} + +func NewDialer(client internalclient.Client) *Dialer { + m := NewManager() + + m.DialerDoneCallback = func(key string) { + // TODO: Use `Key` struct when available to avoid string parsing on every call. + parts := strings.Split(key, ":") + if len(parts) != 2 { + log.Error("failed to parse key at close handler") + + return + } + + tenant := parts[0] + uid := parts[1] + + if err := client.DevicesOffline(context.TODO(), uid); err != nil { + log.WithError(err). + WithFields(log.Fields{ + "uid": uid, + "tenant_id": tenant, + }). + Error("failed to set device offline") + } + } + + m.DialerKeepAliveCallback = func(key string) { + // TODO: Use `Key` struct when available to avoid string parsing on every call. + parts := strings.Split(key, ":") + if len(parts) != 2 { + log.Error("failed to parse key at keep alive handler") + + return + } + + tenant := parts[0] + uid := parts[1] + + if err := client.DevicesHeartbeat(context.TODO(), uid); err != nil { + log.WithError(err). + WithFields(log.Fields{ + "uid": uid, + "tenant_id": tenant, + }). + Error("failed to send heartbeat signal") + } + } + + return &Dialer{ + Manager: m, + client: client, + } +} + +var ErrInvalidArgument = errors.New("invalid argument") + +// DialTo establishes a raw reverse connection to the device and performs +// the version-specific bootstrap for the provided target. It returns a +// connection ready for application protocol usage. +func (t *Dialer) DialTo(ctx context.Context, tenant string, uid string, target Target) (net.Conn, error) { + if tenant == "" || uid == "" { + return nil, ErrInvalidArgument + } + + conn, version, err := t.Manager.Dial(ctx, NewKey(tenant, uid)) + if err != nil { + return nil, err + } + + if target == nil { + return conn, nil + } + + return target.prepare(conn, version) +} diff --git a/ssh/pkg/dialer/docs.go b/ssh/pkg/dialer/docs.go new file mode 100644 index 00000000000..0ffe163a60a --- /dev/null +++ b/ssh/pkg/dialer/docs.go @@ -0,0 +1,50 @@ +// Package dialer provides utilities to manage and use reverse connections +// opened by agents so the server (or other services) can dial back into a +// device. The package supports two transport modes (protocol versions): the +// legacy revdial-based HTTP transport (v1) and a yamux-multiplexed transport +// (v2). When using v2, per-stream application protocols are negotiated using +// multistream identifiers defined in this package. +// +// # High level concepts +// +// - Manager: a connection manager that stores active reverse transports and +// exposes methods to bind new agent connections and to dial a device by +// its key. It also exposes callbacks for tracking when connections are +// closed or when keep-alive events occur. +// +// - Dialer: a thin wrapper around a Manager which also holds an +// internalclient to perform device lifecycle operations (heartbeat / +// offline notifications) and provides DialTo which returns a ready-to-use +// net.Conn for a requested Target. +// +// - Target: an interface implemented by small helpers that prepare a raw +// connection for a particular application-level purpose (for example, +// opening or closing an SSH session, or establishing an HTTP proxy). The +// prepare method will perform any necessary handshake depending on the +// negotiated connection version. +// +// # Versioning +// +// ConnectionVersion1 (v1) uses the older revdial/http handshake where the +// client expects HTTP-style GET/CONNECT requests. ConnectionVersion2 (v2) +// uses a yamux session and performs per-stream negotiation with the +// multistream protocol strings (see ProtoSSHOpen, ProtoSSHClose, +// ProtoHTTPProxy). Callers should prepare the appropriate Target and the +// dialer will perform the correct handshake based on the returned +// ConnectionVersion. +// +// Usage (server-side) +// +// Typical server usage is: +// - When an agent connects, call Manager.Bind(tenant, uid, conn) to +// register the reverse transport. The manager will keep the session alive +// and call configured callbacks on events. +// - To connect to a device, create a Dialer (NewDialer) and call +// Dialer.DialTo(ctx, tenant, uid, target). DialTo returns a net.Conn +// already prepared for the requested target (or a raw connection if the +// target is nil). +// +// The package intentionally keeps the wire-level protocol identifiers and +// version handling colocated with the dial logic so the agent and server +// implementations can remain compatible and easy to reason about. +package dialer diff --git a/ssh/pkg/dialer/manager.go b/ssh/pkg/dialer/manager.go new file mode 100644 index 00000000000..f20548930c1 --- /dev/null +++ b/ssh/pkg/dialer/manager.go @@ -0,0 +1,204 @@ +package dialer + +import ( + "context" + "errors" + "net" + "os" + "time" + + "github.com/hashicorp/yamux" + "github.com/shellhub-io/shellhub/pkg/revdial" + "github.com/shellhub-io/shellhub/pkg/wsconnadapter" + log "github.com/sirupsen/logrus" +) + +var ErrNoConnection = errors.New("no connection") + +type Manager struct { + Connections *SyncSliceMap + DialerDoneCallback func(string) + DialerKeepAliveCallback func(string) +} + +func NewManager() *Manager { + return &Manager{ + Connections: &SyncSliceMap{}, + DialerDoneCallback: func(string) {}, + DialerKeepAliveCallback: func(string) {}, + } +} + +func (m *Manager) Set(key string, conn *wsconnadapter.Adapter, connPath string) { + dialer := revdial.NewDialer(conn.Logger, conn, connPath) + + m.Connections.Store(key, dialer) + + if size := m.Connections.Size(key); size > 1 { + log.WithFields(log.Fields{ + "key": key, + "size": size, + }).Warning("Multiple connections stored for the same identifier.") + } + + m.DialerKeepAliveCallback(key) + + // Start the ping loop and get the channel for pong responses + pong := conn.Ping() + + go func() { + for { + select { + case <-pong: + m.DialerKeepAliveCallback(key) + + continue + case <-dialer.Done(): + m.Connections.Delete(key, dialer) + m.DialerDoneCallback(key) + + return + } + } + }() +} + +// BindPingInterval is the interval between pings sent to the yamux session +// to keep it alive. It should be less than the NAT timeout to avoid +// disconnections. +// It should be the same value as used by the revdial.Dialer ping interval. +const BindPingInterval = 35 * time.Second + +// Bind binds a WebSocket connection to a yamux session and stores it in the connection manager. +// All new agents should use this handler to register their reverse connection. +func (m *Manager) Bind(tenant string, uid string, conn *wsconnadapter.Adapter) error { + key := NewKey(tenant, uid) + + session, err := yamux.Client(conn, &yamux.Config{ + AcceptBacklog: 256, + // NOTE: As we need to keep the registered connection alive, we use our own ping/pong mechanism. + EnableKeepAlive: false, + // NOTE: Although we disable the built-in keepalive, we still need to set the interval to a non-zero value to + // avoid yamux error when verifying the configuration. We've created a Pull Request to improve this behavior. + // TODO: Remove this workaround when yamux supports disabling keepalive completely. + KeepAliveInterval: BindPingInterval, + ConnectionWriteTimeout: 15 * time.Second, + MaxStreamWindowSize: 256 * 1024, + StreamCloseTimeout: 5 * time.Minute, + StreamOpenTimeout: 75 * time.Second, + LogOutput: os.Stderr, + }) + if err != nil { + log.WithError(err).Error("failed to create yamux client session") + + return err + } + + m.Connections.Store(key, session) + + if size := m.Connections.Size(key); size > 1 { + log.WithFields(log.Fields{ + "key": key, + "size": size, + }).Warning("Multiple connections stored for the same identifier.") + } + + m.DialerKeepAliveCallback(key) + + go func() { + for { + select { + // NOTE: Ping is also important to keep the underlying WebSocket connection alive and avoid NAT timeouts. + case <-time.After(BindPingInterval): + if _, err := session.Ping(); err != nil { + log.WithFields(log.Fields{ + "key": key, + }).WithError(err).Error("failed to ping yamux session") + + m.Connections.Delete(key, session) + m.DialerDoneCallback(key) + + return + } + + m.DialerKeepAliveCallback(key) + + continue + case <-session.CloseChan(): + m.Connections.Delete(key, session) + m.DialerDoneCallback(key) + + return + } + } + }() + + return nil +} + +// ConnectionVersion protocol version identifiers used when dialing a device. +type ConnectionVersion byte + +const ( + // ConnectionVersionUnknown is used when the transport version could not be determined. + ConnectionVersionUnknown ConnectionVersion = 0 + // ConnectionVersion1 is the legacy transport using revdial over HTTP. + ConnectionVersion1 ConnectionVersion = 1 + // ConnectionVersion2 is the current transport using yamux multiplexing. + ConnectionVersion2 ConnectionVersion = 2 +) + +// Dial tries to find a connection by its key and dials it. +// +// It returns the connection, its version ([ConnectionVersion1] or [ConnectionVersion2]) and an error, +func (m *Manager) Dial(ctx context.Context, key string) (net.Conn, ConnectionVersion, error) { + loaded, ok := m.Connections.Load(key) + if !ok { + return nil, ConnectionVersionUnknown, ErrNoConnection + } + + if size := m.Connections.Size(key); size > 1 { + log.WithFields(log.Fields{ + "key": key, + "size": size, + }).Warning("Multiple connections found for the same identifier during reverse tunnel dialing.") + } + + if dialer, ok := loaded.(*revdial.Dialer); ok { + log.WithFields(log.Fields{ + "key": key, + "version": "v1", + }).Debug("using v1 dialer for reverse tunnel dialing") + + conn, err := dialer.Dial(ctx) + if err != nil { + log.WithFields(log.Fields{ + "key": key, + "version": "v1", + }).WithError(err).Error("failed to dial reverse connection") + + return nil, ConnectionVersionUnknown, err + } + + return conn, ConnectionVersion1, nil + } + + if session, ok := loaded.(*yamux.Session); ok { + log.WithFields(log.Fields{ + "key": key, + "version": "v2", + }).Debug("using v2 connection for reverse tunnel dialing") + + conn, err := session.Open() + if err != nil { + log.WithFields(log.Fields{ + "key": key, + "version": "v2", + }).WithError(err).Error("failed to open yamux stream for reverse connection") + } + + return conn, ConnectionVersion2, nil + } + + return nil, ConnectionVersionUnknown, ErrNoConnection +} diff --git a/ssh/pkg/dialer/protocols.go b/ssh/pkg/dialer/protocols.go new file mode 100644 index 00000000000..cba9269708f --- /dev/null +++ b/ssh/pkg/dialer/protocols.go @@ -0,0 +1,12 @@ +package dialer + +// Multistream protocol identifiers used when negotiating per-stream +// application protocols over a V2 yamux connection. +// +// The agent and server must keep these values in sync. Changing a value +// is a wire incompatible change. +const ( + ProtoSSHOpen = "/ssh/open/1.0.0" + ProtoSSHClose = "/ssh/close/1.0.0" + ProtoHTTPProxy = "/http/proxy/1.0.0" +) diff --git a/pkg/connman/syncslicemap.go b/ssh/pkg/dialer/syncslicemap.go similarity index 99% rename from pkg/connman/syncslicemap.go rename to ssh/pkg/dialer/syncslicemap.go index f37e1383aac..07aa92a5504 100644 --- a/pkg/connman/syncslicemap.go +++ b/ssh/pkg/dialer/syncslicemap.go @@ -1,4 +1,4 @@ -package connman +package dialer import "sync" diff --git a/pkg/connman/syncslicemap_test.go b/ssh/pkg/dialer/syncslicemap_test.go similarity index 99% rename from pkg/connman/syncslicemap_test.go rename to ssh/pkg/dialer/syncslicemap_test.go index 7414f9f9e8c..11344aa9223 100644 --- a/pkg/connman/syncslicemap_test.go +++ b/ssh/pkg/dialer/syncslicemap_test.go @@ -1,4 +1,4 @@ -package connman +package dialer import ( "testing" diff --git a/ssh/pkg/dialer/target.go b/ssh/pkg/dialer/target.go new file mode 100644 index 00000000000..44d63bddc5e --- /dev/null +++ b/ssh/pkg/dialer/target.go @@ -0,0 +1,126 @@ +package dialer + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strconv" + + "github.com/multiformats/go-multistream" + log "github.com/sirupsen/logrus" +) + +type Target interface { + prepare(conn net.Conn, version ConnectionVersion) (net.Conn, error) +} + +// SSHOpenTarget prepares a connection for initiating a new SSH session +// with the agent. +type SSHOpenTarget struct{ SessionID string } + +func (t SSHOpenTarget) prepare(conn net.Conn, version ConnectionVersion) (net.Conn, error) { // nolint:ireturn + switch version { + case ConnectionVersion1: + log.Debug("preparing SSH open target for connection version 1") + + req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("/ssh/%s", t.SessionID), nil) + if err := req.Write(conn); err != nil { + log.Errorf("failed to write HTTP request: %v", err) + + return nil, err + } + case ConnectionVersion2: + log.Debug("preparing SSH open target for connection version 2") + + if err := multistream.SelectProtoOrFail(ProtoSSHOpen, conn); err != nil { + return nil, err + } + if err := json.NewEncoder(conn).Encode(map[string]string{"id": t.SessionID}); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported connection version: %d", version) + } + + return conn, nil +} + +// SSHCloseTarget prepares a connection to request closing an existing SSH session. +type SSHCloseTarget struct{ SessionID string } + +func (t SSHCloseTarget) prepare(conn net.Conn, version ConnectionVersion) (net.Conn, error) { // nolint:ireturn + switch version { + case ConnectionVersion1: + req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("/ssh/close/%s", t.SessionID), nil) + if err := req.Write(conn); err != nil { + return nil, err + } + case ConnectionVersion2: + if err := multistream.SelectProtoOrFail(ProtoSSHClose, conn); err != nil { + return nil, err + } + if err := json.NewEncoder(conn).Encode(map[string]string{"id": t.SessionID}); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported connection version: %d", version) + } + + return conn, nil +} + +// HTTPProxyTarget prepares a connection for proxying HTTP traffic to a +// device web endpoint. After preparation the caller should write the +// final HTTP request (with rewritten Host + URL) directly to the +// returned connection. +type HTTPProxyTarget struct { + RequestID string + Host string + Port int +} + +func (t HTTPProxyTarget) prepare(conn net.Conn, version ConnectionVersion) (net.Conn, error) { // nolint:ireturn + switch version { + case ConnectionVersion1: + // Write initial handshake request and expect 200 OK. + handshakeReq, _ := http.NewRequest(http.MethodConnect, fmt.Sprintf("/http/proxy/%s:%d", t.Host, t.Port), nil) + if err := handshakeReq.Write(conn); err != nil { + return nil, err + } + resp, err := http.ReadResponse(bufio.NewReader(conn), handshakeReq) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http proxy handshake failed: %s", resp.Status) + } + case ConnectionVersion2: + if err := multistream.SelectProtoOrFail(ProtoHTTPProxy, conn); err != nil { + return nil, err + } + if err := json.NewEncoder(conn).Encode(map[string]string{ + "id": t.RequestID, + "host": t.Host, + "port": strconv.Itoa(t.Port), + }); err != nil { + return nil, err + } + result := map[string]string{} + + // NOTE: limit the size of the response to avoid DoS via large payloads. + const Limit = 512 + if err := json.NewDecoder(io.LimitReader(conn, Limit)).Decode(&result); err != nil { + return nil, err + } + if result["status"] != "ok" { + return nil, fmt.Errorf("http proxy negotiation failed: %s", result["message"]) + } + default: + return nil, fmt.Errorf("unsupported connection version: %d", version) + } + + return conn, nil +} diff --git a/ssh/pkg/dialer/throttle.go b/ssh/pkg/dialer/throttle.go new file mode 100644 index 00000000000..327caf6a548 --- /dev/null +++ b/ssh/pkg/dialer/throttle.go @@ -0,0 +1,247 @@ +package dialer + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// ErrNegativeLimit is returned when attempting to set a negative limit. +var ErrNegativeLimit = errors.New("negative throttle limit") + +// Option configures a Throttler. +type Option func(*Throttler) + +// WithReadLimit sets the read bytes-per-second limit and burst. +// If bps <= 0 => unlimited. If burst <=0 it defaults to bps. +func WithReadLimit(bps int, burst int) Option { + return func(t *Throttler) { + t.setLimiter(&t.readMu, &t.readLimiter, bps, burst) + } +} + +// WithWriteLimit sets the write bytes-per-second limit and burst. +// If bps <= 0 => unlimited. If burst <=0 it defaults to bps. +func WithWriteLimit(bps int, burst int) Option { + return func(t *Throttler) { + t.setLimiter(&t.writeMu, &t.writeLimiter, bps, burst) + } +} + +// Throttler wraps an underlying io.Reader / io.Writer (optionally both) and +// enforces directional byte-per-second limits using token buckets. +// It is safe for concurrent use of Read and Write. +type Throttler struct { + // Underlying read side (may be nil if only writing). + R io.Reader + // Underlying write side (may be nil if only reading). + W io.Writer + + readMu sync.RWMutex + readLimiter *rate.Limiter + + writeMu sync.RWMutex + writeLimiter *rate.Limiter +} + +func NewThrottler(r io.Reader, w io.Writer, opts ...Option) *Throttler { + t := &Throttler{R: r, W: w} + + for _, o := range opts { + o(t) + } + + return t +} + +// setLimiter (internal) creates or clears a limiter based on bps. +func (t *Throttler) setLimiter(mu *sync.RWMutex, lim **rate.Limiter, bps int, burst int) { + mu.Lock() + defer mu.Unlock() + + if bps <= 0 { + *lim = nil + + return + } + + if burst <= 0 { + burst = bps + } + + *lim = rate.NewLimiter(rate.Limit(bps), burst) +} + +// UpdateReadLimit dynamically changes the read limit. +func (t *Throttler) UpdateReadLimit(bps int, burst int) error { + if bps < 0 || burst < 0 { + return ErrNegativeLimit + } + + t.setLimiter(&t.readMu, &t.readLimiter, bps, burst) + + return nil +} + +// UpdateWriteLimit dynamically changes the write limit. +func (t *Throttler) UpdateWriteLimit(bps int, burst int) error { + if bps < 0 || burst < 0 { + return ErrNegativeLimit + } + + t.setLimiter(&t.writeMu, &t.writeLimiter, bps, burst) + + return nil +} + +// Read implements io.Reader with throttling. +func (t *Throttler) Read(p []byte) (int, error) { + if t.R == nil { + return 0, errors.New("read not supported (nil underlying Reader)") + } + + lim := t.getReadLimiter() + + if lim == nil { + return t.R.Read(p) + } + + maxChunk := lim.Burst() + if maxChunk <= 0 { + maxChunk = 32 * 1024 + } + + total := 0 + for total < len(p) { + remaining := len(p) - total + chunk := min(remaining, maxChunk) + + if err := lim.WaitN(context.Background(), chunk); err != nil { + if total > 0 { + return total, err + } + + return 0, err + } + + n, err := t.R.Read(p[total : total+chunk]) + total += n + if err != nil || n == 0 { + return total, err + } + + if n < chunk { + break + } + } + + return total, nil +} + +// Write implements io.Writer with throttling. +func (t *Throttler) Write(p []byte) (int, error) { + if t.W == nil { + return 0, errors.New("write not supported (nil underlying Writer)") + } + + lim := t.getWriteLimiter() + + if lim == nil { + return t.W.Write(p) + } + + maxChunk := lim.Burst() + if maxChunk <= 0 { + maxChunk = 32 * 1024 + } + + total := 0 + for total < len(p) { + remaining := len(p) - total + chunk := min(remaining, maxChunk) + + if err := lim.WaitN(context.Background(), chunk); err != nil { + if total > 0 { + return total, err + } + + return 0, err + } + + n, err := t.W.Write(p[total : total+chunk]) + total += n + if err != nil || n == 0 { + return total, err + } + + if n < chunk { + break + } + } + + return total, nil +} + +// Helper getters with read locks for concurrency. +func (t *Throttler) getReadLimiter() *rate.Limiter { + t.readMu.RLock() + defer t.readMu.RUnlock() + + return t.readLimiter +} + +func (t *Throttler) getWriteLimiter() *rate.Limiter { + t.writeMu.RLock() + defer t.writeMu.RUnlock() + + return t.writeLimiter +} + +type ConnThrottler struct { + Conn net.Conn + Throttler *Throttler +} + +func (c *ConnThrottler) Close() error { + return c.Conn.Close() +} + +func (c *ConnThrottler) LocalAddr() net.Addr { + return c.Conn.LocalAddr() +} + +func (c *ConnThrottler) Read(b []byte) (n int, err error) { + return c.Throttler.Read(b) +} + +func (c *ConnThrottler) RemoteAddr() net.Addr { + return c.Conn.RemoteAddr() +} + +func (c *ConnThrottler) SetDeadline(t time.Time) error { + return c.Conn.SetDeadline(t) +} + +func (c *ConnThrottler) SetReadDeadline(t time.Time) error { + return c.Conn.SetReadDeadline(t) +} + +func (c *ConnThrottler) SetWriteDeadline(t time.Time) error { + return c.Conn.SetWriteDeadline(t) +} + +func (c *ConnThrottler) Write(b []byte) (n int, err error) { + return c.Throttler.Write(b) +} + +func NewConnThrottler(conn net.Conn, readBps, readBurst, writeBps, writeBurst int) net.Conn { + return &ConnThrottler{ + Conn: conn, + Throttler: NewThrottler(conn, conn, WithReadLimit(readBps, readBurst), WithWriteLimit(writeBps, writeBurst)), + } +} diff --git a/ssh/pkg/dialer/throttle_test.go b/ssh/pkg/dialer/throttle_test.go new file mode 100644 index 00000000000..b337879a4d3 --- /dev/null +++ b/ssh/pkg/dialer/throttle_test.go @@ -0,0 +1,147 @@ +package dialer + +import ( + "bytes" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func expectedMinDuration(total, bps, burst int) time.Duration { + if bps <= 0 { + return 0 + } + + remaining := total - burst + if remaining <= 0 { + return 0 + } + + secs := float64(remaining) / float64(bps) + + return time.Duration(secs * float64(time.Second)) +} + +func TestThrottler_TableDriven(t *testing.T) { + cases := []struct { + name string + run func(t *testing.T) + }{ + { + name: "UnlimitedReadFast", + run: func(t *testing.T) { + data := bytes.Repeat([]byte("x"), 1024) + r := bytes.NewReader(data) + + th := NewThrottler(r, nil) // no limits + + buf := make([]byte, len(data)) + start := time.Now() + n, err := th.Read(buf) + dur := time.Since(start) + + assert.Truef(t, err == nil || err == io.EOF, "unexpected read error: %v", err) + assert.Equal(t, len(data), n, "read bytes mismatch") + assert.LessOrEqual(t, dur, 100*time.Millisecond, "unlimited read took too long") + }, + }, + { + name: "NegativeLimitValidation", + run: func(t *testing.T) { + th := NewThrottler(nil, nil) + err := th.UpdateReadLimit(-1, 1) + assert.Equal(t, ErrNegativeLimit, err) + err = th.UpdateWriteLimit(-1, 1) + assert.Equal(t, ErrNegativeLimit, err) + }, + }, + { + name: "ReadRateEnforced", + run: func(t *testing.T) { + total := 200 + bps := 50 + burst := 10 + + data := bytes.Repeat([]byte("r"), total) + r := bytes.NewReader(data) + th := NewThrottler(r, nil, WithReadLimit(bps, burst)) + + buf := make([]byte, total) + start := time.Now() + n, err := th.Read(buf) + dur := time.Since(start) + + assert.Truef(t, err == nil || err == io.EOF, "unexpected read error: %v", err) + assert.Equal(t, total, n, "read bytes mismatch") + + expect := expectedMinDuration(total, bps, burst) + // allow 20% timing slack for scheduler and test flakiness + slack := expect / 5 + assert.Truef(t, dur+slack >= expect, "read duration = %v; want at least ~%v (with slack %v)", dur, expect, slack) + }, + }, + { + name: "WriteRateEnforced", + run: func(t *testing.T) { + total := 200 + bps := 50 + burst := 10 + + var bufOut bytes.Buffer + th := NewThrottler(nil, &bufOut, WithWriteLimit(bps, burst)) + + data := bytes.Repeat([]byte("w"), total) + start := time.Now() + n, err := th.Write(data) + dur := time.Since(start) + + assert.NoError(t, err, "unexpected write error") + assert.Equal(t, total, n, "written bytes mismatch") + + expect := expectedMinDuration(total, bps, burst) + slack := expect / 5 + assert.Truef(t, dur+slack >= expect, "write duration = %v; want at least ~%v (with slack %v)", dur, expect, slack) + }, + }, + { + name: "ConnThrottlerPassthrough", + run: func(t *testing.T) { + c1, c2 := net.Pipe() + t.Cleanup(func() { c1.Close(); c2.Close() }) + + // Wrap c2 with unlimited throttler + thrConn := NewConnThrottler(c2, 0, 0, 0, 0) + + // write from c1, read from thrConn + msg := []byte("hello-throttle") + + done := make(chan error, 1) + go func() { + defer c1.Close() + _, err := c1.Write(msg) + done <- err + }() + + // read on wrapped conn + got := make([]byte, len(msg)) + n, err := thrConn.Read(got) + assert.Truef(t, err == nil || err == io.EOF, "conn read error: %v", err) + assert.Equal(t, len(msg), n, "conn read bytes mismatch") + assert.Equal(t, msg, got, "conn read data mismatch") + + // ensure writer had no error + err = <-done + assert.NoError(t, err, "writer error") + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.run(t) + }) + } +} diff --git a/ssh/pkg/tunnel/tunnel.go b/ssh/pkg/tunnel/tunnel.go deleted file mode 100644 index dc9c2f1c853..00000000000 --- a/ssh/pkg/tunnel/tunnel.go +++ /dev/null @@ -1,329 +0,0 @@ -package tunnel - -import ( - "bufio" - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/url" - "strings" - "sync" - - "github.com/labstack/echo/v4" - "github.com/shellhub-io/shellhub/pkg/api/internalclient" - "github.com/shellhub-io/shellhub/pkg/httptunnel" - log "github.com/sirupsen/logrus" -) - -var ( - ErrWebEndpointForbidden = errors.New("web endpoint not found") - ErrDeviceTunnelDial = errors.New("failed to connect to device") - ErrDeviceTunnelWriteRequest = errors.New("failed to send data to the device") - ErrDeviceTunnelReadResponse = errors.New("failed to write the response back to the client") - ErrDeviceTunnelHijackRequest = errors.New("failed to capture the request") - ErrDeviceTunnelParsePath = errors.New("failed to parse the path") - ErrDeviceTunnelConnect = errors.New("failed to connect to the port on device") -) - -type Message struct { - Message string `json:"message"` -} - -func NewMessageFromError(err error) Message { - return Message{ - Message: err.Error(), - } -} - -type Config struct { - // Tunnels defines if tunnel's feature is enabled. - Tunnels bool - // TunnelsDomain define the domain of tunnels feature when it's enabled. - TunnelsDomain string - // RedisURI is the redis URI connection. - RedisURI string -} - -func (c Config) Validate() error { - if c.Tunnels && c.TunnelsDomain == "" { - return errors.New("tunnels feature is enabled, but tunnel's domain is empty") - } - - if c.RedisURI == "" { - return errors.New("redis uri is empty") - } - - return nil -} - -type Tunnel struct { - Tunnel *httptunnel.Tunnel - API internalclient.Client - router *echo.Echo -} - -func NewTunnel(connection string, dial string, config Config) (*Tunnel, error) { - if err := config.Validate(); err != nil { - return nil, err - } - - api, err := internalclient.NewClient(nil, internalclient.WithAsynqWorker(config.RedisURI)) - if err != nil { - return nil, err - } - - tunnel := &Tunnel{ - Tunnel: httptunnel.NewTunnel(connection, dial), - API: api, - } - - tunnel.Tunnel.ConnectionHandler = func(request *http.Request) (string, error) { - tenant := request.Header.Get("X-Tenant-ID") - uid := request.Header.Get("X-Device-UID") - - // WARN: - // In versions before 0.15, the agent's authentication may not provide the "X-Tenant-ID" header. - // This can cause issues with establishing sessions and tracking online devices. To solve this, - // we retrieve the tenant ID by querying the API. Maybe this can be removed in a future release. - if tenant == "" { - device, err := tunnel.API.GetDevice(context.TODO(), uid) - if err != nil { - log.WithError(err). - WithField("uid", uid). - Error("unable to retrieve device's tenant id") - - return "", err - } - - tenant = device.TenantID - } - - return tenant + ":" + uid, nil - } - tunnel.Tunnel.CloseHandler = func(key string) { - parts := strings.Split(key, ":") - if len(parts) != 2 { - log.Error("failed to parse key at close handler") - - return - } - - tenant := parts[0] - uid := parts[1] - - if err := tunnel.API.DevicesOffline(context.TODO(), uid); err != nil { - log.WithError(err). - WithFields(log.Fields{ - "uid": uid, - "tenant_id": tenant, - }). - Error("failed to set device offline") - } - } - tunnel.Tunnel.KeepAliveHandler = func(key string) { - parts := strings.Split(key, ":") - if len(parts) != 2 { - log.Error("failed to parse key at keep alive handler") - - return - } - - tenant := parts[0] - uid := parts[1] - - if err := tunnel.API.DevicesHeartbeat(context.TODO(), uid); err != nil { - log.WithError(err). - WithFields(log.Fields{ - "uid": uid, - "tenant_id": tenant, - }). - Error("failed to send heartbeat signal") - } - } - - tunnel.router = tunnel.Tunnel.Router().(*echo.Echo) - - // `/sessions/:uid/close` is the endpoint that is called by the agent to inform the SSH's server that the session is - // closed. - tunnel.router.POST("/api/sessions/:uid/close", func(c echo.Context) error { - var data struct { - UID string `param:"uid"` - Device string `json:"device"` - } - - if err := c.Bind(&data); err != nil { - return err - } - - ctx := c.Request().Context() - - tenant := c.Request().Header.Get("X-Tenant-ID") - - conn, err := tunnel.Dial(ctx, fmt.Sprintf("%s:%s", tenant, data.Device)) - if err != nil { - log.WithError(err).Error("could not found the connection to this device") - - return err - } - - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/ssh/close/%s", data.UID), nil) - if err != nil { - log.WithError(err).Error("failed to create a the request for the device") - - return err - } - - if err := req.Write(conn); err != nil { - log.WithError(err).Error("failed to perform the HTTP request to the device to close the session") - - return err - } - - return c.NoContent(http.StatusOK) - }) - - if config.Tunnels { - // The `/http/proxy` endpoint is invoked by the NGINX gateway when a tunnel URL is accessed. It processes the - // `X-Address` and `X-Path` headers, which specify the tunnel's address and the target path on the server, returning - // an error related to the connection to device or what was returned from the server inside the tunnel. - tunnel.router.Any("/http/proxy", func(c echo.Context) error { - requestID := c.Request().Header.Get("X-Request-ID") - - address := c.Request().Header.Get("X-Address") - log.WithFields(log.Fields{ - "request-id": requestID, - "address": address, - }).Debug("address value") - - path := c.Request().Header.Get("X-Path") - log.WithFields(log.Fields{ - "request-id": requestID, - "address": address, - }).Debug("path") - - endpoint, err := tunnel.API.LookupWebEndpoints(c.Request().Context(), address) - if err != nil { - log.WithError(err).Error("failed to get the web endpoint") - - return c.JSON(http.StatusForbidden, NewMessageFromError(ErrWebEndpointForbidden)) - } - - logger := log.WithFields(log.Fields{ - "request-id": requestID, - "namespace": endpoint.Namespace, - "device": endpoint.Device, - }) - - in, err := tunnel.Dial(c.Request().Context(), fmt.Sprintf("%s:%s", endpoint.Namespace, endpoint.DeviceUID)) - if err != nil { - logger.WithError(err).Error("failed to dial to device") - - return c.JSON(http.StatusForbidden, NewMessageFromError(ErrDeviceTunnelDial)) - } - - defer in.Close() - - logger.Trace("new web endpoint connection initialized") - defer logger.Trace("web endpoint connection doned") - - // NOTE: Connects to the HTTP proxy before doing the actual request. In this case, we are connecting to all - // hosts on the agent because we aren't specifying any host, on the port specified. The proxy route accepts - // connections for any port, but this route should only connect to the HTTP server. - req, _ := http.NewRequest(http.MethodConnect, fmt.Sprintf("/http/proxy/%s:%d", endpoint.Host, endpoint.Port), nil) - - if err := req.Write(in); err != nil { - logger.WithError(err).Error("failed to write the request to the agent") - - return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelWriteRequest)) - } - - if resp, err := http.ReadResponse(bufio.NewReader(in), req); err != nil || resp.StatusCode != http.StatusOK { - logger.WithError(err).Error("failed to connect to HTTP port on device") - - return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelConnect)) - } - - req = c.Request() - req.Host = strings.Join([]string{address, config.TunnelsDomain}, ".") - req.URL, err = url.Parse(path) - if err != nil { - logger.WithError(err).Error("failed to parse the path") - - return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelReadResponse)) - } - - if err := req.Write(in); err != nil { - logger.WithError(err).Error("failed to write the request to the agent") - - return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelWriteRequest)) - } - - ctr := http.NewResponseController(c.Response()) - out, _, err := ctr.Hijack() - if err != nil { - logger.WithError(err).Error("failed to hijact the http request") - - return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelHijackRequest)) - } - - defer out.Close() - - // Bidirectional copy between the client and the device. - var wg sync.WaitGroup - wg.Add(2) - - done := sync.OnceFunc(func() { - defer in.Close() - defer out.Close() - - logger.Trace("close called on in and out connections") - }) - - go func() { - defer done() - defer wg.Done() - - if _, err := io.Copy(in, out); err != nil { - logger.WithError(err).Debug("in and out done returned a error") - } - - logger.Trace("in and out done") - }() - - go func() { - defer done() - defer wg.Done() - - if _, err := io.Copy(out, in); err != nil { - logger.WithError(err).Debug("out and in done returned a error") - } - - logger.Trace("out and in done") - }() - - wg.Wait() - - logger.Debug("http proxy is done") - - return nil - }) - } - - tunnel.router.GET("/healthcheck", func(c echo.Context) error { - return c.String(http.StatusOK, "OK") - }) - - return tunnel, nil -} - -func (t *Tunnel) GetRouter() *echo.Echo { - return t.router -} - -// Dial trys to get a connetion to a device specifying a key, what is a combination of tenant and device's UID. -func (t *Tunnel) Dial(ctx context.Context, key string) (net.Conn, error) { - return t.Tunnel.Dial(ctx, key) -} diff --git a/ssh/server/server.go b/ssh/server/server.go index f28901617c8..535daef9b3d 100644 --- a/ssh/server/server.go +++ b/ssh/server/server.go @@ -10,7 +10,7 @@ import ( gliderssh "github.com/gliderlabs/ssh" "github.com/pires/go-proxyproto" "github.com/shellhub-io/shellhub/pkg/cache" - "github.com/shellhub-io/shellhub/pkg/httptunnel" + "github.com/shellhub-io/shellhub/ssh/pkg/dialer" "github.com/shellhub-io/shellhub/ssh/pkg/target" "github.com/shellhub-io/shellhub/ssh/server/auth" "github.com/shellhub-io/shellhub/ssh/server/channels" @@ -29,7 +29,7 @@ type Options struct { type Server struct { sshd *gliderssh.Server opts *Options - tunnel *httptunnel.Tunnel + dialer *dialer.Dialer } var ( @@ -43,10 +43,10 @@ var ( AccessDeniedMessage string ) -func NewServer(opts *Options, tunnel *httptunnel.Tunnel, cache cache.Cache) *Server { +func NewServer(dialer *dialer.Dialer, cache cache.Cache, opts *Options) *Server { server := &Server{ // nolint: exhaustruct opts: opts, - tunnel: tunnel, + dialer: dialer, } server.sshd = &gliderssh.Server{ // nolint: exhaustruct @@ -79,7 +79,7 @@ func NewServer(opts *Options, tunnel *httptunnel.Tunnel, cache cache.Cache) *Ser return message(InvalidSSHIDMessage) } - sess, err := session.NewSession(ctx, tunnel, cache) + sess, err := session.NewSession(ctx, dialer, cache) if err != nil { logger.WithError(err).Error("failed to create the session") diff --git a/ssh/session/session.go b/ssh/session/session.go index d9f4984caad..44bf65147a7 100644 --- a/ssh/session/session.go +++ b/ssh/session/session.go @@ -18,8 +18,8 @@ import ( "github.com/shellhub-io/shellhub/pkg/cache" "github.com/shellhub-io/shellhub/pkg/clock" "github.com/shellhub-io/shellhub/pkg/envs" - "github.com/shellhub-io/shellhub/pkg/httptunnel" "github.com/shellhub-io/shellhub/pkg/models" + "github.com/shellhub-io/shellhub/ssh/pkg/dialer" "github.com/shellhub-io/shellhub/ssh/pkg/host" "github.com/shellhub-io/shellhub/ssh/pkg/target" log "github.com/sirupsen/logrus" @@ -147,7 +147,7 @@ type Session struct { Client *Client api internalclient.Client - tunnel *httptunnel.Tunnel + dialer *dialer.Dialer // Events is a connection to the endpoint to save session's events. Events *Events @@ -236,7 +236,7 @@ func (s *Seats) SetPty(seat int, status bool) { // the session without registering, connecting to the agent, etc. // // It's designed to be used within New. -func NewSession(ctx gliderssh.Context, tunnel *httptunnel.Tunnel, cache cache.Cache) (*Session, error) { +func NewSession(ctx gliderssh.Context, dialer *dialer.Dialer, cache cache.Cache) (*Session, error) { snap := getSnapshot(ctx) api, err := internalclient.NewClient(nil) @@ -317,7 +317,7 @@ func NewSession(ctx gliderssh.Context, tunnel *httptunnel.Tunnel, cache cache.Ca session := &Session{ UID: ctx.SessionID(), api: api, - tunnel: tunnel, + dialer: dialer, Events: &Events{ mu: sync.Mutex{}, conn: events, @@ -555,24 +555,29 @@ func (s *Session) connect(ctx gliderssh.Context, authOpt authFunc) error { return nil } +var ErrDialUnknown = errors.New("unknown protocol version") + +// Dial establishes the underlying transport to the target device. For V1 +// transports an HTTP GET request is issued (legacy reverse tunnel). For +// V2 transports a multistream protocol selection is performed using the +// ProtoSSHOpen identifier followed by a JSON envelope with the session +// id. After this method returns s.Agent.Conn is a raw channel ready for +// SSH key exchange and channel opens. func (s *Session) Dial(ctx gliderssh.Context) error { var err error ctx.Lock() - conn, err := s.tunnel.Dial(ctx, s.Device.TenantID+":"+s.Device.UID) + defer ctx.Unlock() + + conn, err := s.dialer.DialTo(ctx, s.Device.TenantID, s.Device.UID, dialer.SSHOpenTarget{SessionID: s.UID}) if err != nil { - return errors.Join(ErrDial, err) - } + log.WithFields(log.Fields{"session": s.UID, "sshid": s.SSHID}).WithError(err).Error("failed to open ssh session") - req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("/ssh/%s", s.UID), nil) - if err = req.Write(conn); err != nil { - return err + return errors.Join(ErrDial, err) } s.Agent.Conn = conn - ctx.Unlock() - return nil }