From 69585ad086faaaef649300ccdaa59898a2229e2c Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 25 Mar 2024 16:15:14 -0700 Subject: [PATCH 1/8] Refactor OAuth2Listener to better support many ports --- cli/get.go | 20 +++++--- cli/login.go | 82 ++++++++++++++++----------------- cli/oauth2.go | 112 ++++++++++++++++++++++++++++++--------------- cli/oauth2_test.go | 59 ++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 85 deletions(-) create mode 100644 cli/oauth2_test.go diff --git a/cli/get.go b/cli/get.go index f5052d8a..65973ebd 100644 --- a/cli/get.go +++ b/cli/get.go @@ -39,11 +39,13 @@ func init() { getCmd.Flags().String(FlagRoleSessionName, "KeyConjurer-AssumeRole", "the name of the role session name that will show up in CloudTrail logs") getCmd.Flags().StringP(FlagOutputType, "o", outputTypeEnvironmentVariable, "Format to save new credentials in. Supported outputs: env, awscli,tencentcli") getCmd.Flags().String(FlagShellType, shellTypeInfer, "If output type is env, determines which format to output credentials in - by default, the format is inferred based on the execution environment. WSL users may wish to overwrite this to `bash`") - getCmd.Flags().String(FlagAWSCLIPath, "~/.aws/", "Path for directory used by the aws-cli tool. Default is \"~/.aws\".") getCmd.Flags().String(FlagTencentCLIPath, "~/.tencent/", "Path for directory used by the tencent-cli tool. Default is \"~/.tencent\".") getCmd.Flags().String(FlagCloudType, "aws", "Choose a cloud vendor. Default is aws. Can choose aws or tencent") getCmd.Flags().Bool(FlagBypassCache, false, "Do not check the cache for accounts and send the application ID as-is to Okta. This is useful if you have an ID you know is an Okta application ID and it is not stored in your local account cache.") getCmd.Flags().Bool(FlagLogin, false, "Login to Okta before running the command") + getCmd.Flags().String(FlagAWSCLIPath, "~/.aws/", "Path for directory used by the aws CLI") + getCmd.Flags().BoolP(FlagURLOnly, "u", false, "Print only the URL to visit rather than a user-friendly message") + getCmd.Flags().BoolP(FlagNoBrowser, "b", false, "Do not open a browser window, printing the URL instead") } func isMemberOfSlice(slice []string, val string) bool { @@ -74,16 +76,22 @@ A role must be specified when using this command through the --role flag. You ma ctx := cmd.Context() oidcDomain, _ := cmd.Flags().GetString(FlagOIDCDomain) clientID, _ := cmd.Flags().GetString(FlagClientID) + if HasTokenExpired(config.Tokens) { if ok, _ := cmd.Flags().GetBool(FlagLogin); ok { - token, err := Login(ctx, oidcDomain, clientID, LoginOutputModeBrowser{}) - if err != nil { - return err + urlOnly, _ := cmd.Flags().GetBool(FlagURLOnly) + noBrowser, _ := cmd.Flags().GetBool(FlagNoBrowser) + login := LoginCommand{ + Config: config, + OIDCDomain: oidcDomain, + ClientID: clientID, + MachineOutput: ShouldUseMachineOutput(cmd.Flags()) || urlOnly, + NoBrowser: noBrowser, } - if err := config.SaveOAuthToken(token); err != nil { + + if err := login.Execute(cmd.Context()); err != nil { return err } - } else { return ErrTokensExpiredOrAbsent } diff --git a/cli/login.go b/cli/login.go index 92c9683a..b5363ca7 100644 --- a/cli/login.go +++ b/cli/login.go @@ -9,7 +9,6 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" "golang.org/x/exp/slog" - "golang.org/x/oauth2" ) var ( @@ -45,65 +44,62 @@ var loginCmd = &cobra.Command{ oidcDomain, _ := cmd.Flags().GetString(FlagOIDCDomain) clientID, _ := cmd.Flags().GetString(FlagClientID) urlOnly, _ := cmd.Flags().GetBool(FlagURLOnly) - - var outputMode LoginOutputMode = LoginOutputModeBrowser{} - if noBrowser, _ := cmd.Flags().GetBool(FlagNoBrowser); noBrowser { - if ShouldUseMachineOutput(cmd.Flags()) || urlOnly { - outputMode = LoginOutputModeURLOnly{} - } else { - outputMode = LoginOutputModeHumanFriendlyMessage{} - } - } - - token, err := Login(cmd.Context(), oidcDomain, clientID, outputMode) - if err != nil { - return err + noBrowser, _ := cmd.Flags().GetBool(FlagNoBrowser) + command := LoginCommand{ + Config: config, + OIDCDomain: oidcDomain, + ClientID: clientID, + MachineOutput: ShouldUseMachineOutput(cmd.Flags()) || urlOnly, + NoBrowser: noBrowser, } - return config.SaveOAuthToken(token) + return command.Execute(cmd.Context()) }, } -func Login(ctx context.Context, domain, clientID string, outputMode LoginOutputMode) (*oauth2.Token, error) { - oauthCfg, err := DiscoverOAuth2Config(ctx, domain, clientID) +type LoginCommand struct { + Config *Config + OIDCDomain string + ClientID string + MachineOutput bool + NoBrowser bool +} + +func (c LoginCommand) Execute(ctx context.Context) error { + oauthCfg, err := DiscoverOAuth2Config(ctx, c.OIDCDomain, c.ClientID) if err != nil { - return nil, err + return err } - state, err := GenerateState() - if err != nil { - return nil, err + handler := RedirectionFlowHandler{ + Config: oauthCfg, + OnDisplayURL: openBrowserToURL, } - codeVerifier, codeChallenge, err := GenerateCodeVerifierAndChallenge() + if c.NoBrowser { + if c.MachineOutput { + handler.OnDisplayURL = printURLToConsole + } else { + handler.OnDisplayURL = friendlyPrintURLToConsole + } + } + + state := GenerateState() + challenge := GeneratePkceChallenge() + token, err := handler.HandlePendingSession(ctx, challenge, state) if err != nil { - return nil, err + return err } - return RedirectionFlow(ctx, oauthCfg, state, codeChallenge, codeVerifier, outputMode) + return c.Config.SaveOAuthToken(token) } -type LoginOutputMode interface { - PrintURL(url string) error +func friendlyPrintURLToConsole(url string) error { + fmt.Printf("Visit the following link in your terminal: %s\n", url) + return nil } -type LoginOutputModeBrowser struct{} - -func (LoginOutputModeBrowser) PrintURL(url string) error { +func openBrowserToURL(url string) error { slog.Debug("trying to open browser window", slog.String("url", url)) return browser.OpenURL(url) } - -type LoginOutputModeURLOnly struct{} - -func (LoginOutputModeURLOnly) PrintURL(url string) error { - fmt.Fprintln(os.Stdout, url) - return nil -} - -type LoginOutputModeHumanFriendlyMessage struct{} - -func (LoginOutputModeHumanFriendlyMessage) PrintURL(url string) error { - fmt.Printf("Visit the following link in your terminal: %s\n", url) - return nil -} diff --git a/cli/oauth2.go b/cli/oauth2.go index 701e9ef4..357714db 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -9,8 +9,10 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/http" "net/url" + "os" "strings" "github.com/RobotsAndPencils/go-saml" @@ -63,20 +65,10 @@ type OAuth2CallbackInfo struct { } type OAuth2Listener struct { - Addr string - errCh chan error + Socket net.Listener callbackCh chan OAuth2CallbackInfo } -func NewOAuth2Listener() OAuth2Listener { - return OAuth2Listener{ - // 5RIOT on a phone pad - Addr: ":57468", - errCh: make(chan error), - callbackCh: make(chan OAuth2CallbackInfo), - } -} - func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { info := OAuth2CallbackInfo{ Error: r.FormValue("error"), @@ -88,6 +80,23 @@ func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { return info, nil } +func NewOAuth2Listener(socket net.Listener) OAuth2Listener { + return OAuth2Listener{ + Socket: socket, + callbackCh: make(chan OAuth2CallbackInfo), + } +} + +func (o OAuth2Listener) Close() error { + if o.callbackCh != nil { + close(o.callbackCh) + } + if o.Socket != nil { + return o.Socket.Close() + } + return nil +} + func (o OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { info, err := ParseCallbackRequest(r) if err == nil { @@ -99,19 +108,13 @@ func (o OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "You may close this window now.") } -func (o OAuth2Listener) Listen(ctx context.Context) { - server := http.Server{Addr: o.Addr, Handler: o} - go func() { - <-ctx.Done() - server.Close() - }() - - if err := server.ListenAndServe(); err != http.ErrServerClosed { - o.errCh <- err +func (o OAuth2Listener) Listen() error { + err := http.Serve(o.Socket, o) + if errors.Is(err, http.ErrServerClosed) { + return nil } - close(o.callbackCh) - close(o.errCh) + return err } func (o OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { @@ -126,8 +129,6 @@ func (o OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state stri } return info.Code, nil - case err := <-o.errCh: - return "", err case <-ctx.Done(): return "", ctx.Err() } @@ -142,31 +143,70 @@ func (e OAuth2Error) Error() string { return fmt.Sprintf("oauth2 error: %s (%s)", e.Description, e.Reason) } -func GenerateCodeVerifierAndChallenge() (string, string, error) { +func GeneratePkceChallenge() PkceChallenge { codeVerifierBuf := make([]byte, stateBufSize) rand.Read(codeVerifierBuf) codeVerifier := base64.RawURLEncoding.EncodeToString(codeVerifierBuf) codeChallengeHash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(codeChallengeHash[:]) - return codeVerifier, codeChallenge, nil + return PkceChallenge{Verifier: codeVerifier, Challenge: codeChallenge} } -func GenerateState() (string, error) { +func GenerateState() string { stateBuf := make([]byte, stateBufSize) rand.Read(stateBuf) - return base64.URLEncoding.EncodeToString([]byte(stateBuf)), nil + return base64.URLEncoding.EncodeToString(stateBuf) +} + +type PkceChallenge struct { + Challenge string + Verifier string +} + +func printURLToConsole(url string) error { + fmt.Fprintln(os.Stdout, url) + return nil +} + +type RedirectionFlowHandler struct { + Config *oauth2.Config + OnDisplayURL func(url string) error + + // Listen is a function that can be provided to override how the redirection flow handler opens a network socket. + // If this is not specified, the handler will attempt to create a connection that listens to 0.0.0.0:57468 on IPv4. + Listen func() (net.Listener, error) } -func RedirectionFlow(ctx context.Context, oauthCfg *oauth2.Config, state, codeChallenge, codeVerifier string, outputMode LoginOutputMode) (*oauth2.Token, error) { - listener := NewOAuth2Listener() - go listener.Listen(ctx) - oauthCfg.RedirectURL = "http://localhost:57468" - url := oauthCfg.AuthCodeURL(state, +func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challenge PkceChallenge, state string) (*oauth2.Token, error) { + if r.OnDisplayURL == nil { + r.OnDisplayURL = printURLToConsole + } + + if r.Listen == nil { + r.Listen = func() (net.Listener, error) { + var lc net.ListenConfig + sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort("0.0.0.0", "57468")) + return sock, err + } + } + + sock, err := r.Listen() + if err != nil { + return nil, err + } + + r.Config.RedirectURL = "http://localhost:57468" + url := r.Config.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), - oauth2.SetAuthURLParam("code_challenge", codeChallenge), + oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), ) - if err := outputMode.PrintURL(url); err != nil { + listener := NewOAuth2Listener(sock) + defer listener.Close() + // This error can be ignored. + go listener.Listen() + + if err := r.OnDisplayURL(url); err != nil { // This is unlikely to ever happen return nil, fmt.Errorf("failed to display link: %w", err) } @@ -176,7 +216,7 @@ func RedirectionFlow(ctx context.Context, oauthCfg *oauth2.Config, state, codeCh return nil, fmt.Errorf("failed to get authorization code: %w", err) } - return oauthCfg.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier)) } func ExchangeAccessTokenForWebSSOToken(ctx context.Context, client *http.Client, oauthCfg *oauth2.Config, token *TokenSet, applicationID string) (*oauth2.Token, error) { diff --git a/cli/oauth2_test.go b/cli/oauth2_test.go new file mode 100644 index 00000000..b72067bb --- /dev/null +++ b/cli/oauth2_test.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. +func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { + // TODO: Initializing private fields means we should probably refactor OAuth2Listener's constructor to not require net.Listener, + // but instead to have net.Listener be a thing in Listen(); + // + // Users should be able to use ServeHTTP. + listener := OAuth2Listener{ + callbackCh: make(chan OAuth2CallbackInfo), + } + + expectedState := "state goes here" + expectedCode := "code goes here" + + go func() { + w := httptest.NewRecorder() + values := url.Values{ + "code": []string{expectedCode}, + "state": []string{expectedState}, + } + + uri := url.URL{ + Scheme: "http", + Host: "localhost", + Path: "/oauth2/callback", + RawQuery: values.Encode(), + } + + req := httptest.NewRequest("GET", uri.String(), nil) + listener.ServeHTTP(w, req) + }() + + deadline, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + code, err := listener.WaitForAuthorizationCode(deadline, expectedState) + assert.NoError(t, err) + cancel() + + assert.Equal(t, expectedCode, code) + assert.NoError(t, listener.Close()) +} + +func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { + var listener OAuth2Listener + deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) + _, err := listener.WaitForAuthorizationCode(deadline, "") + assert.ErrorIs(t, context.DeadlineExceeded, err) + assert.NoError(t, listener.Close()) +} From fe7dfcff2faad99e8b0d427580f96797d7c8434a Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 25 Mar 2024 17:46:54 -0700 Subject: [PATCH 2/8] Add the ability to attempt many ports if one fails Okta does not implement the BCP212 correctly, and does not allow wildcard ports in loopback addresses. This is our way of circumventing that in the event that one of the ports is taken. --- cli/consts.go | 1 + cli/login.go | 6 ++++++ cli/oauth2.go | 44 ++++++++++++++++++++++++++++++------- cli/oauth2_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 8 deletions(-) diff --git a/cli/consts.go b/cli/consts.go index 7dfd8e32..678c0a37 100644 --- a/cli/consts.go +++ b/cli/consts.go @@ -8,6 +8,7 @@ var ( Version = "TBD" BuildTimestamp = "BuildTimestamp is not set" DownloadURL = "URL not set yet" + CallbackPorts = []string{"57468", "58888", "59999", "60000"} ) const ( diff --git a/cli/login.go b/cli/login.go index b5363ca7..6982ff86 100644 --- a/cli/login.go +++ b/cli/login.go @@ -73,6 +73,7 @@ func (c LoginCommand) Execute(ctx context.Context) error { handler := RedirectionFlowHandler{ Config: oauthCfg, + Listen: ListenAnyPort("127.0.0.1", CallbackPorts), OnDisplayURL: openBrowserToURL, } @@ -94,6 +95,11 @@ func (c LoginCommand) Execute(ctx context.Context) error { return c.Config.SaveOAuthToken(token) } +func printURLToConsole(url string) error { + fmt.Fprintln(os.Stdout, url) + return nil +} + func friendlyPrintURLToConsole(url string) error { fmt.Printf("Visit the following link in your terminal: %s\n", url) return nil diff --git a/cli/oauth2.go b/cli/oauth2.go index 357714db..eb57070a 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -12,7 +12,6 @@ import ( "net" "net/http" "net/url" - "os" "strings" "github.com/RobotsAndPencils/go-saml" @@ -163,9 +162,32 @@ type PkceChallenge struct { Verifier string } -func printURLToConsole(url string) error { - fmt.Fprintln(os.Stdout, url) - return nil +var ErrNoPortsAvailable = errors.New("no ports available") + +// findFirstFreePort will attempt to open a network listener for each port in turn, and return the first one that succeeded. +// +// If none succeed, ErrNoPortsAvailable is returned. +// +// This is useful for supporting OIDC servers that do not allow for ephemeral ports to be used in the loopback address, like Okta. +func findFirstFreePort(ctx context.Context, broadcastAddr string, ports []string) (net.Listener, error) { + var lc net.ListenConfig + for _, port := range ports { + sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort(broadcastAddr, port)) + if err == nil { + return sock, nil + } + } + + return nil, ErrNoPortsAvailable +} + +// ListenAnyPort is a function that can be passed to RedirectionFlowHandler that will attempt to listen to exactly one of the ports in the supplied array. +// +// This function does not guarantee it will try ports in the order they are supplied, but it will return either a listener bound to exactly one of the ports, or the error ErrNoPortsAvailable. +func ListenAnyPort(broadcastAddr string, ports []string) func(ctx context.Context) (net.Listener, error) { + return func(ctx context.Context) (net.Listener, error) { + return findFirstFreePort(ctx, broadcastAddr, ports) + } } type RedirectionFlowHandler struct { @@ -174,7 +196,7 @@ type RedirectionFlowHandler struct { // Listen is a function that can be provided to override how the redirection flow handler opens a network socket. // If this is not specified, the handler will attempt to create a connection that listens to 0.0.0.0:57468 on IPv4. - Listen func() (net.Listener, error) + Listen func(ctx context.Context) (net.Listener, error) } func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challenge PkceChallenge, state string) (*oauth2.Token, error) { @@ -183,19 +205,25 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe } if r.Listen == nil { - r.Listen = func() (net.Listener, error) { + r.Listen = func(ctx context.Context) (net.Listener, error) { var lc net.ListenConfig sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort("0.0.0.0", "57468")) return sock, err } } - sock, err := r.Listen() + sock, err := r.Listen(ctx) + if err != nil { + return nil, err + } + + _, port, err := net.SplitHostPort(sock.Addr().String()) if err != nil { + // Failed to split the host and port. We need the port to continue, so bail return nil, err } - r.Config.RedirectURL = "http://localhost:57468" + r.Config.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port)) url := r.Config.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), diff --git a/cli/oauth2_test.go b/cli/oauth2_test.go index b72067bb..c2e97af5 100644 --- a/cli/oauth2_test.go +++ b/cli/oauth2_test.go @@ -2,12 +2,14 @@ package main import ( "context" + "net" "net/http/httptest" "net/url" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. @@ -57,3 +59,55 @@ func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { assert.ErrorIs(t, context.DeadlineExceeded, err) assert.NoError(t, listener.Close()) } + +// This test is going to be flaky because processes may open ports outside of our control. +func Test_ListenAnyPort_WorksCorrectly(t *testing.T) { + ports := []string{"58080", "58081", "58082", "58083"} + socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", ports[0])) + t.Cleanup(func() { + socket.Close() + }) + require.NoError(t, err, "Could not open socket on port: %s", ports[0]) + + listenFunc := ListenAnyPort("127.0.0.1", ports) + openedSocket, err := listenFunc(context.Background()) + + assert.NoError(t, err) + _, port, err := net.SplitHostPort(openedSocket.Addr().String()) + assert.NoError(t, err) + // There is no guarantee on which port FindFirstFreePort will choose, but it must pick one from the given list. + assert.Contains(t, ports, port) + openedSocket.Close() +} + +func Test_ListenAnyPort_RejectsIfNoPortsAvailable(t *testing.T) { + var ports []string + listenFunc := ListenAnyPort("127.0.0.1", ports) + _, err := listenFunc(context.Background()) + assert.ErrorIs(t, ErrNoPortsAvailable, err) +} + +func Test_ListenAnyPort_RejectsIfAllProvidedPortsExhausted(t *testing.T) { + ports := []string{"58080", "58081", "58082", "58083"} + var sockets []net.Listener + var activePorts []string + // This exhausts all sockets in 'ports' and dumps them into 'activePorts'. + for _, port := range ports { + socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", port)) + if err == nil { + sockets = append(sockets, socket) + activePorts = append(activePorts, port) + } + } + + require.NotEmpty(t, activePorts, "could not open any sockets") + + t.Cleanup(func() { + for _, socket := range sockets { + socket.Close() + } + }) + + _, err := ListenAnyPort("127.0.0.1", activePorts)(context.Background()) + assert.ErrorIs(t, err, ErrNoPortsAvailable) +} From 7bf7c076030a690c67e7a5506b803d40e5a3c387 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 25 Mar 2024 18:09:05 -0700 Subject: [PATCH 3/8] Fix typo --- cli/oauth2.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/oauth2.go b/cli/oauth2.go index eb57070a..d6719555 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -99,7 +99,7 @@ func (o OAuth2Listener) Close() error { func (o OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { info, err := ParseCallbackRequest(r) if err == nil { - // The only errors that might occur would be incorreclty formatted requests, which we will silently drop. + // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. o.callbackCh <- info } From 7555b7951968304f2dce0bd04a4fce3ce8951e3c Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 25 Mar 2024 18:24:03 -0700 Subject: [PATCH 4/8] Extract listenFixedPort function --- cli/oauth2.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cli/oauth2.go b/cli/oauth2.go index d6719555..657fffa4 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -190,6 +190,12 @@ func ListenAnyPort(broadcastAddr string, ports []string) func(ctx context.Contex } } +func listenFixedPort(ctx context.Context) (net.Listener, error) { + var lc net.ListenConfig + sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort("0.0.0.0", "57468")) + return sock, err +} + type RedirectionFlowHandler struct { Config *oauth2.Config OnDisplayURL func(url string) error @@ -205,11 +211,7 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe } if r.Listen == nil { - r.Listen = func(ctx context.Context) (net.Listener, error) { - var lc net.ListenConfig - sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort("0.0.0.0", "57468")) - return sock, err - } + r.Listen = listenFixedPort } sock, err := r.Listen(ctx) From 6000b9d50ce0bb2ece0cd0504c884e58c881326a Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 25 Mar 2024 18:43:56 -0700 Subject: [PATCH 5/8] Fix a race condition when browsers would send two requests Some browsers would send two requests, with one landing reliably between when the OAuth2Listener was closed (and thus, its channel was closed) and when the http server would be closed. This change solves this problem by closing the channel after the request is received and only ever processing a single request. Other requests will receive responses, but will be silently ignored --- cli/oauth2.go | 63 +++++++++++++++++------------------------ cli/oauth2_test.go | 70 +++++++++++++++++++++++++++------------------- 2 files changed, 67 insertions(+), 66 deletions(-) diff --git a/cli/oauth2.go b/cli/oauth2.go index 657fffa4..12a5cc30 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "strings" + "sync" "github.com/RobotsAndPencils/go-saml" "github.com/coreos/go-oidc" @@ -63,11 +64,6 @@ type OAuth2CallbackInfo struct { ErrorDescription string } -type OAuth2Listener struct { - Socket net.Listener - callbackCh chan OAuth2CallbackInfo -} - func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { info := OAuth2CallbackInfo{ Error: r.FormValue("error"), @@ -79,44 +75,35 @@ func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { return info, nil } -func NewOAuth2Listener(socket net.Listener) OAuth2Listener { - return OAuth2Listener{ - Socket: socket, - callbackCh: make(chan OAuth2CallbackInfo), - } +// OAuth2Listener will listen for a single callback request from a web server and return the code if it matched, or an error otherwise. +type OAuth2Listener struct { + once sync.Once + callbackCh chan OAuth2CallbackInfo } -func (o OAuth2Listener) Close() error { - if o.callbackCh != nil { - close(o.callbackCh) - } - if o.Socket != nil { - return o.Socket.Close() +func NewOAuth2Listener() OAuth2Listener { + return OAuth2Listener{ + callbackCh: make(chan OAuth2CallbackInfo, 1), } - return nil } -func (o OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { - info, err := ParseCallbackRequest(r) - if err == nil { - // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. - o.callbackCh <- info - } +func (o *OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // This can sometimes be called multiple times, depending on the browser. + // We will simply ignore any other requests and only serve the first. + o.once.Do(func() { + info, err := ParseCallbackRequest(r) + if err == nil { + // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. + o.callbackCh <- info + } + close(o.callbackCh) + }) - // This is displayed to the end user in their browser. + // We still want to provide feedback to the end-user. fmt.Fprintln(w, "You may close this window now.") } -func (o OAuth2Listener) Listen() error { - err := http.Serve(o.Socket, o) - if errors.Is(err, http.ErrServerClosed) { - return nil - } - - return err -} - -func (o OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { +func (o *OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { select { case info := <-o.callbackCh: if info.Error != "" { @@ -218,6 +205,7 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe if err != nil { return nil, err } + defer sock.Close() _, port, err := net.SplitHostPort(sock.Addr().String()) if err != nil { @@ -231,10 +219,9 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), ) - listener := NewOAuth2Listener(sock) - defer listener.Close() - // This error can be ignored. - go listener.Listen() + listener := NewOAuth2Listener() + // TODO: This error probably should not be ignored if it is not http.ErrServerClosed + go http.Serve(sock, &listener) if err := r.OnDisplayURL(url); err != nil { // This is unlikely to ever happen diff --git a/cli/oauth2_test.go b/cli/oauth2_test.go index c2e97af5..39e90300 100644 --- a/cli/oauth2_test.go +++ b/cli/oauth2_test.go @@ -12,36 +12,29 @@ import ( "github.com/stretchr/testify/require" ) -// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. -func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { - // TODO: Initializing private fields means we should probably refactor OAuth2Listener's constructor to not require net.Listener, - // but instead to have net.Listener be a thing in Listen(); - // - // Users should be able to use ServeHTTP. - listener := OAuth2Listener{ - callbackCh: make(chan OAuth2CallbackInfo), +func sendRequestToListener(listener *OAuth2Listener, values url.Values) { + uri := url.URL{ + Scheme: "http", + Host: "localhost", + Path: "/oauth2/callback", + RawQuery: values.Encode(), } + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", uri.String(), nil) + listener.ServeHTTP(w, req) +} + +// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. +func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { + listener := NewOAuth2Listener() expectedState := "state goes here" expectedCode := "code goes here" - go func() { - w := httptest.NewRecorder() - values := url.Values{ - "code": []string{expectedCode}, - "state": []string{expectedState}, - } - - uri := url.URL{ - Scheme: "http", - Host: "localhost", - Path: "/oauth2/callback", - RawQuery: values.Encode(), - } - - req := httptest.NewRequest("GET", uri.String(), nil) - listener.ServeHTTP(w, req) - }() + go sendRequestToListener(&listener, url.Values{ + "code": []string{expectedCode}, + "state": []string{expectedState}, + }) deadline, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) code, err := listener.WaitForAuthorizationCode(deadline, expectedState) @@ -49,18 +42,39 @@ func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { cancel() assert.Equal(t, expectedCode, code) - assert.NoError(t, listener.Close()) } func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { var listener OAuth2Listener deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) _, err := listener.WaitForAuthorizationCode(deadline, "") + // This will timeout because the OAuth2Listener will forever listen on a nil channel assert.ErrorIs(t, context.DeadlineExceeded, err) - assert.NoError(t, listener.Close()) } -// This test is going to be flaky because processes may open ports outside of our control. +// Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel +func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { + listener := NewOAuth2Listener() + expectedState := "state goes here" + expectedCode := "code goes here" + + go sendRequestToListener(&listener, url.Values{ + "code": []string{expectedCode}, + "state": []string{expectedState}, + }) + + go sendRequestToListener(&listener, url.Values{ + "code": []string{"not the expected code and should be discarded"}, + "state": []string{"not the expected state and should be discarded"}, + }) + + deadlineCtx, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) + code, err := listener.WaitForAuthorizationCode(deadlineCtx, expectedState) + assert.NoError(t, err) + assert.Equal(t, expectedCode, code) +} + +// Test_ListenAnyPort_WorksCorrectly is going to be flaky because processes may open ports outside of our control. func Test_ListenAnyPort_WorksCorrectly(t *testing.T) { ports := []string{"58080", "58081", "58082", "58083"} socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", ports[0])) From 23d2af47e818940aeae5517df376e0dca405c57c Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Mon, 25 Mar 2024 21:26:32 -0700 Subject: [PATCH 6/8] Make it harder to cause a race condition to begin with Instead of returning a struct with mutable struct, we can make it harder to cause a race condition by wrapping all of that state instead of a function closure which cannot be modified from the outside, and by using sync.Once() to ensure that only one request is ever handled. We only need to have the sync.Once() for close()ing the channel to permit a user to permit a user to free the resources manually. It is still possible to cause a panic by closing this channel before receiving a request. --- cli/oauth2.go | 125 ++++++++++++++++++++++++--------------------- cli/oauth2_test.go | 80 ++++++++++++++++++----------- 2 files changed, 119 insertions(+), 86 deletions(-) diff --git a/cli/oauth2.go b/cli/oauth2.go index 12a5cc30..6aa1b442 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -57,67 +57,73 @@ func DiscoverOAuth2Config(ctx context.Context, domain, clientID string) (*oauth2 return &cfg, nil } -type OAuth2CallbackInfo struct { - Code string - State string - Error string - ErrorDescription string -} - -func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) { - info := OAuth2CallbackInfo{ - Error: r.FormValue("error"), - ErrorDescription: r.FormValue("error_description"), - State: r.FormValue("state"), - Code: r.FormValue("code"), - } - - return info, nil +// OAuth2CallbackState encapsulates all of the information from an oauth2 callback. +// +// To retrieve the Code from the struct, you must use the Verify(string) function. +type OAuth2CallbackState struct { + code string + state string + errorMessage string + errorDescription string } -// OAuth2Listener will listen for a single callback request from a web server and return the code if it matched, or an error otherwise. -type OAuth2Listener struct { - once sync.Once - callbackCh chan OAuth2CallbackInfo +// FromRequest parses the given http.Request and populates the OAuth2CallbackState with those values. +func (o *OAuth2CallbackState) FromRequest(r *http.Request) { + o.errorMessage = r.FormValue("error") + o.errorDescription = r.FormValue("error_description") + o.state = r.FormValue("state") + o.code = r.FormValue("code") } -func NewOAuth2Listener() OAuth2Listener { - return OAuth2Listener{ - callbackCh: make(chan OAuth2CallbackInfo, 1), +// Verify safely compares the given state with the state from the OAuth2 callback. +// +// If they match, the code is returned, with a nil value. Otherwise, an empty string and an error is returned. +func (o OAuth2CallbackState) Verify(expectedState string) (string, error) { + if o.errorMessage != "" { + return "", OAuth2Error{Reason: o.errorMessage, Description: o.errorDescription} } -} -func (o *OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // This can sometimes be called multiple times, depending on the browser. - // We will simply ignore any other requests and only serve the first. - o.once.Do(func() { - info, err := ParseCallbackRequest(r) - if err == nil { - // The only errors that might occur would be incorrectly formatted requests, which we will silently drop. - o.callbackCh <- info - } - close(o.callbackCh) - }) + if strings.Compare(o.state, expectedState) != 0 { + return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"} + } - // We still want to provide feedback to the end-user. - fmt.Fprintln(w, "You may close this window now.") + return o.code, nil } -func (o *OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) { - select { - case info := <-o.callbackCh: - if info.Error != "" { - return "", OAuth2Error{Reason: info.Error, Description: info.ErrorDescription} - } - - if strings.Compare(info.State, state) != 0 { - return "", OAuth2Error{Reason: "invalid_state", Description: "state mismatch"} - } +// OAuth2CallbackHandler returns a http.Handler, channel and function triple. +// +// The http handler will accept exactly one request, which it will assume is an OAuth2 callback, parse it into an OAuth2CallbackState and then provide it to the given channel. Subsequent requests will be silently ignored. +// +// The function may be called to ensure that the channel is closed. The channel is closed when a request is received. In general, it is a good idea to ensure this function is called in a defer() block. +func OAuth2CallbackHandler() (http.Handler, <-chan OAuth2CallbackState, func()) { + // TODO: It is possible for the caller to close a panic() if they execute the function in the triplet while the handler has not yet received a request. + // That caller is us, so I don't care that much, but that probably indicates that this design is smelly. + // + // We should look at the Go SDK to see how they handle similar cases - channels that are not bound by a timer, or similar. + + ch := make(chan OAuth2CallbackState, 1) + var reqHandle, closeHandle sync.Once + closeFn := func() { + closeHandle.Do(func() { + close(ch) + }) + } - return info.Code, nil - case <-ctx.Done(): - return "", ctx.Err() + fn := func(w http.ResponseWriter, r *http.Request) { + // This can sometimes be called multiple times, depending on the browser. + // We will simply ignore any other requests and only serve the first. + reqHandle.Do(func() { + var state OAuth2CallbackState + state.FromRequest(r) + ch <- state + closeFn() + }) + + // We still want to provide feedback to the end-user. + fmt.Fprintln(w, "You may close this window now.") } + + return http.HandlerFunc(fn), ch, closeFn } type OAuth2Error struct { @@ -219,21 +225,26 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), ) - listener := NewOAuth2Listener() + callbackHandler, ch, cancel := OAuth2CallbackHandler() // TODO: This error probably should not be ignored if it is not http.ErrServerClosed - go http.Serve(sock, &listener) + go http.Serve(sock, callbackHandler) + defer cancel() if err := r.OnDisplayURL(url); err != nil { // This is unlikely to ever happen return nil, fmt.Errorf("failed to display link: %w", err) } - code, err := listener.WaitForAuthorizationCode(ctx, state) - if err != nil { - return nil, fmt.Errorf("failed to get authorization code: %w", err) + select { + case info := <-ch: + code, err := info.Verify(state) + if err != nil { + return nil, fmt.Errorf("failed to get authorization code: %w", err) + } + return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier)) + case <-ctx.Done(): + return nil, ctx.Err() } - - return r.Config.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", challenge.Verifier)) } func ExchangeAccessTokenForWebSSOToken(ctx context.Context, client *http.Client, oauthCfg *oauth2.Config, token *TokenSet, applicationID string) (*oauth2.Token, error) { diff --git a/cli/oauth2_test.go b/cli/oauth2_test.go index 39e90300..260f310d 100644 --- a/cli/oauth2_test.go +++ b/cli/oauth2_test.go @@ -3,16 +3,16 @@ package main import ( "context" "net" + "net/http" "net/http/httptest" "net/url" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func sendRequestToListener(listener *OAuth2Listener, values url.Values) { +func sendOAuth2CallbackRequest(handler http.Handler, values url.Values) { uri := url.URL{ Scheme: "http", Host: "localhost", @@ -22,56 +22,78 @@ func sendRequestToListener(listener *OAuth2Listener, values url.Values) { w := httptest.NewRecorder() req := httptest.NewRequest("GET", uri.String(), nil) - listener.ServeHTTP(w, req) + handler.ServeHTTP(w, req) } -// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. -func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) { - listener := NewOAuth2Listener() +func Test_OAuth2CallbackHandler_YieldsCorrectlyFormattedState(t *testing.T) { + handler, ch, cancel := OAuth2CallbackHandler() + t.Cleanup(func() { + cancel() + }) + expectedState := "state goes here" expectedCode := "code goes here" - go sendRequestToListener(&listener, url.Values{ + go sendOAuth2CallbackRequest(handler, url.Values{ "code": []string{expectedCode}, "state": []string{expectedState}, }) - deadline, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) - code, err := listener.WaitForAuthorizationCode(deadline, expectedState) + callbackState := <-ch + code, err := callbackState.Verify(expectedState) assert.NoError(t, err) - cancel() - assert.Equal(t, expectedCode, code) } -func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { - var listener OAuth2Listener - deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) - _, err := listener.WaitForAuthorizationCode(deadline, "") - // This will timeout because the OAuth2Listener will forever listen on a nil channel - assert.ErrorIs(t, context.DeadlineExceeded, err) +func Test_OAuth2CallbackState_VerifyWorksCorrectly(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + expectedState := "state goes here" + expectedCode := "code goes here" + callbackState := OAuth2CallbackState{ + code: expectedCode, + state: expectedState, + } + code, err := callbackState.Verify(expectedState) + assert.NoError(t, err) + assert.Equal(t, expectedCode, code) + }) + + t.Run("unhappy path", func(t *testing.T) { + expectedState := "state goes here" + expectedCode := "code goes here" + callbackState := OAuth2CallbackState{ + code: expectedCode, + state: expectedState, + } + _, err := callbackState.Verify("mismatching state") + var oauthErr OAuth2Error + assert.ErrorAs(t, err, &oauthErr) + assert.Equal(t, "invalid_state", oauthErr.Reason) + }) } // Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic prevents an issue where OAuth2Listener would send a request to a closed channel func Test_OAuth2Listener_MultipleRequestsDoesNotCausePanic(t *testing.T) { - listener := NewOAuth2Listener() - expectedState := "state goes here" - expectedCode := "code goes here" + handler, ch, cancel := OAuth2CallbackHandler() + t.Cleanup(func() { + cancel() + }) - go sendRequestToListener(&listener, url.Values{ - "code": []string{expectedCode}, - "state": []string{expectedState}, + go sendOAuth2CallbackRequest(handler, url.Values{ + // We send empty values because we don't care about processing in this test + "code": []string{""}, + "state": []string{""}, }) - go sendRequestToListener(&listener, url.Values{ + // We drain the channel of the first request so the handler completes. + // Without this step, we would get 'stuck' in the sync.Once(). + <-ch + + // We send this request synchronously to ensure that any panics are caught during the test. + sendOAuth2CallbackRequest(handler, url.Values{ "code": []string{"not the expected code and should be discarded"}, "state": []string{"not the expected state and should be discarded"}, }) - - deadlineCtx, _ := context.WithTimeout(context.Background(), 500*time.Millisecond) - code, err := listener.WaitForAuthorizationCode(deadlineCtx, expectedState) - assert.NoError(t, err) - assert.Equal(t, expectedCode, code) } // Test_ListenAnyPort_WorksCorrectly is going to be flaky because processes may open ports outside of our control. From c2c1e90bbc99dc9bd1b13e9a2b485f6492bfe4e9 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 26 Mar 2024 13:32:16 -0700 Subject: [PATCH 7/8] Add comments on CallbackPorts --- cli/consts.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cli/consts.go b/cli/consts.go index 678c0a37..8ba04543 100644 --- a/cli/consts.go +++ b/cli/consts.go @@ -8,7 +8,11 @@ var ( Version = "TBD" BuildTimestamp = "BuildTimestamp is not set" DownloadURL = "URL not set yet" - CallbackPorts = []string{"57468", "58888", "59999", "60000"} + // CallbackPorts is a list of ports that will be attempted in no particular order for hosting an Oauth2 callback web server. + // This cannot be set using -ldflags='-X ..' because -X requires that this be a string literal or uninitialized. + // + // These ports are chosen somewhat arbitrarily + CallbackPorts = []string{"57468", "47512", "57123", "61232", "48231", "49757", "59834", "54293"} ) const ( From 0a2e8e61510d942f7e7a7b0b3d4d775f6ed49179 Mon Sep 17 00:00:00 2001 From: Dan Pantry Date: Tue, 26 Mar 2024 17:34:28 -0700 Subject: [PATCH 8/8] Remove Keyconjurer Upgrade This feature has caused us a lot of problems over the years: * It would brick itself on Windows if you failed to download the file for any reason * Sometimes the DownloadURL would not be set * Relying on the releases always being at a specific URL meant that it was difficult to change it for any reason (For example, to support a stable and development release channel) We are removing this command and will instead move to supporting package managers like Homebrew for OSX and Chocolatey for Windows. It is recommended that folks who depend on this project do the same. --- cli/consts.go | 8 +-- cli/root.go | 1 - cli/upgrade.go | 139 ------------------------------------------------- 3 files changed, 1 insertion(+), 147 deletions(-) delete mode 100644 cli/upgrade.go diff --git a/cli/consts.go b/cli/consts.go index 8ba04543..7c151184 100644 --- a/cli/consts.go +++ b/cli/consts.go @@ -7,7 +7,6 @@ var ( ServerAddress string Version = "TBD" BuildTimestamp = "BuildTimestamp is not set" - DownloadURL = "URL not set yet" // CallbackPorts is a list of ports that will be attempted in no particular order for hosting an Oauth2 callback web server. // This cannot be set using -ldflags='-X ..' because -X requires that this be a string literal or uninitialized. // @@ -19,10 +18,5 @@ const ( // DefaultTTL for requested credentials in hours DefaultTTL uint = 1 // DefaultTimeRemaining for new key requests in minutes - DefaultTimeRemaining uint = 5 - LinuxAmd64BinaryName string = "keyconjurer-linux-amd64" - LinuxArm64BinaryName string = "keyconjurer-linux-arm64" - WindowsBinaryName string = "keyconjurer-windows.exe" - DarwinArm64BinaryName string = "keyconjurer-darwin-arm64" - DarwinAmd64BinaryName string = "keyconjurer-darwin-amd64" + DefaultTimeRemaining uint = 5 ) diff --git a/cli/root.go b/cli/root.go index 07838e85..1ddb6d2f 100644 --- a/cli/root.go +++ b/cli/root.go @@ -32,7 +32,6 @@ func init() { rootCmd.AddCommand(accountsCmd) rootCmd.AddCommand(getCmd) rootCmd.AddCommand(setCmd) - rootCmd.AddCommand(upgradeCmd) rootCmd.AddCommand(&switchCmd) rootCmd.AddCommand(&aliasCmd) rootCmd.AddCommand(&unaliasCmd) diff --git a/cli/upgrade.go b/cli/upgrade.go deleted file mode 100644 index 2f0a4215..00000000 --- a/cli/upgrade.go +++ /dev/null @@ -1,139 +0,0 @@ -package main - -import ( - "context" - "errors" - "fmt" - "io" - "net/http" - "os" - "os/exec" - "runtime" - - "github.com/spf13/cobra" -) - -var upgradeCmd = &cobra.Command{ - Use: "upgrade", - Short: "Downloads the latest version of KeyConjurer", - Args: cobra.ExactArgs(0), - RunE: func(cmd *cobra.Command, args []string) error { - keyConjurerRcPath, err := os.Executable() - if err != nil { - return err - } - - switch runtime.GOOS { - case "windows": - return windowsDownload(keyConjurerRcPath) - default: - return defaultDownload(context.Background(), NewHTTPClient(), keyConjurerRcPath) - } - }} - -// windowsDownload uses a special way to replace the binary due to restrictions in Windows. Because -// -// you cannot replace the currently executing binary, a temporary script is created. This script -// waits 3 seconds for the current process to exit, then downloads the latest Windows binary and -// replaces the old one, finally it removes itself from the filesystem. The cmd prompt should -// appear on the users screen to give them feedback that the download process began an ended. -func windowsDownload(keyConjurerRcPath string) error { - f, err := os.CreateTemp(os.TempDir(), "keyconjurer-downloader-*.cmd") - if err != nil { - return fmt.Errorf("unable to create download script: %w", err) - } - - command := fmt.Sprintf("timeout 3 && bitsadmin /transfer keyconjurerdownload /priority foreground /download %s/%s %s && del %s && exit", DownloadURL, WindowsBinaryName, keyConjurerRcPath, f.Name()) - fileData := []byte(command) - - if _, err = f.Write(fileData); err != nil { - return fmt.Errorf("failed to write to temporary file: %w", err) - } - - if err := f.Close(); err != nil { - return err - } - - cmd := exec.Command("cmd", "/C", "start", f.Name()) - return cmd.Start() -} - -// defaultDownload replaces the currently executing binary by writing over it directly. -func defaultDownload(ctx context.Context, client *http.Client, keyConjurerRcPath string) error { - tmp, err := os.CreateTemp(os.TempDir(), "keyconjurer") - if err != nil { - return fmt.Errorf("failed to create temporary file for upgrade: %w", err) - } - - defer tmp.Close() - src, err := DownloadLatestBinary(ctx, client, tmp) - if err != nil { - return fmt.Errorf("unable to download the latest binary: %w", err) - } - - if err := tmp.Close(); err != nil { - return fmt.Errorf("could not close tmp file: %w", err) - } - - bytesCopied, err := io.Copy(tmp, src) - if err != nil { - return fmt.Errorf("failed to copy new keyconjurer: %s", err) - } - - // Re-open the temporary file for reading and copy: - r, err := os.Open(tmp.Name()) - if err != nil { - return fmt.Errorf("could not open temporary file %s: %w", tmp.Name(), err) - } - - kc, _ := os.OpenFile(keyConjurerRcPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0744) - if err != nil { - return fmt.Errorf("unable to open %q: %w", keyConjurerRcPath, err) - } - - bytesCopied2, err := io.Copy(kc, r) - if err != nil || bytesCopied != bytesCopied2 { - // If an error occurs here, KeyConjurer has been overwritten and is potentially corrrupted - return fmt.Errorf("failed to copy new keyconjurer contents - keyconjurer is potentially corrupted and may need to be downloaded again: %w", err) - } - - return nil -} - -func getBinaryName() string { - switch runtime.GOOS { - case "linux": - if runtime.GOARCH == "arm64" { - return LinuxArm64BinaryName - } - - return LinuxAmd64BinaryName - case "windows": - return WindowsBinaryName - default: - if runtime.GOARCH == "arm64" { - return DarwinArm64BinaryName - } - return DarwinAmd64BinaryName - } -} - -// DownloadLatestBinary downloads the latest keyconjurer binary from the web. -func DownloadLatestBinary(ctx context.Context, client *http.Client, w io.Writer) (io.ReadCloser, error) { - binaryURL := fmt.Sprintf("%s/%s", DownloadURL, getBinaryName()) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, binaryURL, nil) - if err != nil { - return nil, fmt.Errorf("could not upgrade: %w", err) - } - - res, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("could not upgrade: %w", err) - } - - if res.StatusCode != 200 { - return nil, errors.New("could not upgrade: response did not indicate success - are you being blocked by the server?") - } - - return req.Body, nil -}