Skip to content

Commit

Permalink
better interface for instantiating tp.Cient
Browse files Browse the repository at this point in the history
  • Loading branch information
btoews committed Nov 16, 2023
1 parent a793d0f commit 746fdb8
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 50 deletions.
11 changes: 10 additions & 1 deletion flyio/flyio.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package flyio

import "github.com/superfly/macaroon"
import (
"github.com/superfly/macaroon"
"github.com/superfly/macaroon/tp"
)

const (
// well-known locations
Expand All @@ -14,3 +17,9 @@ const (
func ParsePermissionAndDischargeTokens(header string) ([]byte, [][]byte, error) {
return macaroon.ParsePermissionAndDischargeTokens(header, LocationPermission)
}

// DischargeClient returns a *tp.Client suitable for discharging third party
// caveats in fly.io permission tokens.
func DischargeClient(opts ...tp.ClientOption) *tp.Client {
return tp.NewClient(LocationPermission, opts...)
}
154 changes: 123 additions & 31 deletions tp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,100 @@ import (

type ClientOption func(*Client)

// WithHTTP specifies the HTTP client to use for requests to third parties.
// Third parties may try to set cookies to expedite future discharge flows. This
// may be facilitated by setting the http.Client's Jar field. With cookies
// enabled it's important to use a different cookie jar and hence client when
// fetching discharge tokens for multiple users.
func WithHTTP(h *http.Client) ClientOption {
return func(c *Client) {
if c.http == nil {
c.http = h
return
}

if authed, isAuthed := c.http.Transport.(*authenticatedHTTP); isAuthed {
authed.t = h.Transport
cpy := *h
cpy.Transport = authed
c.http = &cpy
return
}

c.http = h
}
}

// WithBearerAuthentication specifies a token to be sent in requests to the
// specified host in the `Authorization: Bearer` header.
func WithBearerAuthentication(hostname, token string) ClientOption {
token = "Bearer " + token

return func(c *Client) {
if c.http == nil {
cpy := *http.DefaultClient
c.http = &cpy
}

switch t := c.http.Transport.(type) {
case *authenticatedHTTP:
t.auth[hostname] = token
default:
c.http.Transport = &authenticatedHTTP{
t: t,
auth: map[string]string{hostname: token},
}
}
}
}

// WithUserURLCallback specifies a function to call when when the third party
// needs to interact with the end-user directly. The provided URL should be
// opened in the user's browser if possible. Otherwise it should be displayed to
// the user and they should be instructed to open it themselves. (Optional, but
// attempts at user-interactive discharge flow will fail)
func WithUserURLCallback(cb func(ctx context.Context, url string) error) ClientOption {
return func(c *Client) {
c.UserURLCallback = cb
}
}

// WithPollingBackoff specifies a function determining how long to wait before
// making the next request when polling the third party to see if a discharge is
// ready. This is called the first time with a zero duration. (Optional)
func WithPollingBackoff(nextBackoff func(lastBO time.Duration) (nextBO time.Duration)) ClientOption {
return func(c *Client) {
c.PollBackoffNext = nextBackoff
}
}

type Client struct {
// Location identifier for the party that issued the first party macaroon.
FirstPartyLocation string

// HTTP client to use for requests to third parties. Third parties may try
// to set cookies to expedite future discharge flows. This may be
// facilitated by setting the http.Client's Jar field. With cookies enabled
// it's important to use a different cookie jar and hence client when
// fetching discharge tokens for multiple users.
HTTP *http.Client

// Function to call when when the third party needs to interact with the
// end-user directly. The provided URL should be opened in the user's
// browser if possible. Otherwise it should be displayed to the user and
// they should be instructed to open it themselves. (Optional, but attempts
// at user-interactive discharge flow will fail)
UserURLCallback func(ctx context.Context, url string) error

// A function determining how long to wait before making the next request
// when polling the third party to see if a discharge is ready. This is
// called the first time with a zero duration. (Optional)
PollBackoffNext func(lastBO time.Duration) (nextBO time.Duration)
firstPartyLocation string
http *http.Client
UserURLCallback func(ctx context.Context, url string) error
PollBackoffNext func(lastBO time.Duration) (nextBO time.Duration)
}

// NewClient returns a Client for discharging third party caveats in macaroons
// issued by the specified first party.
func NewClient(firstPartyLocation string, opts ...ClientOption) *Client {
client := &Client{
firstPartyLocation: firstPartyLocation,
}

for _, opt := range opts {
opt(client)
}

if client.http == nil {
client.http = http.DefaultClient
}

if client.PollBackoffNext == nil {
client.PollBackoffNext = defaultBackoff
}

return client
}

func (c *Client) NeedsDischarge(tokenHeader string) (bool, error) {
Expand Down Expand Up @@ -93,7 +165,7 @@ func (c *Client) undischargedTickets(tokenHeader string) (map[string][][]byte, e
return nil, err
}

perms, _, _, disToks, err := macaroon.FindPermissionAndDischargeTokens(toks, c.FirstPartyLocation)
perms, _, _, disToks, err := macaroon.FindPermissionAndDischargeTokens(toks, c.firstPartyLocation)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -146,7 +218,7 @@ func (c *Client) doInitRequest(ctx context.Context, thirdPartyLocation string, t
}
hreq.Header.Set("Content-Type", "application/json")

hresp, err := c.http().Do(hreq)
hresp, err := c.http.Do(hreq)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -180,7 +252,7 @@ func (c *Client) doPoll(ctx context.Context, pollURL string) (string, error) {

pollLoop:
for {
hresp, err := c.http().Do(req)
hresp, err := c.http.Do(req)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -243,13 +315,6 @@ func (c *Client) openUserInteractiveURL(ctx context.Context, url string) error {
return errors.New("client not configured for opening URLs")
}

func (c *Client) http() *http.Client {
if c.HTTP != nil {
return c.HTTP
}
return http.DefaultClient
}

func initURL(location string) string {
if strings.HasSuffix(location, "/") {
return location + InitPath[1:]
Expand All @@ -265,3 +330,30 @@ type Error struct {
func (e Error) Error() string {
return fmt.Sprintf("tp error (%d): %s", e.StatusCode, e.Msg)
}

type authenticatedHTTP struct {
t http.RoundTripper
auth map[string]string
}

func (a *authenticatedHTTP) RoundTrip(r *http.Request) (*http.Response, error) {
if cred := a.auth[r.URL.Hostname()]; cred != "" {
r.Header.Set("Authorization", cred)
}

return a.transport().RoundTrip(r)
}

func (a *authenticatedHTTP) transport() http.RoundTripper {
if a.t == nil {
return http.DefaultTransport
}
return a.t
}

func defaultBackoff(lastBO time.Duration) (nextBO time.Duration) {
if lastBO == 0 {
return time.Second
}
return 2 * lastBO
}
9 changes: 4 additions & 5 deletions tp/immediate_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func newImmediateServer(tp *TP) *immediateSever {
}

func (is *immediateSever) handleInitRequest(w http.ResponseWriter, r *http.Request) {
if username, password, _ := r.BasicAuth(); username != "mulder" || password != "trustno1" {
if r.Header.Get("Authorization") != "Bearer trustno1" {
is.tp.RespondError(w, r, http.StatusUnauthorized, "bad client authentication")
return
}
Expand Down Expand Up @@ -69,10 +69,9 @@ func ExampleTP_RespondDischarge() {
_, err = validateFirstPartyMacaroon(firstPartyMacaroon)
fmt.Printf("validation error without 3p discharge token: %v\n", err)

client := &Client{
FirstPartyLocation: firstPartyLocation,
HTTP: basicAuthClient("mulder", "trustno1"),
}
client := NewClient(firstPartyLocation,
WithBearerAuthentication("127.0.0.1", "trustno1"),
)

firstPartyMacaroon, err = client.FetchDischargeTokens(context.Background(), firstPartyMacaroon)
if err != nil {
Expand Down
72 changes: 59 additions & 13 deletions tp/tp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -54,13 +55,60 @@ func TestTP(t *testing.T) {
})

hdr := genFP(t, tp, myCaveat("fp-cav"))
c := &Client{FirstPartyLocation: firstPartyLocation}
c := NewClient(firstPartyLocation)
hdr, err = c.FetchDischargeTokens(context.Background(), hdr)
assert.NoError(t, err)
cavs := checkFP(t, hdr)
assert.Equal(t, []string{"fp-cav", "dis-cav"}, cavs)
})

t.Run("WithBearerAuthentication", func(t *testing.T) {
t.Run("sends token to correct host", func(t *testing.T) {
handleInit = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer my-token" {
tp.RespondError(w, r, http.StatusUnauthorized, "bad client authentication")
return
}
_, err := CaveatsFromRequest(r)
assert.NoError(t, err)

tp.RespondDischarge(w, r)
})

u, err := url.Parse(tp.Location)
assert.NoError(t, err)

hdr := genFP(t, tp)
c := NewClient(firstPartyLocation,
WithBearerAuthentication(u.Hostname(), "my-token"),
)
hdr, err = c.FetchDischargeTokens(context.Background(), hdr)
assert.NoError(t, err)
checkFP(t, hdr)
})

t.Run("doesn't send token to wrong host", func(t *testing.T) {
handleInit = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "" {
tp.RespondError(w, r, http.StatusUnauthorized, "bad client authentication")
return
}
_, err := CaveatsFromRequest(r)
assert.NoError(t, err)

tp.RespondDischarge(w, r)
})

hdr := genFP(t, tp)
c := NewClient(firstPartyLocation,
WithBearerAuthentication("wrong.com", "my-token"),
)
hdr, err = c.FetchDischargeTokens(context.Background(), hdr)
assert.NoError(t, err)
checkFP(t, hdr)
})
})

t.Run("poll response", func(t *testing.T) {
pollSecret := ""
pollSecretSet := make(chan struct{})
Expand All @@ -75,15 +123,14 @@ func TestTP(t *testing.T) {

hdr := genFP(t, tp, myCaveat("fp-cav"))

c := &Client{
FirstPartyLocation: firstPartyLocation,
PollBackoffNext: func(last time.Duration) time.Duration {
c := NewClient(firstPartyLocation,
WithPollingBackoff(func(last time.Duration) time.Duration {
if last == 0 {
return 10 * time.Millisecond
}
return 10 * time.Second
},
}
}),
)

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand Down Expand Up @@ -120,20 +167,19 @@ func TestTP(t *testing.T) {

hdr := genFP(t, tp, myCaveat("fp-cav"))

c := &Client{
FirstPartyLocation: firstPartyLocation,
PollBackoffNext: func(last time.Duration) time.Duration {
c := NewClient(firstPartyLocation,
WithPollingBackoff(func(last time.Duration) time.Duration {
if last == 0 {
return 10 * time.Millisecond
}
return 10 * time.Second
},
UserURLCallback: func(_ context.Context, url string) error {
}),
WithUserURLCallback(func(_ context.Context, url string) error {
time.Sleep(10 * time.Millisecond)
assert.NoError(t, tp.DischargeUserInteractive(context.Background(), userSecret, myCaveat("dis-cav")))
return nil
},
}
}),
)

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand Down

0 comments on commit 746fdb8

Please sign in to comment.