Skip to content

Commit

Permalink
Add the ability to attempt many ports if one fails
Browse files Browse the repository at this point in the history
Okta does not implement the BCP212 correctly, and does not allow
wildcard ports in loopback addresses. This is our way of circumventing
that in the event that one of the ports is taken.
  • Loading branch information
punmechanic committed Mar 26, 2024
1 parent 69585ad commit fe7dfcf
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 8 deletions.
1 change: 1 addition & 0 deletions cli/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ var (
Version = "TBD"
BuildTimestamp = "BuildTimestamp is not set"
DownloadURL = "URL not set yet"
CallbackPorts = []string{"57468", "58888", "59999", "60000"}
)

const (
Expand Down
6 changes: 6 additions & 0 deletions cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (c LoginCommand) Execute(ctx context.Context) error {

handler := RedirectionFlowHandler{
Config: oauthCfg,
Listen: ListenAnyPort("127.0.0.1", CallbackPorts),
OnDisplayURL: openBrowserToURL,
}

Expand All @@ -94,6 +95,11 @@ func (c LoginCommand) Execute(ctx context.Context) error {
return c.Config.SaveOAuthToken(token)
}

func printURLToConsole(url string) error {
fmt.Fprintln(os.Stdout, url)
return nil
}

func friendlyPrintURLToConsole(url string) error {
fmt.Printf("Visit the following link in your terminal: %s\n", url)
return nil
Expand Down
44 changes: 36 additions & 8 deletions cli/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"net"
"net/http"
"net/url"
"os"
"strings"

"github.com/RobotsAndPencils/go-saml"
Expand Down Expand Up @@ -163,9 +162,32 @@ type PkceChallenge struct {
Verifier string
}

func printURLToConsole(url string) error {
fmt.Fprintln(os.Stdout, url)
return nil
var ErrNoPortsAvailable = errors.New("no ports available")

// findFirstFreePort will attempt to open a network listener for each port in turn, and return the first one that succeeded.
//
// If none succeed, ErrNoPortsAvailable is returned.
//
// This is useful for supporting OIDC servers that do not allow for ephemeral ports to be used in the loopback address, like Okta.
func findFirstFreePort(ctx context.Context, broadcastAddr string, ports []string) (net.Listener, error) {
var lc net.ListenConfig
for _, port := range ports {
sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort(broadcastAddr, port))
if err == nil {
return sock, nil
}
}

return nil, ErrNoPortsAvailable
}

// ListenAnyPort is a function that can be passed to RedirectionFlowHandler that will attempt to listen to exactly one of the ports in the supplied array.
//
// This function does not guarantee it will try ports in the order they are supplied, but it will return either a listener bound to exactly one of the ports, or the error ErrNoPortsAvailable.
func ListenAnyPort(broadcastAddr string, ports []string) func(ctx context.Context) (net.Listener, error) {
return func(ctx context.Context) (net.Listener, error) {
return findFirstFreePort(ctx, broadcastAddr, ports)
}
}

type RedirectionFlowHandler struct {
Expand All @@ -174,7 +196,7 @@ type RedirectionFlowHandler struct {

// Listen is a function that can be provided to override how the redirection flow handler opens a network socket.
// If this is not specified, the handler will attempt to create a connection that listens to 0.0.0.0:57468 on IPv4.
Listen func() (net.Listener, error)
Listen func(ctx context.Context) (net.Listener, error)
}

func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challenge PkceChallenge, state string) (*oauth2.Token, error) {
Expand All @@ -183,19 +205,25 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe
}

if r.Listen == nil {
r.Listen = func() (net.Listener, error) {
r.Listen = func(ctx context.Context) (net.Listener, error) {
var lc net.ListenConfig
sock, err := lc.Listen(ctx, "tcp4", net.JoinHostPort("0.0.0.0", "57468"))
return sock, err
}
}

sock, err := r.Listen()
sock, err := r.Listen(ctx)
if err != nil {
return nil, err
}

_, port, err := net.SplitHostPort(sock.Addr().String())
if err != nil {
// Failed to split the host and port. We need the port to continue, so bail
return nil, err
}

r.Config.RedirectURL = "http://localhost:57468"
r.Config.RedirectURL = fmt.Sprintf("http://%s", net.JoinHostPort("localhost", port))
url := r.Config.AuthCodeURL(state,
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", challenge.Challenge),
Expand Down
54 changes: 54 additions & 0 deletions cli/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package main

import (
"context"
"net"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// This test ensures that the success flow of receiving an authorization code from a HTTP handler works correctly.
Expand Down Expand Up @@ -57,3 +59,55 @@ func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) {
assert.ErrorIs(t, context.DeadlineExceeded, err)
assert.NoError(t, listener.Close())
}

// This test is going to be flaky because processes may open ports outside of our control.
func Test_ListenAnyPort_WorksCorrectly(t *testing.T) {
ports := []string{"58080", "58081", "58082", "58083"}
socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", ports[0]))
t.Cleanup(func() {
socket.Close()
})
require.NoError(t, err, "Could not open socket on port: %s", ports[0])

listenFunc := ListenAnyPort("127.0.0.1", ports)
openedSocket, err := listenFunc(context.Background())

assert.NoError(t, err)
_, port, err := net.SplitHostPort(openedSocket.Addr().String())
assert.NoError(t, err)
// There is no guarantee on which port FindFirstFreePort will choose, but it must pick one from the given list.
assert.Contains(t, ports, port)
openedSocket.Close()
}

func Test_ListenAnyPort_RejectsIfNoPortsAvailable(t *testing.T) {
var ports []string
listenFunc := ListenAnyPort("127.0.0.1", ports)
_, err := listenFunc(context.Background())
assert.ErrorIs(t, ErrNoPortsAvailable, err)
}

func Test_ListenAnyPort_RejectsIfAllProvidedPortsExhausted(t *testing.T) {
ports := []string{"58080", "58081", "58082", "58083"}
var sockets []net.Listener
var activePorts []string
// This exhausts all sockets in 'ports' and dumps them into 'activePorts'.
for _, port := range ports {
socket, err := net.Listen("tcp4", net.JoinHostPort("127.0.0.1", port))
if err == nil {
sockets = append(sockets, socket)
activePorts = append(activePorts, port)
}
}

require.NotEmpty(t, activePorts, "could not open any sockets")

t.Cleanup(func() {
for _, socket := range sockets {
socket.Close()
}
})

_, err := ListenAnyPort("127.0.0.1", activePorts)(context.Background())
assert.ErrorIs(t, err, ErrNoPortsAvailable)
}

0 comments on commit fe7dfcf

Please sign in to comment.