diff --git a/cli/consts.go b/cli/consts.go index 7dfd8e32..678c0a37 100644 --- a/cli/consts.go +++ b/cli/consts.go @@ -8,6 +8,7 @@ var ( Version = "TBD" BuildTimestamp = "BuildTimestamp is not set" DownloadURL = "URL not set yet" + CallbackPorts = []string{"57468", "58888", "59999", "60000"} ) const ( diff --git a/cli/login.go b/cli/login.go index b5363ca7..6982ff86 100644 --- a/cli/login.go +++ b/cli/login.go @@ -73,6 +73,7 @@ func (c LoginCommand) Execute(ctx context.Context) error { handler := RedirectionFlowHandler{ Config: oauthCfg, + Listen: ListenAnyPort("127.0.0.1", CallbackPorts), OnDisplayURL: openBrowserToURL, } @@ -94,6 +95,11 @@ func (c LoginCommand) Execute(ctx context.Context) error { return c.Config.SaveOAuthToken(token) } +func printURLToConsole(url string) error { + fmt.Fprintln(os.Stdout, url) + return nil +} + func friendlyPrintURLToConsole(url string) error { fmt.Printf("Visit the following link in your terminal: %s\n", url) return nil diff --git a/cli/oauth2.go b/cli/oauth2.go index 357714db..eb57070a 100644 --- a/cli/oauth2.go +++ b/cli/oauth2.go @@ -12,7 +12,6 @@ import ( "net" "net/http" "net/url" - "os" "strings" "github.com/RobotsAndPencils/go-saml" @@ -163,9 +162,32 @@ type PkceChallenge struct { Verifier string } -func printURLToConsole(url string) error { - fmt.Fprintln(os.Stdout, url) - return nil +var ErrNoPortsAvailable = errors.New("no ports available") + +// findFirstFreePort will attempt to open a network listener for each port in turn, and return the first one that succeeded. +// +// If none succeed, ErrNoPortsAvailable is returned. +// +// This is useful for supporting OIDC servers that do not allow for ephemeral ports to be used in the loopback address, like Okta. +func findFirstFreePort(ctx context.Context, broadcastAddr string, ports []string) (net.Listener, error) { + var lc net.ListenConfig + for _, port := range ports { + sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort(broadcastAddr, port)) + if err == nil { + return sock, nil + } + } + + return nil, ErrNoPortsAvailable +} + +// ListenAnyPort is a function that can be passed to RedirectionFlowHandler that will attempt to listen to exactly one of the ports in the supplied array. +// +// This function does not guarantee it will try ports in the order they are supplied, but it will return either a listener bound to exactly one of the ports, or the error ErrNoPortsAvailable. +func ListenAnyPort(broadcastAddr string, ports []string) func(ctx context.Context) (net.Listener, error) { + return func(ctx context.Context) (net.Listener, error) { + return findFirstFreePort(ctx, broadcastAddr, ports) + } } type RedirectionFlowHandler struct { @@ -174,7 +196,7 @@ type RedirectionFlowHandler struct { // Listen is a function that can be provided to override how the redirection flow handler opens a network socket. // If this is not specified, the handler will attempt to create a connection that listens to 0.0.0.0:57468 on IPv4. - Listen func() (net.Listener, error) + Listen func(ctx context.Context) (net.Listener, error) } func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challenge PkceChallenge, state string) (*oauth2.Token, error) { @@ -183,19 +205,25 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe } if r.Listen == nil { - r.Listen = func() (net.Listener, error) { + r.Listen = func(ctx context.Context) (net.Listener, error) { var lc net.ListenConfig sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort("0.0.0.0", "57468")) return sock, err } } - sock, err := r.Listen() + sock, err := r.Listen(ctx) + if err != nil { + return nil, err + } + + _, port, err := net.SplitHostPort(sock.Addr().String()) if err != nil { + // Failed to split the host and port. We need the port to continue, so bail return nil, err } - r.Config.RedirectURL = "http://localhost:57468" + r.Config.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port)) url := r.Config.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", challenge.Challenge), diff --git a/cli/oauth2_test.go b/cli/oauth2_test.go index b72067bb..c2e97af5 100644 --- a/cli/oauth2_test.go +++ b/cli/oauth2_test.go @@ -2,12 +2,14 @@ package main import ( "context" + "net" "net/http/httptest" "net/url" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly. @@ -57,3 +59,55 @@ func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) { assert.ErrorIs(t, context.DeadlineExceeded, err) assert.NoError(t, listener.Close()) } + +// This test is going to be flaky because processes may open ports outside of our control. +func Test_ListenAnyPort_WorksCorrectly(t *testing.T) { + ports := []string{"58080", "58081", "58082", "58083"} + socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", ports[0])) + t.Cleanup(func() { + socket.Close() + }) + require.NoError(t, err, "Could not open socket on port: %s", ports[0]) + + listenFunc := ListenAnyPort("127.0.0.1", ports) + openedSocket, err := listenFunc(context.Background()) + + assert.NoError(t, err) + _, port, err := net.SplitHostPort(openedSocket.Addr().String()) + assert.NoError(t, err) + // There is no guarantee on which port FindFirstFreePort will choose, but it must pick one from the given list. + assert.Contains(t, ports, port) + openedSocket.Close() +} + +func Test_ListenAnyPort_RejectsIfNoPortsAvailable(t *testing.T) { + var ports []string + listenFunc := ListenAnyPort("127.0.0.1", ports) + _, err := listenFunc(context.Background()) + assert.ErrorIs(t, ErrNoPortsAvailable, err) +} + +func Test_ListenAnyPort_RejectsIfAllProvidedPortsExhausted(t *testing.T) { + ports := []string{"58080", "58081", "58082", "58083"} + var sockets []net.Listener + var activePorts []string + // This exhausts all sockets in 'ports' and dumps them into 'activePorts'. + for _, port := range ports { + socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", port)) + if err == nil { + sockets = append(sockets, socket) + activePorts = append(activePorts, port) + } + } + + require.NotEmpty(t, activePorts, "could not open any sockets") + + t.Cleanup(func() { + for _, socket := range sockets { + socket.Close() + } + }) + + _, err := ListenAnyPort("127.0.0.1", activePorts)(context.Background()) + assert.ErrorIs(t, err, ErrNoPortsAvailable) +}