Skip to content

Commit

Permalink
Merge pull request PelicanPlatform#1591 from bbockelm/generate_token
Browse files Browse the repository at this point in the history
Auto-generate token if the issuer is available
  • Loading branch information
jhiemstrawisc authored Sep 30, 2024
2 parents 779ab97 + 2e6a91c commit 98fc081
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 6 deletions.
134 changes: 134 additions & 0 deletions client/acquire_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package client

import (
"context"
"crypto/ecdsa"
"encoding/json"
"fmt"
"io/fs"
Expand All @@ -33,6 +34,7 @@ import (
"sync/atomic"
"time"

"github.com/lestrrat-go/jwx/v2/jwk"
jwt "github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
Expand All @@ -42,6 +44,8 @@ import (
"github.com/pelicanplatform/pelican/config"
oauth2 "github.com/pelicanplatform/pelican/oauth2"
"github.com/pelicanplatform/pelican/server_structs"
"github.com/pelicanplatform/pelican/token"
"github.com/pelicanplatform/pelican/token_scopes"
)

type (
Expand Down Expand Up @@ -588,7 +592,15 @@ func AcquireToken(destination *url.URL, dirResp server_structs.DirectorResponse,
}
var prefixEntry *config.PrefixEntry
newEntry := false
tryTokenGen := false
if prefixIdx < 0 {
// We prefer to generate a token over registering a new client.
if token, err := generateToken(destination, dirResp, opts); err == nil && token != "" {
log.Debugln("Successfully generated a new token from a local key")
return token, nil
}
tryTokenGen = true

log.Infof("Prefix configuration for %s not in configuration file; will request new client", nsPrefix)
prefixEntry, err = registerClient(dirResp)
if err != nil {
Expand All @@ -600,6 +612,14 @@ func AcquireToken(destination *url.URL, dirResp server_structs.DirectorResponse,
} else {
prefixEntry = &osdfConfig.OSDF.OauthClient[prefixIdx]
if len(prefixEntry.ClientID) == 0 || len(prefixEntry.ClientSecret) == 0 {

// Similarly, here, generate a token before registering a new client.
if token, err := generateToken(destination, dirResp, opts); err == nil && token != "" {
log.Debugln("Successfully generated a new token from a local key")
return token, nil
}
tryTokenGen = true

log.Infof("Prefix configuration for %s missing OAuth2 client information", nsPrefix)
prefixEntry, err = registerClient(dirResp)
if err != nil {
Expand Down Expand Up @@ -676,6 +696,15 @@ func AcquireToken(destination *url.URL, dirResp server_structs.DirectorResponse,
}
}

// If here, we've got a valid OAuth2 client credential but didn't have any luck refreshing -
// try generating the token before requiring a potentially user-interactive flow.
if !tryTokenGen {
if token, err := generateToken(destination, dirResp, opts); err == nil && token != "" {
log.Debugln("Successfully generated a new token from a local key")
return token, nil
}
}

token, err := oauth2.AcquireToken(issuer, prefixEntry, dirResp, destination.Path, opts)
if errors.Is(err, oauth2.ErrUnknownClient) {
// We use anonymously-registered clients; OA4MP can periodically garbage collect these to prevent DoS
Expand Down Expand Up @@ -706,3 +735,108 @@ func AcquireToken(destination *url.URL, dirResp server_structs.DirectorResponse,

return token.AccessToken, nil
}

// Given a URL and a known public key, determine whether the public key
// is valid for the issuer URL.
//
// If valid, returns the corresponding keyId and sets found to true.
func findKeyId(url string, ecPubKey *ecdsa.PublicKey) (keyid string, found bool) {
// Next, download the public keys for the issuer
ctx := context.Background()
issuerInfo, err := config.GetIssuerMetadata(url)
if err != nil {
log.Debugln("Failed to get metadata for", url, ":", err)
return
}
client := &http.Client{Transport: config.GetTransport()}
fetchOption := jwk.WithHTTPClient(client)
jwks, err := jwk.Fetch(ctx, issuerInfo.JwksUri, fetchOption)
if err != nil {
log.Debugln("Failed to fetch the JWKS:", err)
return
}
keyIter := jwks.Keys(ctx)
for keyIter.Next(ctx) {
pair := keyIter.Pair()
key, ok := pair.Value.(jwk.Key)
if !ok {
log.Debugln("Decode of JWK in return JWKS failed")
continue
}
var ecPubKey2 ecdsa.PublicKey
if err = key.Raw(&ecPubKey2); err != nil {
log.Debugln("Failed to convert public key:", err)
continue
}
if ecPubKey2.Equal(ecPubKey) {
return key.KeyID(), true
}
}
return
}

// Check to see if there's a copy of the issuer's pubkey locally; if so, generate an appropriate token directly.
func generateToken(destination *url.URL, dirResp server_structs.DirectorResponse, opts config.TokenGenerationOpts) (tkn string, err error) {
// Check to see if a private key is installed locally
key, err := config.GetIssuerPrivateJWK()
if err != nil {
log.Debugln("Cannot generate a token locally as private key is not present:", err)
return
}
log.Debugln("Trying to generate a token locally from issuer private key")
pubKey, err := key.PublicKey()
if err != nil {
log.Debugln("Cannot generate a token locally as the public key cannot be generated:", err)
return
}
var ecPubKey ecdsa.PublicKey
if err = pubKey.Raw(&ecPubKey); err != nil {
log.Debugln("Failed to convert JWT pub key to ECDSA:", err)
return
}

log.Debugln("Searching issuer public keys for matching key")
// Next, download the public keys for the issuer
var found bool
var keyId, issuer string
for _, issuerUrl := range dirResp.XPelAuthHdr.Issuers {
if issuerUrl == nil {
continue
}
issuer = issuerUrl.String()
keyId, found = findKeyId(issuer, &ecPubKey)
if found {
break
}
}
if !found {
log.Debugln("Failed to find public key at issuer corresponding to local public key")
return
}

tc, err := token.NewTokenConfig(token.TokenProfileWLCG)
if err != nil {
return
}
tc.AddAudienceAny()
tc.Issuer = issuer
tc.Lifetime = time.Hour
tc.Subject = "client_token"
ts := token_scopes.Storage_Read
if opts.Operation == config.TokenSharedWrite {
ts = token_scopes.Storage_Create
}
if after, found := strings.CutPrefix(path.Clean(destination.Path), path.Clean(dirResp.XPelNsHdr.Namespace)); found {
tc.AddResourceScopes(token_scopes.NewResourceScope(ts, after))
} else {
err = errors.New("Destination resource not inside director-provided namespace")
return
}

err = key.Set("kid", keyId)
if err != nil {
return
}
tkn, err = tc.CreateTokenWithKey(key)
return
}
50 changes: 44 additions & 6 deletions client/fed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,9 @@ func TestNewTransferJob(t *testing.T) {
// use our auth required namespace
mockRemoteUrl, err := url.Parse("/second/namespace/hello_world.txt")
require.NoError(t, err)
_, err = tc.NewTransferJob(context.Background(), mockRemoteUrl, "/dest", false, false)
_, err = tc.NewTransferJob(context.Background(), mockRemoteUrl, "/dest", false, false, client.WithAcquireToken(false))
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to get token for transfer: failed to find or generate a token as required for /second/namespace/hello_world.txt")
assert.Contains(t, err.Error(), "failed to get token for transfer: credential is required for /second/namespace/hello_world.txt but was not discovered")
})

// Test success
Expand Down Expand Up @@ -835,17 +835,17 @@ func TestObjectList(t *testing.T) {
for _, export := range fed.Exports {
listURL := fmt.Sprintf("pelican://%s:%d%s", param.Server_Hostname.GetString(), param.Server_WebPort.GetInt(), export.FederationPrefix)
if !export.Capabilities.PublicReads {
get, err := client.DoList(fed.Ctx, listURL, client.WithTokenLocation(""))
get, err := client.DoList(fed.Ctx, listURL, client.WithTokenLocation(""), client.WithAcquireToken(false))
require.Error(t, err)
assert.Len(t, get, 0)
assert.Contains(t, err.Error(), "failed to get token for transfer: failed to find or generate a token as required")
assert.Contains(t, err.Error(), "failed to get token for transfer: credential is required")

// No error if it's with token
get, err = client.DoList(fed.Ctx, listURL, client.WithTokenLocation(tempToken.Name()))
get, err = client.DoList(fed.Ctx, listURL, client.WithTokenLocation(tempToken.Name()), client.WithAcquireToken(false))
require.NoError(t, err)
require.Len(t, get, 2)
} else {
get, err := client.DoList(fed.Ctx, listURL, client.WithTokenLocation(tempToken.Name()))
get, err := client.DoList(fed.Ctx, listURL, client.WithTokenLocation(tempToken.Name()), client.WithAcquireToken(false))
require.NoError(t, err)
require.Len(t, get, 2)
}
Expand Down Expand Up @@ -962,3 +962,41 @@ func TestClientUnpack(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, int64(11), fi.Size())
}

// A test that generates a token locally from the private key
func TestTokenGenerate(t *testing.T) {
viper.Reset()
server_utils.ResetOriginExports()
fed := fed_test_utils.NewFedTest(t, bothAuthOriginCfg)

// Other set-up items:
testFileContent := "test file content"
// Create the temporary file to upload
tempFile, err := os.CreateTemp(t.TempDir(), "test")
require.NoError(t, err, "Error creating temp file")
defer os.Remove(tempFile.Name())
_, err = tempFile.WriteString(testFileContent)
require.NoError(t, err, "Error writing to temp file")
tempFile.Close()

// Disable progress bars to not reuse the same mpb instance
viper.Set("Logging.DisableProgressBars", true)

// Set path for object to upload/download
for _, export := range fed.Exports {
tempPath := tempFile.Name()
fileName := filepath.Base(tempPath)
uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()),
export.FederationPrefix, "token_gen", fileName)

// Upload the file with PUT
transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false)
require.NoError(t, err)
assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17))

// Download that same file with GET
transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false)
require.NoError(t, err)
assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes)
}
}
1 change: 1 addition & 0 deletions config/issuer_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

type OauthIssuer struct {
Issuer string `json:"issuer"`
JwksUri string `json:"jwks_uri"`
AuthURL string `json:"authorization_endpoint"`
DeviceAuthURL string `json:"device_authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
Expand Down
4 changes: 4 additions & 0 deletions xrootd/launch.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package xrootd
import (
"context"
_ "embed"
"os"
"path/filepath"
"regexp"
"strconv"
Expand Down Expand Up @@ -87,6 +88,9 @@ func makeUnprivilegedXrootdLauncher(daemonName string, configPath string, isCach
"XRD_PELICANCLIENTCERTFILE=" + filepath.Join(xrootdRun, "copied-tls-creds.crt"),
"XRD_PELICANCLIENTKEYFILE=" + filepath.Join(xrootdRun, "copied-tls-creds.crt"),
}
if confDir := os.Getenv("XRD_PLUGINCONFDIR"); confDir != "" {
result.ExtraEnv = append(result.ExtraEnv, "XRD_PLUGINCONFDIR="+confDir)
}
}
return
}
Expand Down

0 comments on commit 98fc081

Please sign in to comment.