diff --git a/client/acquire_token.go b/client/acquire_token.go index 9183658e7..20e6d4b68 100644 --- a/client/acquire_token.go +++ b/client/acquire_token.go @@ -20,6 +20,7 @@ package client import ( "context" + "crypto/ecdsa" "encoding/json" "fmt" "io/fs" @@ -33,6 +34,7 @@ import ( "sync/atomic" "time" + "github.com/lestrrat-go/jwx/v2/jwk" jwt "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -42,6 +44,8 @@ import ( "github.com/pelicanplatform/pelican/config" oauth2 "github.com/pelicanplatform/pelican/oauth2" "github.com/pelicanplatform/pelican/server_structs" + "github.com/pelicanplatform/pelican/token" + "github.com/pelicanplatform/pelican/token_scopes" ) type ( @@ -588,7 +592,15 @@ func AcquireToken(destination *url.URL, dirResp server_structs.DirectorResponse, } var prefixEntry *config.PrefixEntry newEntry := false + tryTokenGen := false if prefixIdx < 0 { + // We prefer to generate a token over registering a new client. + if token, err := generateToken(destination, dirResp, opts); err == nil && token != "" { + log.Debugln("Successfully generated a new token from a local key") + return token, nil + } + tryTokenGen = true + log.Infof("Prefix configuration for %s not in configuration file; will request new client", nsPrefix) prefixEntry, err = registerClient(dirResp) if err != nil { @@ -600,6 +612,14 @@ func AcquireToken(destination *url.URL, dirResp server_structs.DirectorResponse, } else { prefixEntry = &osdfConfig.OSDF.OauthClient[prefixIdx] if len(prefixEntry.ClientID) == 0 || len(prefixEntry.ClientSecret) == 0 { + + // Similarly, here, generate a token before registering a new client. + if token, err := generateToken(destination, dirResp, opts); err == nil && token != "" { + log.Debugln("Successfully generated a new token from a local key") + return token, nil + } + tryTokenGen = true + log.Infof("Prefix configuration for %s missing OAuth2 client information", nsPrefix) prefixEntry, err = registerClient(dirResp) if err != nil { @@ -676,6 +696,15 @@ func AcquireToken(destination *url.URL, dirResp server_structs.DirectorResponse, } } + // If here, we've got a valid OAuth2 client credential but didn't have any luck refreshing - + // try generating the token before requiring a potentially user-interactive flow. + if !tryTokenGen { + if token, err := generateToken(destination, dirResp, opts); err == nil && token != "" { + log.Debugln("Successfully generated a new token from a local key") + return token, nil + } + } + token, err := oauth2.AcquireToken(issuer, prefixEntry, dirResp, destination.Path, opts) if errors.Is(err, oauth2.ErrUnknownClient) { // We use anonymously-registered clients; OA4MP can periodically garbage collect these to prevent DoS @@ -706,3 +735,108 @@ func AcquireToken(destination *url.URL, dirResp server_structs.DirectorResponse, return token.AccessToken, nil } + +// Given a URL and a known public key, determine whether the public key +// is valid for the issuer URL. +// +// If valid, returns the corresponding keyId and sets found to true. +func findKeyId(url string, ecPubKey *ecdsa.PublicKey) (keyid string, found bool) { + // Next, download the public keys for the issuer + ctx := context.Background() + issuerInfo, err := config.GetIssuerMetadata(url) + if err != nil { + log.Debugln("Failed to get metadata for", url, ":", err) + return + } + client := &http.Client{Transport: config.GetTransport()} + fetchOption := jwk.WithHTTPClient(client) + jwks, err := jwk.Fetch(ctx, issuerInfo.JwksUri, fetchOption) + if err != nil { + log.Debugln("Failed to fetch the JWKS:", err) + return + } + keyIter := jwks.Keys(ctx) + for keyIter.Next(ctx) { + pair := keyIter.Pair() + key, ok := pair.Value.(jwk.Key) + if !ok { + log.Debugln("Decode of JWK in return JWKS failed") + continue + } + var ecPubKey2 ecdsa.PublicKey + if err = key.Raw(&ecPubKey2); err != nil { + log.Debugln("Failed to convert public key:", err) + continue + } + if ecPubKey2.Equal(ecPubKey) { + return key.KeyID(), true + } + } + return +} + +// Check to see if there's a copy of the issuer's pubkey locally; if so, generate an appropriate token directly. +func generateToken(destination *url.URL, dirResp server_structs.DirectorResponse, opts config.TokenGenerationOpts) (tkn string, err error) { + // Check to see if a private key is installed locally + key, err := config.GetIssuerPrivateJWK() + if err != nil { + log.Debugln("Cannot generate a token locally as private key is not present:", err) + return + } + log.Debugln("Trying to generate a token locally from issuer private key") + pubKey, err := key.PublicKey() + if err != nil { + log.Debugln("Cannot generate a token locally as the public key cannot be generated:", err) + return + } + var ecPubKey ecdsa.PublicKey + if err = pubKey.Raw(&ecPubKey); err != nil { + log.Debugln("Failed to convert JWT pub key to ECDSA:", err) + return + } + + log.Debugln("Searching issuer public keys for matching key") + // Next, download the public keys for the issuer + var found bool + var keyId, issuer string + for _, issuerUrl := range dirResp.XPelAuthHdr.Issuers { + if issuerUrl == nil { + continue + } + issuer = issuerUrl.String() + keyId, found = findKeyId(issuer, &ecPubKey) + if found { + break + } + } + if !found { + log.Debugln("Failed to find public key at issuer corresponding to local public key") + return + } + + tc, err := token.NewTokenConfig(token.TokenProfileWLCG) + if err != nil { + return + } + tc.AddAudienceAny() + tc.Issuer = issuer + tc.Lifetime = time.Hour + tc.Subject = "client_token" + ts := token_scopes.Storage_Read + if opts.Operation == config.TokenSharedWrite { + ts = token_scopes.Storage_Create + } + if after, found := strings.CutPrefix(path.Clean(destination.Path), path.Clean(dirResp.XPelNsHdr.Namespace)); found { + tc.AddResourceScopes(token_scopes.NewResourceScope(ts, after)) + } else { + err = errors.New("Destination resource not inside director-provided namespace") + return + } + + err = key.Set("kid", keyId) + if err != nil { + return + } + tkn, err = tc.CreateTokenWithKey(key) + return +} diff --git a/client/fed_test.go b/client/fed_test.go index 93e2176f4..ac3ac804a 100644 --- a/client/fed_test.go +++ b/client/fed_test.go @@ -752,9 +752,9 @@ func TestNewTransferJob(t *testing.T) { // use our auth required namespace mockRemoteUrl, err := url.Parse("/second/namespace/hello_world.txt") require.NoError(t, err) - _, err = tc.NewTransferJob(context.Background(), mockRemoteUrl, "/dest", false, false) + _, err = tc.NewTransferJob(context.Background(), mockRemoteUrl, "/dest", false, false, client.WithAcquireToken(false)) require.Error(t, err) - assert.Contains(t, err.Error(), "failed to get token for transfer: failed to find or generate a token as required for /second/namespace/hello_world.txt") + assert.Contains(t, err.Error(), "failed to get token for transfer: credential is required for /second/namespace/hello_world.txt but was not discovered") }) // Test success @@ -835,17 +835,17 @@ func TestObjectList(t *testing.T) { for _, export := range fed.Exports { listURL := fmt.Sprintf("pelican://%s:%d%s", param.Server_Hostname.GetString(), param.Server_WebPort.GetInt(), export.FederationPrefix) if !export.Capabilities.PublicReads { - get, err := client.DoList(fed.Ctx, listURL, client.WithTokenLocation("")) + get, err := client.DoList(fed.Ctx, listURL, client.WithTokenLocation(""), client.WithAcquireToken(false)) require.Error(t, err) assert.Len(t, get, 0) - assert.Contains(t, err.Error(), "failed to get token for transfer: failed to find or generate a token as required") + assert.Contains(t, err.Error(), "failed to get token for transfer: credential is required") // No error if it's with token - get, err = client.DoList(fed.Ctx, listURL, client.WithTokenLocation(tempToken.Name())) + get, err = client.DoList(fed.Ctx, listURL, client.WithTokenLocation(tempToken.Name()), client.WithAcquireToken(false)) require.NoError(t, err) require.Len(t, get, 2) } else { - get, err := client.DoList(fed.Ctx, listURL, client.WithTokenLocation(tempToken.Name())) + get, err := client.DoList(fed.Ctx, listURL, client.WithTokenLocation(tempToken.Name()), client.WithAcquireToken(false)) require.NoError(t, err) require.Len(t, get, 2) } @@ -962,3 +962,41 @@ func TestClientUnpack(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(11), fi.Size()) } + +// A test that generates a token locally from the private key +func TestTokenGenerate(t *testing.T) { + viper.Reset() + server_utils.ResetOriginExports() + fed := fed_test_utils.NewFedTest(t, bothAuthOriginCfg) + + // Other set-up items: + testFileContent := "test file content" + // Create the temporary file to upload + tempFile, err := os.CreateTemp(t.TempDir(), "test") + require.NoError(t, err, "Error creating temp file") + defer os.Remove(tempFile.Name()) + _, err = tempFile.WriteString(testFileContent) + require.NoError(t, err, "Error writing to temp file") + tempFile.Close() + + // Disable progress bars to not reuse the same mpb instance + viper.Set("Logging.DisableProgressBars", true) + + // Set path for object to upload/download + for _, export := range fed.Exports { + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + export.FederationPrefix, "token_gen", fileName) + + // Upload the file with PUT + transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false) + require.NoError(t, err) + assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) + + // Download that same file with GET + transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false) + require.NoError(t, err) + assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) + } +} diff --git a/config/issuer_metadata.go b/config/issuer_metadata.go index 080666b67..5e77370c3 100644 --- a/config/issuer_metadata.go +++ b/config/issuer_metadata.go @@ -29,6 +29,7 @@ import ( type OauthIssuer struct { Issuer string `json:"issuer"` + JwksUri string `json:"jwks_uri"` AuthURL string `json:"authorization_endpoint"` DeviceAuthURL string `json:"device_authorization_endpoint"` TokenURL string `json:"token_endpoint"` diff --git a/xrootd/launch.go b/xrootd/launch.go index a1ec277a6..78de57ff6 100644 --- a/xrootd/launch.go +++ b/xrootd/launch.go @@ -23,6 +23,7 @@ package xrootd import ( "context" _ "embed" + "os" "path/filepath" "regexp" "strconv" @@ -87,6 +88,9 @@ func makeUnprivilegedXrootdLauncher(daemonName string, configPath string, isCach "XRD_PELICANCLIENTCERTFILE=" + filepath.Join(xrootdRun, "copied-tls-creds.crt"), "XRD_PELICANCLIENTKEYFILE=" + filepath.Join(xrootdRun, "copied-tls-creds.crt"), } + if confDir := os.Getenv("XRD_PLUGINCONFDIR"); confDir != "" { + result.ExtraEnv = append(result.ExtraEnv, "XRD_PLUGINCONFDIR="+confDir) + } } return }