diff --git a/client/fed_linux_test.go b/client/fed_linux_test.go index b76db620e..14771dfea 100644 --- a/client/fed_linux_test.go +++ b/client/fed_linux_test.go @@ -21,10 +21,8 @@ package client_test import ( - "context" "fmt" "os" - "path" "path/filepath" "strconv" "testing" @@ -35,9 +33,9 @@ import ( "github.com/pelicanplatform/pelican/fed_test_utils" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/server_utils" - "github.com/pelicanplatform/pelican/test_utils" "github.com/pelicanplatform/pelican/token" "github.com/pelicanplatform/pelican/token_scopes" + log "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -45,13 +43,13 @@ import ( func TestRecursiveUploadsAndDownloads(t *testing.T) { // Create instance of test federation - ctx, _, _ := test_utils.TestContext(context.Background(), t) viper.Reset() server_utils.ResetOriginExports() fed := fed_test_utils.NewFedTest(t, mixedAuthOriginCfg) - //////////////////////////SETUP/////////////////////////// + te := client.NewTransferEngine(fed.Ctx) + // Create a token file issuer, err := config.GetServerIssuerURL() require.NoError(t, err) @@ -79,122 +77,57 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { // Make our test directories and files tempDir, err := os.MkdirTemp("", "UploadDir") assert.NoError(t, err) + innerTempDir, err := os.MkdirTemp(tempDir, "InnerUploadDir") + assert.NoError(t, err) + defer os.RemoveAll(tempDir) defer os.RemoveAll(tempDir) - permissions := os.FileMode(0777) + permissions := os.FileMode(0755) err = os.Chmod(tempDir, permissions) require.NoError(t, err) + err = os.Chmod(innerTempDir, permissions) + require.NoError(t, err) testFileContent1 := "test file content" testFileContent2 := "more test file content!" + innerTestFileContent := "this content is within another dir!" tempFile1, err := os.CreateTemp(tempDir, "test1") assert.NoError(t, err, "Error creating temp1 file") tempFile2, err := os.CreateTemp(tempDir, "test1") assert.NoError(t, err, "Error creating temp2 file") + innerTempFile, err := os.CreateTemp(innerTempDir, "testInner") + assert.NoError(t, err, "Error creating inner test file") defer os.Remove(tempFile1.Name()) defer os.Remove(tempFile2.Name()) + defer os.Remove(innerTempFile.Name()) + _, err = tempFile1.WriteString(testFileContent1) assert.NoError(t, err, "Error writing to temp1 file") tempFile1.Close() _, err = tempFile2.WriteString(testFileContent2) assert.NoError(t, err, "Error writing to temp2 file") tempFile2.Close() - - t.Run("testPelicanRecursiveGetAndPutOsdfURL", func(t *testing.T) { - config.SetPreferredPrefix("pelican") - for _, export := range *fed.Exports { - // Set path for object to upload/download - tempPath := tempDir - dirName := filepath.Base(tempPath) - // Note: minimally fixing this test as it is soon to be replaced - uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), - export.FederationPrefix, "pel_osdf", dirName) - - ////////////////////////////////////////////////////////// - - // Upload the file with PUT - transferDetailsUpload, err := client.DoPut(ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) - require.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytes17 := 0 - countBytes23 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case int64(17): - countBytes17++ - continue - case int64(23): - countBytes23++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not upload proper amount of bytes") - } - } - if countBytes17 != 1 || countBytes23 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not uploaded correctly") - } - } else if len(transferDetailsUpload) != 2 { - t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) - } - - // Download the files we just uploaded - var transferDetailsDownload []client.TransferResults - if export.Capabilities.PublicReads { - transferDetailsDownload, err = client.DoGet(ctx, uploadURL, t.TempDir(), true) - } else { - transferDetailsDownload, err = client.DoGet(ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) - } - assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytesUploadIdx0 := 0 - countBytesUploadIdx1 := 0 - // Verify we got the correct files back (have to do this since files upload in different orders at times) - // In this case, we want to match them to the sizes of the uploaded files - for _, transfer := range transferDetailsUpload { - transferredBytes := transfer.TransferredBytes - switch transferredBytes { - case transferDetailsUpload[0].TransferredBytes: - countBytesUploadIdx0++ - continue - case transferDetailsUpload[1].TransferredBytes: - countBytesUploadIdx1++ - continue - default: - // We got a byte amount we are not expecting - t.Fatal("did not download proper amount of bytes") - } - } - if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { - // We would hit this case if 1 counter got hit twice for some reason - t.Fatal("One of the files was not downloaded correctly") - } else if len(transferDetailsDownload) != 2 { - t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) - } - } - - } - }) + _, err = innerTempFile.WriteString(innerTestFileContent) + assert.NoError(t, err, "Error writing to inner test file") + innerTempFile.Close() t.Run("testPelicanRecursiveGetAndPutPelicanURL", func(t *testing.T) { - config.SetPreferredPrefix("pelican") + _, err := config.SetPreferredPrefix("PELICAN") + assert.NoError(t, err) for _, export := range *fed.Exports { // Set path for object to upload/download tempPath := tempDir dirName := filepath.Base(tempPath) - uploadURL := fmt.Sprintf("pelican://%s/%s/%s", export.FederationPrefix, "pel_pel", dirName) - - ////////////////////////////////////////////////////////// + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + export.FederationPrefix, "osdf_osdf", dirName) // Upload the file with PUT - transferDetailsUpload, err := client.DoPut(ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) + transferDetailsUpload, err := client.DoPut(fed.Ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { + if err == nil && len(transferDetailsUpload) == 3 { countBytes17 := 0 countBytes23 := 0 + countBytes35 := 0 // Verify we got the correct files back (have to do this since files upload in different orders at times) for _, transfer := range transferDetailsUpload { transferredBytes := transfer.TransferredBytes @@ -205,50 +138,58 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { case int64(23): countBytes23++ continue + case int64(35): + countBytes35++ + continue default: // We got a byte amount we are not expecting t.Fatal("did not upload proper amount of bytes") } } - if countBytes17 != 1 || countBytes23 != 1 { + if countBytes17 != 1 || countBytes23 != 1 || countBytes35 != 1 { // We would hit this case if 1 counter got hit twice for some reason t.Fatal("One of the files was not uploaded correctly") } - } else if len(transferDetailsUpload) != 2 { + } else if len(transferDetailsUpload) != 3 { t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) } // Download the files we just uploaded var transferDetailsDownload []client.TransferResults if export.Capabilities.PublicReads { - transferDetailsDownload, err = client.DoGet(ctx, uploadURL, t.TempDir(), true) + transferDetailsDownload, err = client.DoGet(fed.Ctx, uploadURL, t.TempDir(), true) } else { - transferDetailsDownload, err = client.DoGet(ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) + transferDetailsDownload, err = client.DoGet(fed.Ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) } assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytesUploadIdx0 := 0 - countBytesUploadIdx1 := 0 + if err == nil && len(transferDetailsDownload) == 3 { + countBytesDownloadIdx0 := 0 + countBytesDownloadIdx1 := 0 + countBytesDownloadIdx2 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) // In this case, we want to match them to the sizes of the uploaded files - for _, transfer := range transferDetailsUpload { + for _, transfer := range transferDetailsDownload { transferredBytes := transfer.TransferredBytes switch transferredBytes { - case transferDetailsUpload[0].TransferredBytes: - countBytesUploadIdx0++ + case transferDetailsDownload[0].TransferredBytes: + countBytesDownloadIdx0++ + continue + case transferDetailsDownload[1].TransferredBytes: + countBytesDownloadIdx1++ continue - case transferDetailsUpload[1].TransferredBytes: - countBytesUploadIdx1++ + case transferDetailsDownload[2].TransferredBytes: + countBytesDownloadIdx2++ continue default: // We got a byte amount we are not expecting t.Fatal("did not download proper amount of bytes") } } - if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { + if countBytesDownloadIdx0 != 1 || countBytesDownloadIdx1 != 1 || countBytesDownloadIdx2 != 1 { // We would hit this case if 1 counter got hit twice for some reason t.Fatal("One of the files was not downloaded correctly") - } else if len(transferDetailsDownload) != 2 { + } else if len(transferDetailsDownload) != 3 { t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) } } @@ -256,23 +197,29 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { }) t.Run("testOsdfRecursiveGetAndPutOsdfURL", func(t *testing.T) { - config.SetPreferredPrefix("osdf") + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) for _, export := range *fed.Exports { // Set path for object to upload/download tempPath := tempDir dirName := filepath.Base(tempPath) - // Note: minimally fixing this test as it is soon to be replaced - uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), - export.FederationPrefix, "osdf_osdf", dirName) + uploadURL := fmt.Sprintf("osdf:///%s/%s/%s", export.FederationPrefix, "osdf_osdf", dirName) + hostname := fmt.Sprintf("%v:%v", param.Server_WebHost.GetString(), param.Server_WebPort.GetInt()) - ////////////////////////////////////////////////////////// + // Set our metadata values in config since that is what this url scheme - prefix combo does in handle_http + metadata, err := config.DiscoverUrlFederation(fed.Ctx, "https://"+hostname) + assert.NoError(t, err) + viper.Set("Federation.DirectorUrl", metadata.DirectorEndpoint) + viper.Set("Federation.RegistryUrl", metadata.NamespaceRegistrationEndpoint) + viper.Set("Federation.DiscoveryUrl", hostname) // Upload the file with PUT - transferDetailsUpload, err := client.DoPut(ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) + transferDetailsUpload, err := client.DoPut(fed.Ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { + if err == nil && len(transferDetailsUpload) == 3 { countBytes17 := 0 countBytes23 := 0 + countBytes35 := 0 // Verify we got the correct files back (have to do this since files upload in different orders at times) for _, transfer := range transferDetailsUpload { transferredBytes := transfer.TransferredBytes @@ -283,16 +230,19 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { case int64(23): countBytes23++ continue + case int64(35): + countBytes35++ + continue default: // We got a byte amount we are not expecting t.Fatal("did not upload proper amount of bytes") } } - if countBytes17 != 1 || countBytes23 != 1 { + if countBytes17 != 1 || countBytes23 != 1 || countBytes35 != 1 { // We would hit this case if 1 counter got hit twice for some reason t.Fatal("One of the files was not uploaded correctly") } - } else if len(transferDetailsUpload) != 2 { + } else if len(transferDetailsUpload) != 3 { t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) } @@ -300,62 +250,62 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { tmpDir := t.TempDir() var transferDetailsDownload []client.TransferResults if export.Capabilities.PublicReads { - transferDetailsDownload, err = client.DoGet(ctx, uploadURL, tmpDir, true) + transferDetailsDownload, err = client.DoGet(fed.Ctx, uploadURL, tmpDir, true) } else { - transferDetailsDownload, err = client.DoGet(ctx, uploadURL, tmpDir, true, client.WithTokenLocation(tempToken.Name())) + transferDetailsDownload, err = client.DoGet(fed.Ctx, uploadURL, tmpDir, true, client.WithTokenLocation(tempToken.Name())) } assert.NoError(t, err) - if err == nil && len(transferDetailsDownload) == 2 { - countBytesUploadIdx0 := 0 - countBytesUploadIdx1 := 0 + if err == nil && len(transferDetailsDownload) == 3 { + countBytesDownloadIdx0 := 0 + countBytesDownloadIdx1 := 0 + countBytesDownloadIdx2 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) // In this case, we want to match them to the sizes of the uploaded files for _, transfer := range transferDetailsDownload { transferredBytes := transfer.TransferredBytes switch transferredBytes { - case transferDetailsUpload[0].TransferredBytes: - countBytesUploadIdx0++ + case transferDetailsDownload[0].TransferredBytes: + countBytesDownloadIdx0++ + continue + case transferDetailsDownload[1].TransferredBytes: + countBytesDownloadIdx1++ continue - case transferDetailsUpload[1].TransferredBytes: - countBytesUploadIdx1++ + case transferDetailsDownload[2].TransferredBytes: + countBytesDownloadIdx2++ continue default: // We got a byte amount we are not expecting t.Fatal("did not download proper amount of bytes") } } - if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { + if countBytesDownloadIdx0 != 1 || countBytesDownloadIdx1 != 1 || countBytesDownloadIdx2 != 1 { // We would hit this case if 1 counter got hit twice for some reason t.Fatal("One of the files was not downloaded correctly") + } else if len(transferDetailsDownload) != 3 { + t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) } - contents, err := os.ReadFile(filepath.Join(tmpDir, path.Join(dirName, path.Base(tempFile2.Name())))) - assert.NoError(t, err) - assert.Equal(t, testFileContent2, string(contents)) - contents, err = os.ReadFile(filepath.Join(tmpDir, path.Join(dirName, path.Base(tempFile1.Name())))) - assert.NoError(t, err) - assert.Equal(t, testFileContent1, string(contents)) - } else if err == nil && len(transferDetailsDownload) != 2 { - t.Fatalf("Number of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) } } }) t.Run("testOsdfRecursiveGetAndPutPelicanURL", func(t *testing.T) { - config.SetPreferredPrefix("osdf") + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + for _, export := range *fed.Exports { // Set path for object to upload/download tempPath := tempDir dirName := filepath.Base(tempPath) - uploadURL := fmt.Sprintf("pelican://%s/%s/%s", export.FederationPrefix, "osdf_pel", dirName) - - ////////////////////////////////////////////////////////// - + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + export.FederationPrefix, "osdf_osdf", dirName) // Upload the file with PUT - transferDetailsUpload, err := client.DoPut(ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) + transferDetailsUpload, err := client.DoPut(fed.Ctx, tempDir, uploadURL, true, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { + if err == nil && len(transferDetailsUpload) == 3 { countBytes17 := 0 countBytes23 := 0 + countBytes35 := 0 // Verify we got the correct files back (have to do this since files upload in different orders at times) for _, transfer := range transferDetailsUpload { transferredBytes := transfer.TransferredBytes @@ -366,53 +316,68 @@ func TestRecursiveUploadsAndDownloads(t *testing.T) { case int64(23): countBytes23++ continue + case int64(35): + countBytes35++ + continue default: // We got a byte amount we are not expecting t.Fatal("did not upload proper amount of bytes") } } - if countBytes17 != 1 || countBytes23 != 1 { + if countBytes17 != 1 || countBytes23 != 1 || countBytes35 != 1 { // We would hit this case if 1 counter got hit twice for some reason t.Fatal("One of the files was not uploaded correctly") } - } else if len(transferDetailsUpload) != 2 { + } else if len(transferDetailsUpload) != 3 { t.Fatalf("Amount of transfers results returned for upload was not correct. Transfer details returned: %d", len(transferDetailsUpload)) } - // Download the files we just uploaded var transferDetailsDownload []client.TransferResults if export.Capabilities.PublicReads { - transferDetailsDownload, err = client.DoGet(ctx, uploadURL, t.TempDir(), true) + transferDetailsDownload, err = client.DoGet(fed.Ctx, uploadURL, t.TempDir(), true) } else { - transferDetailsDownload, err = client.DoGet(ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) + transferDetailsDownload, err = client.DoGet(fed.Ctx, uploadURL, t.TempDir(), true, client.WithTokenLocation(tempToken.Name())) } assert.NoError(t, err) - if err == nil && len(transferDetailsUpload) == 2 { - countBytesUploadIdx0 := 0 - countBytesUploadIdx1 := 0 + if err == nil && len(transferDetailsDownload) == 3 { + countBytesDownloadIdx0 := 0 + countBytesDownloadIdx1 := 0 + countBytesDownloadIdx2 := 0 + // Verify we got the correct files back (have to do this since files upload in different orders at times) // In this case, we want to match them to the sizes of the uploaded files - for _, transfer := range transferDetailsUpload { + for _, transfer := range transferDetailsDownload { transferredBytes := transfer.TransferredBytes switch transferredBytes { - case transferDetailsUpload[0].TransferredBytes: - countBytesUploadIdx0++ + case transferDetailsDownload[0].TransferredBytes: + countBytesDownloadIdx0++ + continue + case transferDetailsDownload[1].TransferredBytes: + countBytesDownloadIdx1++ continue - case transferDetailsUpload[1].TransferredBytes: - countBytesUploadIdx1++ + case transferDetailsDownload[2].TransferredBytes: + countBytesDownloadIdx2++ continue default: // We got a byte amount we are not expecting t.Fatal("did not download proper amount of bytes") } } - if countBytesUploadIdx0 != 1 || countBytesUploadIdx1 != 1 { + if countBytesDownloadIdx0 != 1 || countBytesDownloadIdx1 != 1 || countBytesDownloadIdx2 != 1 { // We would hit this case if 1 counter got hit twice for some reason t.Fatal("One of the files was not downloaded correctly") - } else if len(transferDetailsDownload) != 2 { + } else if len(transferDetailsDownload) != 3 { t.Fatalf("Amount of transfers results returned for download was not correct. Transfer details returned: %d", len(transferDetailsDownload)) } } } }) + + t.Cleanup(func() { + if err := te.Shutdown(); err != nil { + log.Errorln("Failure when shutting down transfer engine:", err) + } + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) } diff --git a/client/fed_test.go b/client/fed_test.go index 9a9dc0bf4..b252e304a 100644 --- a/client/fed_test.go +++ b/client/fed_test.go @@ -23,16 +23,13 @@ package client_test import ( "context" _ "embed" - "encoding/json" "fmt" - "io" - "net/http" "os" "path/filepath" + "strconv" "testing" "time" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/assert" @@ -41,7 +38,6 @@ import ( "github.com/pelicanplatform/pelican/client" "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/fed_test_utils" - "github.com/pelicanplatform/pelican/launchers" "github.com/pelicanplatform/pelican/param" "github.com/pelicanplatform/pelican/server_utils" "github.com/pelicanplatform/pelican/test_utils" @@ -60,168 +56,6 @@ var ( mixedAuthOriginCfg string ) -func generateFileTestScitoken() (string, error) { - // Issuer is whichever server that initiates the test, so it's the server itself - issuerUrl, err := config.GetServerIssuerURL() - if err != nil { - return "", err - } - if issuerUrl == "" { // if empty, then error - return "", errors.New("Failed to create token: Invalid iss, Server_ExternalWebUrl is empty") - } - - fTestTokenCfg := token.NewWLCGToken() - fTestTokenCfg.Lifetime = time.Minute - fTestTokenCfg.Issuer = issuerUrl - fTestTokenCfg.Subject = "origin" - fTestTokenCfg.AddAudiences(config.GetServerAudience()) - fTestTokenCfg.AddResourceScopes(token_scopes.NewResourceScope(token_scopes.Storage_Read, "/"), - token_scopes.NewResourceScope(token_scopes.Storage_Modify, "/")) - - // CreateToken also handles validation for us - tok, err := fTestTokenCfg.CreateToken() - if err != nil { - return "", errors.Wrap(err, "failed to create file test token:") - } - - return tok, nil -} - -func TestFullUpload(t *testing.T) { - // Setup our test federation - ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) - defer func() { require.NoError(t, egrp.Wait()) }() - defer cancel() - - viper.Reset() - server_utils.ResetOriginExports() - defer viper.Reset() - defer server_utils.ResetOriginExports() - - modules := config.ServerType(0) - modules.Set(config.OriginType) - modules.Set(config.DirectorType) - modules.Set(config.RegistryType) - - // Create our own temp directory (for some reason t.TempDir() does not play well with xrootd) - tmpPathPattern := "XRootD-Test_Origin*" - tmpPath, err := os.MkdirTemp("", tmpPathPattern) - require.NoError(t, err) - - // Need to set permissions or the xrootd process we spawn won't be able to write PID/UID files - permissions := os.FileMode(0755) - err = os.Chmod(tmpPath, permissions) - require.NoError(t, err) - - viper.Set("ConfigDir", tmpPath) - - // Increase the log level; otherwise, its difficult to debug failures - viper.Set("Logging.Level", "Debug") - config.InitConfig() - - originDir, err := os.MkdirTemp("", "Origin") - assert.NoError(t, err) - - // Change the permissions of the temporary directory - permissions = os.FileMode(0777) - err = os.Chmod(originDir, permissions) - require.NoError(t, err) - - viper.Set("Origin.FederationPrefix", "/test") - viper.Set("Origin.StoragePrefix", originDir) - viper.Set("Origin.StorageType", "posix") - // Disable functionality we're not using (and is difficult to make work on Mac) - viper.Set("Origin.EnableCmsd", false) - viper.Set("Origin.EnableMacaroons", false) - viper.Set("Origin.EnableVoms", false) - viper.Set("Origin.EnableWrites", true) - viper.Set("TLSSkipVerify", true) - viper.Set("Server.EnableUI", false) - viper.Set("Registry.DbLocation", filepath.Join(t.TempDir(), "ns-registry.sqlite")) - viper.Set("Origin.RunLocation", tmpPath) - viper.Set("Registry.RequireOriginApproval", false) - viper.Set("Registry.RequireCacheApproval", false) - viper.Set("Logging.Origin.Scitokens", "debug") - viper.Set("Origin.Port", 0) - viper.Set("Server.WebPort", 0) - - err = config.InitServer(ctx, modules) - require.NoError(t, err) - - fedCancel, err := launchers.LaunchModules(ctx, modules) - defer fedCancel() - if err != nil { - log.Errorln("Failure in fedServeInternal:", err) - require.NoError(t, err) - } - - desiredURL := param.Server_ExternalWebUrl.GetString() + "/api/v1.0/health" - err = server_utils.WaitUntilWorking(ctx, "GET", desiredURL, "director", 200) - require.NoError(t, err) - - httpc := http.Client{ - Transport: config.GetTransport(), - } - resp, err := httpc.Get(desiredURL) - require.NoError(t, err) - - assert.Equal(t, resp.StatusCode, http.StatusOK) - - responseBody, err := io.ReadAll(resp.Body) - require.NoError(t, err) - expectedResponse := struct { - Msg string `json:"message"` - }{} - err = json.Unmarshal(responseBody, &expectedResponse) - require.NoError(t, err) - - assert.NotEmpty(t, expectedResponse.Msg) - - t.Run("testFullUpload", func(t *testing.T) { - testFileContent := "test file content" - - // Create the temporary file to upload - tempFile, err := os.CreateTemp(t.TempDir(), "test") - assert.NoError(t, err, "Error creating temp file") - defer os.Remove(tempFile.Name()) - _, err = tempFile.WriteString(testFileContent) - assert.NoError(t, err, "Error writing to temp file") - tempFile.Close() - - // Create a token file - token, err := generateFileTestScitoken() - assert.NoError(t, err) - tempToken, err := os.CreateTemp(t.TempDir(), "token") - assert.NoError(t, err, "Error creating temp token file") - defer os.Remove(tempToken.Name()) - _, err = tempToken.WriteString(token) - assert.NoError(t, err, "Error writing to temp token file") - tempToken.Close() - - // Upload the file - tempPath := tempFile.Name() - fileName := filepath.Base(tempPath) - uploadURL := "stash:///test/" + fileName - - transferResults, err := client.DoCopy(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) - assert.NoError(t, err, "Error uploading file") - assert.Equal(t, int64(len(testFileContent)), transferResults[0].TransferredBytes, "Uploaded file size does not match") - - // Upload an osdf file - uploadURL = "pelican:///test/stuff/blah.txt" - assert.NoError(t, err, "Error parsing upload URL") - transferResults, err = client.DoCopy(ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) - assert.NoError(t, err, "Error uploading file") - assert.Equal(t, int64(len(testFileContent)), transferResults[0].TransferredBytes, "Uploaded file size does not match") - }) - t.Cleanup(func() { - os.RemoveAll(tmpPath) - os.RemoveAll(originDir) - }) - - viper.Reset() -} - // A test that spins up a federation, and tests object get and put func TestGetAndPutAuth(t *testing.T) { viper.Reset() @@ -270,12 +104,43 @@ func TestGetAndPutAuth(t *testing.T) { // This tests object get/put with a pelican:// url t.Run("testPelicanObjectPutAndGetWithPelicanUrl", func(t *testing.T) { - config.SetPreferredPrefix("pelican") + _, err := config.SetPreferredPrefix("PELICAN") + assert.NoError(t, err) + // Set path for object to upload/download for _, export := range *fed.Exports { tempPath := tempFile.Name() fileName := filepath.Base(tempPath) - uploadURL := fmt.Sprintf("pelican://%s/%s", export.FederationPrefix, fileName) + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + export.FederationPrefix, "osdf_osdf", fileName) + + // Upload the file with PUT + transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) + } + + // Download that same file with GET + transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) + } + } + }) + + // This tests object get/put with a pelican:// url + t.Run("testOsdfObjectPutAndGetWithPelicanUrl", func(t *testing.T) { + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + + for _, export := range *fed.Exports { + // Set path for object to upload/download + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + export.FederationPrefix, "osdf_osdf", fileName) // Upload the file with PUT transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) @@ -294,14 +159,24 @@ func TestGetAndPutAuth(t *testing.T) { }) // This tests pelican object get/put with an osdf url - t.Run("testPelicanObjectPutAndGetWithOSDFUrl", func(t *testing.T) { - config.SetPreferredPrefix("pelican") + t.Run("testOsdfObjectPutAndGetWithOSDFUrl", func(t *testing.T) { + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + for _, export := range *fed.Exports { // Set path for object to upload/download tempPath := tempFile.Name() fileName := filepath.Base(tempPath) // Minimal fix of test as it is soon to be replaced - uploadURL := fmt.Sprintf("pelican://%s/%s", export.FederationPrefix, fileName) + uploadURL := fmt.Sprintf("osdf://%s/%s", export.FederationPrefix, fileName) + hostname := fmt.Sprintf("%v:%v", param.Server_WebHost.GetString(), param.Server_WebPort.GetInt()) + + // Set our metadata values in config since that is what this url scheme - prefix combo does in handle_http + metadata, err := config.DiscoverUrlFederation(fed.Ctx, "https://"+hostname) + assert.NoError(t, err) + viper.Set("Federation.DirectorUrl", metadata.DirectorEndpoint) + viper.Set("Federation.RegistryUrl", metadata.NamespaceRegistrationEndpoint) + viper.Set("Federation.DiscoveryUrl", hostname) // Upload the file with PUT transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) @@ -318,25 +193,109 @@ func TestGetAndPutAuth(t *testing.T) { } } }) + t.Cleanup(func() { + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) +} + +// A test that spins up a federation, and tests object get and put +func TestCopyAuth(t *testing.T) { + viper.Reset() + server_utils.ResetOriginExports() + fed := fed_test_utils.NewFedTest(t, bothAuthOriginCfg) + + te := client.NewTransferEngine(fed.Ctx) + + // Other set-up items: + testFileContent := "test file content" + // Create the temporary file to upload + tempFile, err := os.CreateTemp(t.TempDir(), "test") + assert.NoError(t, err, "Error creating temp file") + defer os.Remove(tempFile.Name()) + _, err = tempFile.WriteString(testFileContent) + assert.NoError(t, err, "Error writing to temp file") + tempFile.Close() + + issuer, err := config.GetServerIssuerURL() + require.NoError(t, err) + audience := config.GetServerAudience() + + // Create a token file + tokenConfig := token.NewWLCGToken() + tokenConfig.Lifetime = time.Minute + tokenConfig.Issuer = issuer + tokenConfig.Subject = "origin" + tokenConfig.AddAudiences(audience) + + scopes := []token_scopes.TokenScope{} + readScope, err := token_scopes.Storage_Read.Path("/") + assert.NoError(t, err) + scopes = append(scopes, readScope) + modScope, err := token_scopes.Storage_Modify.Path("/") + assert.NoError(t, err) + scopes = append(scopes, modScope) + tokenConfig.AddScopes(scopes...) + token, err := tokenConfig.CreateToken() + assert.NoError(t, err) + tempToken, err := os.CreateTemp(t.TempDir(), "token") + assert.NoError(t, err, "Error creating temp token file") + defer os.Remove(tempToken.Name()) + _, err = tempToken.WriteString(token) + assert.NoError(t, err, "Error writing to temp token file") + tempToken.Close() + // Disable progress bars to not reuse the same mpb instance + viper.Set("Logging.DisableProgressBars", true) // This tests object get/put with a pelican:// url - t.Run("testOsdfObjectPutAndGetWithPelicanUrl", func(t *testing.T) { - config.SetPreferredPrefix("osdf") + t.Run("testPelicanObjectCopyWithPelicanUrl", func(t *testing.T) { + _, err := config.SetPreferredPrefix("PELICAN") + assert.NoError(t, err) + + // Set path for object to upload/download + for _, export := range *fed.Exports { + tempPath := tempFile.Name() + fileName := filepath.Base(tempPath) + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + export.FederationPrefix, "osdf_osdf", fileName) + + // Upload the file with PUT + transferResultsUpload, err := client.DoCopy(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) + } + + // Download that same file with GET + transferResultsDownload, err := client.DoCopy(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + assert.NoError(t, err) + if err == nil { + assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) + } + } + }) + + // This tests object get/put with a pelican:// url + t.Run("testOsdfObjectCopyWithPelicanUrl", func(t *testing.T) { + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + for _, export := range *fed.Exports { // Set path for object to upload/download tempPath := tempFile.Name() fileName := filepath.Base(tempPath) - uploadURL := fmt.Sprintf("pelican://%s/%s", export.FederationPrefix, fileName) + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + export.FederationPrefix, "osdf_osdf", fileName) // Upload the file with PUT - transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + transferResultsUpload, err := client.DoCopy(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) } // Download that same file with GET - transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + transferResultsDownload, err := client.DoCopy(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) @@ -345,35 +304,51 @@ func TestGetAndPutAuth(t *testing.T) { }) // This tests pelican object get/put with an osdf url - t.Run("testOsdfObjectPutAndGetWithOSDFUrl", func(t *testing.T) { - config.SetPreferredPrefix("osdf") + t.Run("testOsdfObjectCopyWithOSDFUrl", func(t *testing.T) { + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + for _, export := range *fed.Exports { // Set path for object to upload/download tempPath := tempFile.Name() fileName := filepath.Base(tempPath) // Minimal fix of test as it is soon to be replaced - uploadURL := fmt.Sprintf("pelican://%s/%s", export.FederationPrefix, fileName) + uploadURL := fmt.Sprintf("osdf://%s/%s", export.FederationPrefix, fileName) + hostname := fmt.Sprintf("%v:%v", param.Server_WebHost.GetString(), param.Server_WebPort.GetInt()) + + // Set our metadata values in config since that is what this url scheme - prefix combo does in handle_http + metadata, err := config.DiscoverUrlFederation(fed.Ctx, "https://"+hostname) + assert.NoError(t, err) + viper.Set("Federation.DirectorUrl", metadata.DirectorEndpoint) + viper.Set("Federation.RegistryUrl", metadata.NamespaceRegistrationEndpoint) + viper.Set("Federation.DiscoveryUrl", hostname) // Upload the file with PUT - transferResultsUpload, err := client.DoPut(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) + transferResultsUpload, err := client.DoCopy(fed.Ctx, tempFile.Name(), uploadURL, false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsUpload[0].TransferredBytes, int64(17)) } // Download that same file with GET - transferResultsDownload, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) + transferResultsDownload, err := client.DoCopy(fed.Ctx, uploadURL, t.TempDir(), false, client.WithTokenLocation(tempToken.Name())) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResultsDownload[0].TransferredBytes, transferResultsUpload[0].TransferredBytes) } } }) + t.Cleanup(func() { + if err := te.Shutdown(); err != nil { + log.Errorln("Failure when shutting down transfer engine:", err) + } + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) } // A test that spins up the federation, where the origin is in EnablePublicReads mode. Then GET a file from the origin without a token func TestGetPublicRead(t *testing.T) { - ctx, _, _ := test_utils.TestContext(context.Background(), t) viper.Reset() server_utils.ResetOriginExports() @@ -395,16 +370,21 @@ func TestGetPublicRead(t *testing.T) { // Set path for object to upload/download tempPath := tempFile.Name() fileName := filepath.Base(tempPath) - uploadURL := fmt.Sprintf("pelican://%s/%s", export.FederationPrefix, fileName) + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + export.FederationPrefix, fileName) // Download the file with GET. Shouldn't need a token to succeed - transferResults, err := client.DoGet(ctx, uploadURL, t.TempDir(), false) + transferResults, err := client.DoGet(fed.Ctx, uploadURL, t.TempDir(), false) assert.NoError(t, err) if err == nil { assert.Equal(t, transferResults[0].TransferredBytes, int64(17)) } } }) + t.Cleanup(func() { + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) } // A test that tests the statHttp function @@ -429,7 +409,10 @@ func TestStatHttp(t *testing.T) { // Set path for object to upload/download tempPath := tempFile.Name() fileName := filepath.Base(tempPath) - uploadURL := fmt.Sprintf("pelican://%s/%s", ((*fed.Exports)[0]).FederationPrefix, fileName) + uploadURL := fmt.Sprintf("pelican://%s:%s%s/%s", param.Server_Hostname.GetString(), strconv.Itoa(param.Server_WebPort.GetInt()), + ((*fed.Exports)[0]).FederationPrefix, fileName) + + log.Errorln(uploadURL) // Download the file with GET. Shouldn't need a token to succeed objectSize, err := client.DoStat(ctx, uploadURL) @@ -440,6 +423,8 @@ func TestStatHttp(t *testing.T) { }) t.Run("testStatHttpOSDFScheme", func(t *testing.T) { + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) testFileContent := "test file content" // Drop the testFileContent into the origin directory tempFile, err := os.Create(filepath.Join(((*fed.Exports)[0]).StoragePrefix, "test.txt")) @@ -450,11 +435,19 @@ func TestStatHttp(t *testing.T) { viper.Set("Logging.DisableProgressBars", true) - // Set path for object to upload/download tempPath := tempFile.Name() fileName := filepath.Base(tempPath) - // Minimal fix of test as it is soon to be replaced - uploadURL := fmt.Sprintf("pelican://%s/%s", ((*fed.Exports)[0]).FederationPrefix, fileName) + + uploadURL := fmt.Sprintf("osdf://%s/%s", ((*fed.Exports)[0]).FederationPrefix, fileName) + hostname := fmt.Sprintf("%v:%v", param.Server_WebHost.GetString(), param.Server_WebPort.GetInt()) + + // Set our metadata values in config since that is what this url scheme - prefix combo does in handle_http + metadata, err := config.DiscoverUrlFederation(fed.Ctx, "https://"+hostname) + assert.NoError(t, err) + viper.Set("Federation.DirectorUrl", metadata.DirectorEndpoint) + viper.Set("Federation.RegistryUrl", metadata.NamespaceRegistrationEndpoint) + viper.Set("Federation.DiscoveryUrl", hostname) + log.Errorln(uploadURL) // Download the file with GET. Shouldn't need a token to succeed objectSize, err := client.DoStat(ctx, uploadURL) @@ -484,6 +477,6 @@ func TestStatHttp(t *testing.T) { objectSize, err := client.DoStat(ctx, uploadURL) assert.Error(t, err) assert.Equal(t, uint64(0), objectSize) - assert.Contains(t, err.Error(), "Unsupported scheme requested") + assert.Contains(t, err.Error(), "Do not understand the destination scheme: some. Permitted values are file, osdf, pelican, stash, ") }) } diff --git a/client/handle_http.go b/client/handle_http.go index 1a1c3e257..1f02aa560 100644 --- a/client/handle_http.go +++ b/client/handle_http.go @@ -41,14 +41,15 @@ import ( "github.com/VividCortex/ewma" "github.com/google/uuid" + "github.com/jellydator/ttlcache/v3" "github.com/lestrrat-go/option" "github.com/opensaucerer/grab/v3" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "github.com/spf13/viper" "github.com/studio-b12/gowebdav" "github.com/vbauerster/mpb/v8" "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" "golang.org/x/time/rate" "github.com/pelicanplatform/pelican/config" @@ -59,9 +60,39 @@ import ( var ( progressCtrOnce sync.Once progressCtr *mpb.Progress + + successTTL = ttlcache.DefaultTTL + failureTTL = 5 * time.Minute + + loader = ttlcache.LoaderFunc[string, cacheItem]( + func(c *ttlcache.Cache[string, cacheItem], key string) *ttlcache.Item[string, cacheItem] { + ctx := context.Background() + // Note: setting this timeout mostly for unit tests + ctx, cancel := context.WithTimeout(ctx, param.Transport_ResponseHeaderTimeout.GetDuration()) + defer cancel() + urlFederation, err := config.DiscoverUrlFederation(ctx, key) + if err != nil { + // Set a shorter TTL for failures + item := c.Set(key, cacheItem{err: err}, failureTTL) + return item + } + // Set a longer TTL for successes + item := c.Set(key, cacheItem{ + url: pelicanUrl{ + directorUrl: urlFederation.DirectorEndpoint, + }, + }, successTTL) + return item + }, + ) ) type ( + cacheItem struct { + url pelicanUrl + err error + } + // Error type for when the transfer started to return data then completely stopped StoppedTransferError struct { Err string @@ -182,24 +213,25 @@ type ( // An object able to process transfer jobs. TransferEngine struct { - ctx context.Context // The context provided upon creation of the engine. - cancel context.CancelFunc - egrp *errgroup.Group // The errgroup for the worker goroutines - work chan *clientTransferJob - files chan *clientTransferFile - results chan *clientTransferResults - jobLookupDone chan *clientTransferJob // Indicates the job lookup handler is done with the job - workersActive int - resultsMap map[uuid.UUID]chan *TransferResults - workMap map[uuid.UUID]chan *TransferJob - notifyChan chan bool - closeChan chan bool - closeDoneChan chan bool - ewmaTick *time.Ticker - ewma ewma.MovingAverage - ewmaVal atomic.Int64 - ewmaCtr atomic.Int64 - clientLock sync.RWMutex + ctx context.Context // The context provided upon creation of the engine. + cancel context.CancelFunc + egrp *errgroup.Group // The errgroup for the worker goroutines + work chan *clientTransferJob + files chan *clientTransferFile + results chan *clientTransferResults + jobLookupDone chan *clientTransferJob // Indicates the job lookup handler is done with the job + workersActive int + resultsMap map[uuid.UUID]chan *TransferResults + workMap map[uuid.UUID]chan *TransferJob + notifyChan chan bool + closeChan chan bool + closeDoneChan chan bool + ewmaTick *time.Ticker + ewma ewma.MovingAverage + ewmaVal atomic.Int64 + ewmaCtr atomic.Int64 + clientLock sync.RWMutex + pelicanURLCache *ttlcache.Cache[string, cacheItem] } TransferCallbackFunc = func(path string, downloaded int64, totalSize int64, completed bool) @@ -210,6 +242,7 @@ type ( ctx context.Context cancel context.CancelFunc callback TransferCallbackFunc + engine *TransferEngine skipAcquire bool // Enable/disable the token acquisition logic. Defaults to acquiring a token tokenLocation string // Location of a token file to use for transfers token string // Token that should be used for transfers @@ -232,6 +265,10 @@ type ( NeedsToken bool PackOption string } + + pelicanUrl struct { + directorUrl string + } ) const ( @@ -333,6 +370,86 @@ func (tr TransferResults) ID() string { return tr.jobId.String() } +func (te *TransferEngine) newPelicanURL(remoteUrl *url.URL) (pelicanURL pelicanUrl, err error) { + scheme := remoteUrl.Scheme + if remoteUrl.Host != "" { + if scheme == "osdf" || scheme == "stash" { + // in the osdf/stash case, fix url's that have a hostname + joinedPath, err := url.JoinPath(remoteUrl.Host, remoteUrl.Path) + // Prefix with a / just in case + remoteUrl.Path = path.Join("/", joinedPath) + if err != nil { + log.Errorln("Failed to join remote destination url path:", err) + return pelicanUrl{}, err + } + } else if scheme == "pelican" { + // If we have a host and url is pelican, we need to extract federation data from the host + log.Debugln("Detected pelican:// url, getting federation metadata from specified host", remoteUrl.Host) + federationUrl := &url.URL{} + // federationUrl, _ := url.Parse(remoteUrl.String()) + federationUrl.Scheme = "https" + federationUrl.Path = "" + federationUrl.Host = remoteUrl.Host + + // Check if cache has key of federationURL, if not, loader will add it: + pelicanUrlItem := te.pelicanURLCache.Get(federationUrl.String()) + if pelicanUrlItem.Value().err != nil { + return pelicanUrl{}, pelicanUrlItem.Value().err + } else { + pelicanURL = pelicanUrlItem.Value().url + } + } + } + + // With an osdf:// url scheme, we assume the user will be using the OSDF so load in our osdf metadata for our url + if scheme == "osdf" { + // If we are using an osdf/stash binary, we discovered the federation already --> load into local url metadata + if config.GetPreferredPrefix() == "OSDF" { + log.Debugln("In OSDF mode with osdf:// url; populating metadata with OSDF defaults") + if param.Federation_DirectorUrl.GetString() == "" || param.Federation_DiscoveryUrl.GetString() == "" || param.Federation_RegistryUrl.GetString() == "" { + return pelicanUrl{}, fmt.Errorf("OSDF default metadata is not populated in config") + } else { + pelicanURL.directorUrl = param.Federation_DirectorUrl.GetString() + } + } else if config.GetPreferredPrefix() == "PELICAN" { + // We hit this case when we are using a pelican binary but an osdf:// url, therefore we need to disover the osdf federation + log.Debugln("In Pelican mode with an osdf:// url, populating metadata with OSDF defaults") + // Check if cache has key of federationURL, if not, loader will add it: + pelicanUrlItem := te.pelicanURLCache.Get("osg-htc.org") + if pelicanUrlItem.Value().err != nil { + err = pelicanUrlItem.Value().err + return + } else { + pelicanURL = pelicanUrlItem.Value().url + } + } + } else if scheme == "pelican" && remoteUrl.Host == "" { + // We hit this case when we do not have a hostname with a pelican:// url + if param.Federation_DiscoveryUrl.GetString() == "" { + return pelicanUrl{}, fmt.Errorf("Pelican url scheme without discovery-url detected, please provide a federation discovery-url " + + "(e.g. pelican://) within the hostname or with the -f flag") + } else { + // Check if cache has key of federationURL, if not, loader will add it: + pelicanUrlItem := te.pelicanURLCache.Get(param.Federation_DiscoveryUrl.GetString()) + if pelicanUrlItem != nil { + pelicanURL = pelicanUrlItem.Value().url + } else { + return pelicanUrl{}, fmt.Errorf("Issue getting metadata information from cache") + } + } + } else if scheme == "" { + // If we don't have a url scheme, then our metadata information should be in the config + log.Debugln("No url scheme detected, getting metadata information from configuration") + pelicanURL.directorUrl = param.Federation_DirectorUrl.GetString() + + // If the values do not exist, exit with failure + if pelicanURL.directorUrl == "" { + return pelicanUrl{}, fmt.Errorf("Missing metadata information in config, ensure Federation DirectorUrl, RegistryUrl, and DiscoverUrl are all set") + } + } + return +} + // Returns a new transfer engine object whose lifetime is tied // to the provided context. Will launcher worker goroutines to // handle the underlying transfers @@ -342,21 +459,31 @@ func NewTransferEngine(ctx context.Context) *TransferEngine { work := make(chan *clientTransferJob) files := make(chan *clientTransferFile) results := make(chan *clientTransferResults, 5) + suppressedLoader := ttlcache.NewSuppressedLoader(loader, new(singleflight.Group)) + pelicanURLCache := ttlcache.New( + ttlcache.WithTTL[string, cacheItem](30*time.Minute), + ttlcache.WithLoader(suppressedLoader), + ) + + // Start our cache for url metadata + go pelicanURLCache.Start() + te := &TransferEngine{ - ctx: ctx, - cancel: cancel, - egrp: egrp, - work: work, - files: files, - results: results, - resultsMap: make(map[uuid.UUID]chan *TransferResults), - workMap: make(map[uuid.UUID]chan *TransferJob), - jobLookupDone: make(chan *clientTransferJob), - notifyChan: make(chan bool), - closeChan: make(chan bool), - closeDoneChan: make(chan bool), - ewmaTick: time.NewTicker(ewmaInterval), - ewma: ewma.NewMovingAverage(), + ctx: ctx, + cancel: cancel, + egrp: egrp, + work: work, + files: files, + results: results, + resultsMap: make(map[uuid.UUID]chan *TransferResults), + workMap: make(map[uuid.UUID]chan *TransferJob), + jobLookupDone: make(chan *clientTransferJob), + notifyChan: make(chan bool), + closeChan: make(chan bool), + closeDoneChan: make(chan bool), + ewmaTick: time.NewTicker(ewmaInterval), + ewma: ewma.NewMovingAverage(), + pelicanURLCache: pelicanURLCache, } workerCount := param.Client_WorkerCount.GetInt() if workerCount <= 0 { @@ -421,6 +548,7 @@ func (te *TransferEngine) NewClient(options ...TransferOption) (client *Transfer return } client = &TransferClient{ + engine: te, id: id, results: make(chan *TransferResults), work: make(chan *TransferJob), @@ -463,6 +591,7 @@ func (te *TransferEngine) Shutdown() error { te.Close() <-te.closeDoneChan te.ewmaTick.Stop() + te.pelicanURLCache.Stop() te.cancel() err := te.egrp.Wait() @@ -742,6 +871,12 @@ func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, u return } + pelicanURL, err := tc.engine.newPelicanURL(remoteUrl) + if err != nil { + err = errors.Wrap(err, "error generating metadata for specified url") + return + } + copyUrl := *remoteUrl // Make a copy of the input URL to avoid concurrent issues. tj = &TransferJob{ caches: tc.caches, @@ -773,40 +908,11 @@ func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, u } } - if remoteUrl.Scheme == "pelican" && remoteUrl.Host != "" { - fd := config.GetFederation() - defer config.SetFederation(fd) - config.SetFederation(config.FederationDiscovery{}) - fedUrlCopy := *remoteUrl - fedUrlCopy.Scheme = "https" - fedUrlCopy.Path = "" - fedUrlCopy.RawFragment = "" - fedUrlCopy.RawQuery = "" - viper.Set("Federation.DiscoveryUrl", fedUrlCopy.String()) - if err = config.DiscoverFederation(); err != nil { - return - } - } else if remoteUrl.Scheme == "osdf" { - if remoteUrl.Host != "" { - remoteUrl.Path = path.Clean(path.Join("/", remoteUrl.Host, remoteUrl.Path)) - } - fd := config.GetFederation() - defer config.SetFederation(fd) - config.SetFederation(config.FederationDiscovery{}) - fedUrl := &url.URL{} - fedUrl.Scheme = "https" - fedUrl.Host = "osg-htc.org" - viper.Set("Federation.DiscoveryUrl", fedUrl.String()) - if err = config.DiscoverFederation(); err != nil { - return - } - } - - tj.useDirector = param.Federation_DirectorUrl.GetString() != "" - ns, err := getNamespaceInfo(remoteUrl.Path, param.Federation_DirectorUrl.GetString(), upload) + tj.useDirector = pelicanURL.directorUrl != "" + ns, err := getNamespaceInfo(remoteUrl.Path, pelicanURL.directorUrl, upload) if err != nil { log.Errorln(err) - err = errors.Wrapf(err, "failed to get namespace information for remote URL %s", remoteUrl) + err = errors.Wrapf(err, "failed to get namespace information for remote URL %s", remoteUrl.String()) } tj.namespace = ns @@ -1290,7 +1396,7 @@ func sortAttempts(ctx context.Context, path string, attempts []transferAttemptDe } func downloadObject(transfer *transferFile) (transferResults TransferResults, err error) { - log.Debugln("Downloading file from", transfer.remoteURL, "to", transfer.localPath) + log.Debugln("Downloading object from", transfer.remoteURL, "to", transfer.localPath) // Remove the source from the file path directory := path.Dir(transfer.localPath) var downloaded int64 diff --git a/client/handle_http_test.go b/client/handle_http_test.go index 7b88ae5a2..6e0255c5a 100644 --- a/client/handle_http_test.go +++ b/client/handle_http_test.go @@ -23,6 +23,7 @@ package client import ( "bytes" "context" + "encoding/json" "net" "net/http" "net/http/httptest" @@ -34,6 +35,8 @@ import ( "testing" "time" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -515,3 +518,186 @@ func TestProjInUserAgent(t *testing.T) { // Test the user-agent header is what we expect it to be assert.Equal(t, "pelican-client/"+config.GetVersion()+" project/test", *server_test.user_agent) } + +func TestNewPelicanURL(t *testing.T) { + // Set up our federation and context + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + viper.Set("Client.WorkerCount", 5) + te := NewTransferEngine(ctx) + config.InitConfig() + + t.Run("TestOsdfOrStashSchemeWithOSDFPrefixNoError", func(t *testing.T) { + viper.Reset() + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + // Init config to get proper timeouts + config.InitConfig() + + remoteObject := "osdf:///something/somewhere/thatdoesnotexist.txt" + remoteObjectURL, err := url.Parse(remoteObject) + assert.NoError(t, err) + + // Instead of relying on osdf, let's just set our global metadata (osdf prefix does this for us) + viper.Set("Federation.DirectorUrl", "someDirectorUrl") + viper.Set("Federation.DiscoveryUrl", "someDiscoveryUrl") + viper.Set("Federation.RegistryUrl", "someRegistryUrl") + + pelicanURL, err := te.newPelicanURL(remoteObjectURL) + assert.NoError(t, err) + + // Check pelicanURL properly filled out + assert.Equal(t, "someDirectorUrl", pelicanURL.directorUrl) + viper.Reset() + }) + + t.Run("TestOsdfOrStashSchemeWithOSDFPrefixWithError", func(t *testing.T) { + viper.Reset() + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + config.InitConfig() + + remoteObject := "osdf:///something/somewhere/thatdoesnotexist.txt" + remoteObjectURL, err := url.Parse(remoteObject) + assert.NoError(t, err) + + // Instead of relying on osdf, let's just set our global metadata but don't set one piece + viper.Set("Federation.DirectorUrl", "someDirectorUrl") + viper.Set("Federation.DiscoveryUrl", "someDiscoveryUrl") + + _, err = te.newPelicanURL(remoteObjectURL) + // Make sure we get an error + assert.Error(t, err) + viper.Reset() + }) + + t.Run("TestOsdfOrStashSchemeWithPelicanPrefixNoError", func(t *testing.T) { + viper.Reset() + _, err := config.SetPreferredPrefix("PELICAN") + config.InitConfig() + assert.NoError(t, err) + remoteObject := "osdf:///something/somewhere/thatdoesnotexist.txt" + remoteObjectURL, err := url.Parse(remoteObject) + assert.NoError(t, err) + + pelicanURL, err := te.newPelicanURL(remoteObjectURL) + assert.NoError(t, err) + + // Check pelicanURL properly filled out + assert.Equal(t, "https://osdf-director.osg-htc.org", pelicanURL.directorUrl) + viper.Reset() + // Note: can't really test this for an error since that would require osg-htc.org to be down + }) + + t.Run("TestPelicanSchemeNoError", func(t *testing.T) { + viper.Reset() + viper.Set("TLSSkipVerify", true) + config.InitConfig() + err := config.InitClient() + assert.NoError(t, err) + // Create a server that gives us a mock response + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // make our response: + response := config.FederationDiscovery{ + DirectorEndpoint: "director", + NamespaceRegistrationEndpoint: "registry", + JwksUri: "jwks", + BrokerEndpoint: "broker", + } + + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + _, err = w.Write(responseJSON) + assert.NoError(t, err) + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + assert.NoError(t, err) + + remoteObject := "pelican://" + serverURL.Host + "/something/somewhere/thatdoesnotexist.txt" + remoteObjectURL, err := url.Parse(remoteObject) + assert.NoError(t, err) + + pelicanURL, err := te.newPelicanURL(remoteObjectURL) + assert.NoError(t, err) + + // Check pelicanURL properly filled out + assert.Equal(t, "director", pelicanURL.directorUrl) + // Check to make sure it was populated in our cache + assert.True(t, te.pelicanURLCache.Has("https://"+serverURL.Host)) + viper.Reset() + }) + + t.Run("TestPelicanSchemeWithError", func(t *testing.T) { + viper.Reset() + config.InitConfig() + + remoteObject := "pelican://some-host/something/somewhere/thatdoesnotexist.txt" + remoteObjectURL, err := url.Parse(remoteObject) + assert.NoError(t, err) + + _, err = te.newPelicanURL(remoteObjectURL) + assert.Error(t, err) + viper.Reset() + }) + + t.Run("TestPelicanSchemeMetadataTimeoutError", func(t *testing.T) { + viper.Reset() + viper.Set("TLSSkipVerify", true) + oldResponseHeaderTimeout := viper.Get("transport.ResponseHeaderTimeout") + viper.Set("transport.ResponseHeaderTimeout", 0.1*float64(time.Millisecond)) + viper.Set("Client.WorkerCount", 5) + err := config.InitClient() + assert.NoError(t, err) + // Create a server that gives us a mock response + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // make our response: + response := config.FederationDiscovery{ + DirectorEndpoint: "director", + NamespaceRegistrationEndpoint: "registry", + JwksUri: "jwks", + BrokerEndpoint: "broker", + } + + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + _, err = w.Write(responseJSON) + assert.NoError(t, err) + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + assert.NoError(t, err) + + remoteObject := "pelican://" + serverURL.Host + "/something/somewhere/thatdoesnotexist.txt" + remoteObjectURL, err := url.Parse(remoteObject) + assert.NoError(t, err) + + _, err = te.newPelicanURL(remoteObjectURL) + assert.Error(t, err) + assert.True(t, errors.Is(err, config.MetadataTimeoutErr)) + viper.Set("transport.ResponseHeaderTimeout", oldResponseHeaderTimeout) + }) + + t.Cleanup(func() { + cancel() + if err := egrp.Wait(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + require.NoError(t, err) + } + if err := te.Shutdown(); err != nil { + log.Errorln("Failure when shutting down transfer engine:") + } + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) +} diff --git a/client/main.go b/client/main.go index 2e5872c9d..056b0d9fc 100644 --- a/client/main.go +++ b/client/main.go @@ -41,7 +41,6 @@ import ( "github.com/pelicanplatform/pelican/config" "github.com/pelicanplatform/pelican/namespaces" "github.com/pelicanplatform/pelican/param" - "github.com/spf13/viper" ) // Number of caches to attempt to use in any invocation @@ -178,33 +177,25 @@ func DoStat(ctx context.Context, destination string, options ...TransferOption) return 0, err } - understoodSchemes := []string{"osdf", "pelican", "stash", ""} - - _, foundSource := find(understoodSchemes, destUri.Scheme) - if !foundSource { - log.Errorln("Unknown schema provided:", destUri.Scheme) - return 0, errors.New("Unsupported scheme requested") + // Check if we understand the found url scheme + err = schemeUnderstood(destUri.Scheme) + if err != nil { + return 0, err } - origScheme := destUri.Scheme - if config.GetPreferredPrefix() != "PELICAN" && origScheme == "" { - destUri.Scheme = "osdf" - } - if (destUri.Scheme == "osdf" || destUri.Scheme == "stash") && destUri.Host != "" { - destUri.Path = path.Clean("/" + destUri.Host + "/" + destUri.Path) - destUri.Host = "" - } else if destUri.Scheme == "pelican" { - federationUrl, _ := url.Parse(destUri.String()) - federationUrl.Scheme = "https" - federationUrl.Path = "" - viper.Set("Federation.DiscoveryUrl", federationUrl.String()) - err = config.DiscoverFederation() - if err != nil { - return 0, err + te := NewTransferEngine(ctx) + defer func() { + if err := te.Shutdown(); err != nil { + log.Errorln("Failure when shutting down transfer engine:", err) } + }() + + pelicanURL, err := te.newPelicanURL(destUri) + if err != nil { + return 0, errors.Wrap(err, "Failed to generate pelicanURL object") } - ns, err := getNamespaceInfo(destUri.Path, param.Federation_DirectorUrl.GetString(), false) + ns, err := getNamespaceInfo(destUri.Path, pelicanURL.directorUrl, false) if err != nil { return 0, err } @@ -413,6 +404,7 @@ func getNamespaceInfo(resourcePath, OSDFDirectorUrl string, isPut bool) (ns name } return } else { + log.Debugln("Director URL not found, searching in topology") ns, err = namespaces.MatchNamespace(resourcePath) if err != nil { return @@ -421,6 +413,17 @@ func getNamespaceInfo(resourcePath, OSDFDirectorUrl string, isPut bool) (ns name } } +func schemeUnderstood(scheme string) error { + understoodSchemes := []string{"file", "osdf", "pelican", "stash", ""} + + _, foundDest := find(understoodSchemes, scheme) + if !foundDest { + return errors.Errorf("Do not understand the destination scheme: %s. Permitted values are %s", + scheme, strings.Join(understoodSchemes, ", ")) + } + return nil +} + /* Start of transfer for pelican object put, gets information from the target destination before doing our HTTP PUT request @@ -446,37 +449,13 @@ func DoPut(ctx context.Context, localObject string, remoteDestination string, re return nil, err } remoteDestUrl.Scheme = remoteDestScheme - fd := config.GetFederation() - defer config.SetFederation(fd) - if remoteDestUrl.Host != "" { - if remoteDestUrl.Scheme == "osdf" || remoteDestUrl.Scheme == "stash" { - remoteDestUrl.Path, err = url.JoinPath(remoteDestUrl.Host, remoteDestUrl.Path) - if err != nil { - log.Errorln("Failed to join remote destination url path:", err) - return nil, err - } - } else if remoteDestUrl.Scheme == "pelican" { - - config.SetFederation(config.FederationDiscovery{}) - federationUrl, _ := url.Parse(remoteDestUrl.String()) - federationUrl.Scheme = "https" - federationUrl.Path = "" - viper.Set("Federation.DiscoveryUrl", federationUrl.String()) - err = config.DiscoverFederation() - if err != nil { - return nil, err - } - } - } remoteDestScheme, _ = getTokenName(remoteDestUrl) - understoodSchemes := []string{"file", "osdf", "pelican", ""} - - _, foundDest := find(understoodSchemes, remoteDestScheme) - if !foundDest { - return nil, fmt.Errorf("Do not understand the destination scheme: %s. Permitted values are %s", - remoteDestUrl.Scheme, strings.Join(understoodSchemes, ", ")) + // Check if we understand the found url scheme + err = schemeUnderstood(remoteDestScheme) + if err != nil { + return nil, err } project := GetProjectName() @@ -536,39 +515,14 @@ func DoGet(ctx context.Context, remoteObject string, localDestination string, re return nil, err } remoteObjectUrl.Scheme = remoteObjectScheme - fd := config.GetFederation() - defer config.SetFederation(fd) - - // If there is a host specified, prepend it to the path in the osdf case - if remoteObjectUrl.Host != "" { - if remoteObjectUrl.Scheme == "osdf" { - remoteObjectUrl.Path, err = url.JoinPath(remoteObjectUrl.Host, remoteObjectUrl.Path) - if err != nil { - log.Errorln("Failed to join source url path:", err) - return nil, err - } - } else if remoteObjectUrl.Scheme == "pelican" { - - config.SetFederation(config.FederationDiscovery{}) - federationUrl, _ := url.Parse(remoteObjectUrl.String()) - federationUrl.Scheme = "https" - federationUrl.Path = "" - viper.Set("Federation.DiscoveryUrl", federationUrl.String()) - err = config.DiscoverFederation() - if err != nil { - return nil, err - } - } - } + // This is for condor cases: remoteObjectScheme, _ = getTokenName(remoteObjectUrl) - understoodSchemes := []string{"file", "osdf", "pelican", ""} - - _, foundSource := find(understoodSchemes, remoteObjectScheme) - if !foundSource { - return nil, fmt.Errorf("Do not understand the source scheme: %s. Permitted values are %s", - remoteObjectUrl.Scheme, strings.Join(understoodSchemes, ", ")) + // Check if we understand the found url scheme + err = schemeUnderstood(remoteObjectScheme) + if err != nil { + return nil, err } if remoteObjectScheme == "osdf" || remoteObjectScheme == "pelican" { @@ -667,64 +621,16 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv sourceURL.Scheme = source_scheme destination, dest_scheme := correctURLWithUnderscore(destination) - dest_url, err := url.Parse(destination) + destURL, err := url.Parse(destination) if err != nil { log.Errorln("Failed to parse destination URL:", err) return nil, err } - dest_url.Scheme = dest_scheme - fd := config.GetFederation() - defer config.SetFederation(fd) - - // If there is a host specified, prepend it to the path in the osdf case - if sourceURL.Host != "" { - if sourceURL.Scheme == "osdf" || sourceURL.Scheme == "stash" { - sourceURL.Path = "/" + path.Join(sourceURL.Host, sourceURL.Path) - } else if sourceURL.Scheme == "pelican" { - config.SetFederation(config.FederationDiscovery{}) - federationUrl, _ := url.Parse(sourceURL.String()) - federationUrl.Scheme = "https" - federationUrl.Path = "" - viper.Set("Federation.DiscoveryUrl", federationUrl.String()) - err = config.DiscoverFederation() - if err != nil { - return nil, err - } - } - } - - if dest_url.Host != "" { - if dest_url.Scheme == "osdf" || dest_url.Scheme == "stash" { - dest_url.Path = "/" + path.Join(dest_url.Host, dest_url.Path) - } else if dest_url.Scheme == "pelican" { - config.SetFederation(config.FederationDiscovery{}) - federationUrl, _ := url.Parse(dest_url.String()) - federationUrl.Scheme = "https" - federationUrl.Path = "" - viper.Set("Federation.DiscoveryUrl", federationUrl.String()) - err = config.DiscoverFederation() - if err != nil { - return nil, err - } - } - } + destURL.Scheme = dest_scheme + // Check for scheme here for when using condor sourceScheme, _ := getTokenName(sourceURL) - destScheme, _ := getTokenName(dest_url) - - understoodSchemes := []string{"stash", "file", "osdf", "pelican", ""} - - _, foundSource := find(understoodSchemes, sourceScheme) - if !foundSource { - log.Errorln("Do not understand source scheme:", sourceURL.Scheme) - return nil, errors.New("Do not understand source scheme") - } - - _, foundDest := find(understoodSchemes, destScheme) - if !foundDest { - log.Errorln("Do not understand destination scheme:", sourceURL.Scheme) - return nil, errors.New("Do not understand destination scheme") - } + destScheme, _ := getTokenName(destURL) project := GetProjectName() @@ -733,13 +639,22 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv var localPath string var remoteURL *url.URL if isPut { - log.Debugln("Detected object write to remote federation object", dest_url.Path) + // Verify valid scheme + if err = schemeUnderstood(destScheme); err != nil { + return nil, err + } + + log.Debugln("Detected object write to remote federation object", destURL.Path) localPath = sourceFile - remoteURL = dest_url + remoteURL = destURL } else { + // Verify valid scheme + if err = schemeUnderstood(sourceScheme); err != nil { + return nil, err + } - if dest_url.Scheme == "file" { - destination = dest_url.Path + if destURL.Scheme == "file" { + destination = destURL.Path } if sourceScheme == "stash" || sourceScheme == "osdf" || sourceScheme == "pelican" { diff --git a/client/main_test.go b/client/main_test.go index 2bbcd68a0..1b6e654a6 100644 --- a/client/main_test.go +++ b/client/main_test.go @@ -350,3 +350,36 @@ func TestGetProjectName(t *testing.T) { assert.Equal(t, "testProject", projectName) }) } + +func TestSchemeUnderstood(t *testing.T) { + t.Run("TestProperSchemeOsdf", func(t *testing.T) { + scheme := "osdf" + err := schemeUnderstood(scheme) + assert.NoError(t, err) + }) + t.Run("TestProperSchemeStash", func(t *testing.T) { + scheme := "stash" + err := schemeUnderstood(scheme) + assert.NoError(t, err) + }) + t.Run("TestProperSchemePelican", func(t *testing.T) { + scheme := "pelican" + err := schemeUnderstood(scheme) + assert.NoError(t, err) + }) + t.Run("TestProperSchemeFile", func(t *testing.T) { + scheme := "file" + err := schemeUnderstood(scheme) + assert.NoError(t, err) + }) + t.Run("TestProperSchemeEmpty", func(t *testing.T) { + scheme := "" + err := schemeUnderstood(scheme) + assert.NoError(t, err) + }) + t.Run("TestImproperScheme", func(t *testing.T) { + scheme := "ThisSchemeDoesNotExistAndHopefullyNeverWill" + err := schemeUnderstood(scheme) + assert.Error(t, err) + }) +} diff --git a/client/sharing_url.go b/client/sharing_url.go index 5bbcdcdab..81cd69085 100644 --- a/client/sharing_url.go +++ b/client/sharing_url.go @@ -19,6 +19,7 @@ package client import ( + "context" "net/url" "strings" @@ -45,7 +46,7 @@ func getDirectorFromUrl(objectUrl *url.URL) (string, error) { } viper.Set("Federation.DirectorUrl", "") viper.Set("Federation.DiscoveryUrl", discoveryUrl.String()) - if err := config.DiscoverFederation(); err != nil { + if err := config.DiscoverFederation(context.Background()); err != nil { return "", errors.Wrapf(err, "Failed to discover location of the director for the federation %s", objectUrl.Host) } if directorUrl = param.Federation_DirectorUrl.GetString(); directorUrl == "" { @@ -58,7 +59,7 @@ func getDirectorFromUrl(objectUrl *url.URL) (string, error) { objectUrl.Host = "" } viper.Set("Federation.DiscoveryUrl", "https://osg-htc.org") - if err := config.DiscoverFederation(); err != nil { + if err := config.DiscoverFederation(context.Background()); err != nil { return "", errors.Wrap(err, "Failed to discover director for the OSDF") } if directorUrl = param.Federation_DirectorUrl.GetString(); directorUrl == "" { diff --git a/client/sharing_url_test.go b/client/sharing_url_test.go index d41524f74..d51a7f283 100644 --- a/client/sharing_url_test.go +++ b/client/sharing_url_test.go @@ -153,7 +153,8 @@ func TestSharingUrl(t *testing.T) { defer server.Close() myUrl = server.URL - config.SetPreferredPrefix("PELICAN") + _, err := config.SetPreferredPrefix("PELICAN") + assert.NoError(t, err) viper.Set("ConfigDir", t.TempDir()) viper.Set("Logging.Level", "debug") config.InitConfig() @@ -162,7 +163,7 @@ func TestSharingUrl(t *testing.T) { defer os.Unsetenv("PELICAN_SKIP_TERMINAL_CHECK") viper.Set("Federation.DirectorURL", myUrl) viper.Set("ConfigDir", t.TempDir()) - err := config.InitClient() + err = config.InitClient() assert.NoError(t, err) // Call QueryDirector with the test server URL and a source path diff --git a/cmd/object_copy.go b/cmd/object_copy.go index d0daed41f..a3bcccdc3 100644 --- a/cmd/object_copy.go +++ b/cmd/object_copy.go @@ -199,6 +199,7 @@ func copyMain(cmd *cobra.Command, args []string) { var result error lastSrc := "" + for _, src := range source { isRecursive, _ := cmd.Flags().GetBool("recursive") _, result = client.DoCopy(ctx, src, dest, isRecursive, client.WithCallback(pb.callback), client.WithTokenLocation(tokenLocation), client.WithCaches(caches...)) diff --git a/cmd/object_get.go b/cmd/object_get.go index 8301bca26..f66065363 100644 --- a/cmd/object_get.go +++ b/cmd/object_get.go @@ -117,6 +117,7 @@ func getMain(cmd *cobra.Command, args []string) { var result error lastSrc := "" + for _, src := range source { isRecursive, _ := cmd.Flags().GetBool("recursive") _, result = client.DoGet(ctx, src, dest, isRecursive, client.WithCallback(pb.callback), client.WithTokenLocation(tokenLocation), client.WithCaches(caches...)) diff --git a/cmd/object_put.go b/cmd/object_put.go index 31f332e69..fe2eb2178 100644 --- a/cmd/object_put.go +++ b/cmd/object_put.go @@ -87,6 +87,7 @@ func putMain(cmd *cobra.Command, args []string) { var result error lastSrc := "" + for _, src := range source { isRecursive, _ := cmd.Flags().GetBool("recursive") _, result = client.DoPut(ctx, src, dest, isRecursive, client.WithCallback(pb.callback), client.WithTokenLocation(tokenLocation)) diff --git a/cmd/plugin_test.go b/cmd/plugin_test.go index c109826c0..68ee7530e 100644 --- a/cmd/plugin_test.go +++ b/cmd/plugin_test.go @@ -25,6 +25,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net/http" "os" @@ -185,7 +186,8 @@ func TestStashPluginMain(t *testing.T) { viper.Reset() server_utils.ResetOriginExports() - config.SetPreferredPrefix("STASH") + _, err := config.SetPreferredPrefix("STASH") + assert.NoError(t, err) // Temp dir for downloads tempDir := os.TempDir() @@ -213,7 +215,7 @@ func TestStashPluginMain(t *testing.T) { // Set path for object to upload/download tempPath := tempFile.Name() fileName := filepath.Base(tempPath) - uploadURL := "pelican:///test/" + fileName + uploadURL := fmt.Sprintf("pelican://%s:%d/test/%s", param.Server_Hostname.GetString(), param.Server_WebPort.GetInt(), fileName) // Download a test file args := []string{uploadURL, tempDir} @@ -231,14 +233,14 @@ func TestStashPluginMain(t *testing.T) { var stderr bytes.Buffer cmd.Stderr = &stderr - err := cmd.Run() + err = cmd.Run() assert.NoError(t, err, stderr.String()) // changing output for "\\" since in windows there are excess "\" printed in debug logs output := strings.Replace(stderr.String(), "\\\\", "\\", -1) // Check captured output for successful download - expectedOutput := "Downloading: pelican:///test/test.txt to " + tempDir + expectedOutput := "Downloading object from pelican:///test/test.txt to " + tempDir assert.Contains(t, output, expectedOutput) successfulDownloadMsg := "HTTP Transfer was successful" assert.Contains(t, output, successfulDownloadMsg) diff --git a/config/config.go b/config/config.go index b9879e4c5..449872478 100644 --- a/config/config.go +++ b/config/config.go @@ -48,6 +48,8 @@ import ( // Structs holding the OAuth2 state (and any other OSDF config needed) type ( + ConfigPrefix string + TokenEntry struct { Expiration int64 `yaml:"expiration"` AccessToken string `yaml:"access_token"` @@ -94,6 +96,12 @@ type ( } ) +const ( + Pelican ConfigPrefix = "PELICAN" + OSDF ConfigPrefix = "OSDF" + Stash ConfigPrefix = "STASH" +) + const ( CacheType ServerType = 1 << iota OriginType @@ -149,6 +157,13 @@ var ( RestartFlag = make(chan any) // A channel flag to restart the server instance that launcher listens to (including cache) MetadataTimeoutErr *MetadataErr = &MetadataErr{msg: "Timeout when querying metadata"} + + validPrefixes = map[string]bool{ + string(Pelican): true, + string(OSDF): true, + string(Stash): true, + "": true, + } ) // This function creates a new MetadataError by wrapping the previous error @@ -357,12 +372,12 @@ func (sType *ServerType) SetString(name string) bool { func GetPreferredPrefix() string { // Testing override to programmatically force different behaviors. if testingPreferredPrefix != "" { - return testingPreferredPrefix + return string(ConfigPrefix(testingPreferredPrefix)) } arg0 := strings.ToUpper(filepath.Base(os.Args[0])) underscore_idx := strings.Index(arg0, "_") if underscore_idx != -1 { - return arg0[0:underscore_idx] + return string(ConfigPrefix(arg0[0:underscore_idx])) } if strings.HasPrefix(arg0, "STASH") { return "STASH" @@ -374,10 +389,14 @@ func GetPreferredPrefix() string { // Override the auto-detected preferred prefix; mostly meant for unittests. // Returns the old preferred prefix. -func SetPreferredPrefix(newPref string) string { - oldPref := testingPreferredPrefix +func SetPreferredPrefix(newPref string) (oldPref string, err error) { + newPref = strings.ToUpper(newPref) + if _, ok := validPrefixes[newPref]; !ok { + return "", errors.New("Invalid prefix provided") + } + oldPrefix := ConfigPrefix(testingPreferredPrefix) testingPreferredPrefix = newPref - return oldPref + return string(oldPrefix), nil } // Get the list of valid prefixes for this binary. Given there's been so @@ -394,7 +413,88 @@ func GetAllPrefixes() []string { return prefixes } -func DiscoverFederation() error { +// This function is for discovering federations as specified by a url during a pelican:// transfer. +// this does not populate global fields and is more temporary per url +func DiscoverUrlFederation(ctx context.Context, federationDiscoveryUrl string) (metadata FederationDiscovery, err error) { + log.Debugln("Performing federation service discovery for specified url against endpoint", federationDiscoveryUrl) + federationUrl, err := url.Parse(federationDiscoveryUrl) + if err != nil { + err = errors.Wrapf(err, "Invalid federation value %s:", federationDiscoveryUrl) + return + } + federationUrl.Scheme = "https" + if len(federationUrl.Path) > 0 && len(federationUrl.Host) == 0 { + federationUrl.Host = federationUrl.Path + federationUrl.Path = "" + } + + discoveryUrl, err := url.Parse(federationUrl.String()) + if err != nil { + err = errors.Wrap(err, "unable to parse federation discovery URL") + return + } + discoveryUrl.Path, err = url.JoinPath(federationUrl.Path, ".well-known/pelican-configuration") + if err != nil { + err = errors.Wrap(err, "Unable to parse federation url because of invalid path") + return + } + + httpClient := http.Client{ + Transport: GetTransport(), + Timeout: time.Second * 5, + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryUrl.String(), nil) + if err != nil { + err = errors.Wrapf(err, "Failure when doing federation metadata request creation for %s", discoveryUrl) + return + } + req.Header.Set("User-Agent", "pelican/"+version) + + result, err := httpClient.Do(req) + if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + err = MetadataTimeoutErr.Wrap(err) + return + } else { + err = NewMetadataError(err, "Error occured when querying for metadata") + return + } + } + + if result.Body != nil { + defer result.Body.Close() + } + + body, err := io.ReadAll(result.Body) + if err != nil { + return FederationDiscovery{}, errors.Wrapf(err, "Failure when doing federation metadata read to %s", discoveryUrl) + } + + if result.StatusCode != http.StatusOK { + truncatedMessage := string(body) + if len(body) > 1000 { + truncatedMessage = string(body[:1000]) + truncatedMessage += " [... remainder truncated ...]" + } + return FederationDiscovery{}, errors.Errorf("Federation metadata discovery failed with HTTP status %d. Error message: %s", result.StatusCode, truncatedMessage) + } + + metadata = FederationDiscovery{} + err = json.Unmarshal(body, &metadata) + if err != nil { + return FederationDiscovery{}, errors.Wrapf(err, "Failure when parsing federation metadata at %s", discoveryUrl) + } + + log.Debugln("Federation service discovery resulted in director URL", metadata.DirectorEndpoint) + log.Debugln("Federation service discovery resulted in registry URL", metadata.NamespaceRegistrationEndpoint) + log.Debugln("Federation service discovery resulted in JWKS URL", metadata.JwksUri) + log.Debugln("Federation service discovery resulted in broker URL", metadata.BrokerEndpoint) + + return metadata, nil +} + +func DiscoverFederation(ctx context.Context) error { federationStr := param.Federation_DiscoveryUrl.GetString() externalUrlStr := param.Server_ExternalWebUrl.GetString() defer func() { @@ -426,93 +526,46 @@ func DiscoverFederation() error { curRegistryURL := param.Federation_RegistryUrl.GetString() curFederationJwkURL := param.Federation_JwkUrl.GetString() curBrokerURL := param.Federation_BrokerUrl.GetString() - if len(curDirectorURL) != 0 && len(curRegistryURL) != 0 && len(curFederationJwkURL) != 0 { + if curDirectorURL != "" && curRegistryURL != "" && curFederationJwkURL != "" && curBrokerURL != "" { return nil } - log.Debugln("Performing federation service discovery against endpoint", federationStr) federationUrl, err := url.Parse(federationStr) if err != nil { return errors.Wrapf(err, "Invalid federation value %s:", federationStr) } + if federationUrl.Path != "" && federationUrl.Host != "" { // If the host is nothing, then the url is fine, but if we have a host and a path then there is a problem return errors.New("Invalid federation discovery url is set. No path allowed for federation discovery url. Provided url: " + federationStr) } + federationUrl.Scheme = "https" if len(federationUrl.Path) > 0 && len(federationUrl.Host) == 0 { federationUrl.Host = federationUrl.Path federationUrl.Path = "" } - discoveryUrl, err := url.Parse(federationUrl.String()) - if err != nil { - return errors.Wrap(err, "unable to parse federation discovery URL") - } - discoveryUrl.Path, err = url.JoinPath(federationUrl.Path, ".well-known/pelican-configuration") - if err != nil { - return errors.Wrap(err, "Unable to parse federation url because of invalid path") - } - - httpClient := http.Client{ - Transport: GetTransport(), - Timeout: time.Second * 5, - } - req, err := http.NewRequest(http.MethodGet, discoveryUrl.String(), nil) + metadata, err := DiscoverUrlFederation(ctx, federationStr) if err != nil { - return errors.Wrapf(err, "Failure when doing federation metadata request creation for %s", discoveryUrl) - } - req.Header.Set("User-Agent", "pelican/"+version) - - result, err := httpClient.Do(req) - if err != nil { - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - return MetadataTimeoutErr.Wrap(err) - } else { - return NewMetadataError(err, "Error occured when querying for metadata") - } - } - - if result.Body != nil { - defer result.Body.Close() - } - - body, err := io.ReadAll(result.Body) - if err != nil { - return errors.Wrapf(err, "Failure when doing federation metadata read to %s", discoveryUrl) - } - - if result.StatusCode != http.StatusOK { - truncatedMessage := string(body) - if len(body) > 1000 { - truncatedMessage = string(body[:1000]) - truncatedMessage += " [... remainder truncated ...]" - } - return errors.Errorf("Federation metadata discovery failed with HTTP status %d. Error message: %s", result.StatusCode, truncatedMessage) + return errors.Wrapf(err, "Invalid federation value %s:", federationStr) } - metadata := FederationDiscovery{} - err = json.Unmarshal(body, &metadata) - if err != nil { - return errors.Wrapf(err, "Failure when parsing federation metadata at %s", discoveryUrl) - } + // Set our globals if curDirectorURL == "" { - log.Debugln("Federation service discovery resulted in director URL", metadata.DirectorEndpoint) + log.Debugln("Setting global director url to", metadata.DirectorEndpoint) viper.Set("Federation.DirectorUrl", metadata.DirectorEndpoint) } if curRegistryURL == "" { - log.Debugln("Federation service discovery resulted in registry URL", - metadata.NamespaceRegistrationEndpoint) + log.Debugln("Setting global registry url to", metadata.NamespaceRegistrationEndpoint) viper.Set("Federation.RegistryUrl", metadata.NamespaceRegistrationEndpoint) } if curFederationJwkURL == "" { - log.Debugln("Federation service discovery resulted in JWKS URL", - metadata.JwksUri) + log.Debugln("Setting global jwks url to", metadata.JwksUri) viper.Set("Federation.JwkUrl", metadata.JwksUri) } if curBrokerURL == "" && metadata.BrokerEndpoint != "" { - log.Debugln("Federation service discovery resulted in broker URL", metadata.BrokerEndpoint) + log.Debugln("Setting global broker url to", metadata.BrokerEndpoint) viper.Set("Federation.BrokerUrl", metadata.BrokerEndpoint) } @@ -1142,7 +1195,7 @@ func InitServer(ctx context.Context, currentServers ServerType) error { // Sets up the server log filter mechanism initFilterLogging() - return DiscoverFederation() + return DiscoverFederation(ctx) } func InitClient() error { @@ -1268,5 +1321,5 @@ func InitClient() error { return err } - return DiscoverFederation() + return DiscoverFederation(context.Background()) } diff --git a/config/config_test.go b/config/config_test.go index 0b61ac16f..057460d90 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -30,6 +30,7 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spf13/viper" @@ -284,7 +285,49 @@ func TestEnabledServers(t *testing.T) { }) } +// Tests the function setPreferredPrefix: ensures case-insensitivity and invalid values are handled correctly +func TestSetPreferredPrefix(t *testing.T) { + t.Run("TestPelicanPreferredPrefix", func(t *testing.T) { + oldPref, err := SetPreferredPrefix("pelican") + assert.NoError(t, err) + if GetPreferredPrefix() != "PELICAN" { + t.Errorf("Expected preferred prefix to be 'PELICAN', got '%s'", GetPreferredPrefix()) + } + if oldPref != "" { + t.Errorf("Expected old preferred prefix to be empty, got '%s'", oldPref) + } + }) + + t.Run("TestOSDFPreferredPrefix", func(t *testing.T) { + oldPref, err := SetPreferredPrefix("osdf") + assert.NoError(t, err) + if GetPreferredPrefix() != "OSDF" { + t.Errorf("Expected preferred prefix to be 'OSDF', got '%s'", GetPreferredPrefix()) + } + if oldPref != "PELICAN" { + t.Errorf("Expected old preferred prefix to be 'PELICAN', got '%s'", oldPref) + } + }) + + t.Run("TestStashPreferredPrefix", func(t *testing.T) { + oldPref, err := SetPreferredPrefix("stash") + assert.NoError(t, err) + if GetPreferredPrefix() != "STASH" { + t.Errorf("Expected preferred prefix to be 'STASH', got '%s'", GetPreferredPrefix()) + } + if oldPref != "OSDF" { + t.Errorf("Expected old preferred prefix to be 'osdf', got '%s'", oldPref) + } + }) + + t.Run("TestInvalidPreferredPrefix", func(t *testing.T) { + _, err := SetPreferredPrefix("invalid") + assert.Error(t, err) + }) +} + func TestDiscoverFederation(t *testing.T) { + viper.Reset() // Server to be a "mock" federation server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -308,9 +351,8 @@ func TestDiscoverFederation(t *testing.T) { })) defer server.Close() t.Run("testInvalidDiscoveryUrlWithPath", func(t *testing.T) { - viper.Set("tlsskipverify", true) viper.Set("Federation.DiscoveryUrl", server.URL+"/this/is/some/path") - err := DiscoverFederation() + err := DiscoverFederation(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), "Invalid federation discovery url is set. No path allowed for federation discovery url. Provided url: ", "Error returned does not contain the correct error") @@ -318,9 +360,8 @@ func TestDiscoverFederation(t *testing.T) { }) t.Run("testValidDiscoveryUrl", func(t *testing.T) { - viper.Set("tlsskipverify", true) viper.Set("Federation.DiscoveryUrl", server.URL) - err := DiscoverFederation() + err := DiscoverFederation(context.Background()) assert.NoError(t, err) // Assert that the metadata matches expectations assert.Equal(t, "director", param.Federation_DirectorUrl.GetString(), "Unexpected DirectorEndpoint") @@ -331,9 +372,8 @@ func TestDiscoverFederation(t *testing.T) { }) t.Run("testOsgHtcUrl", func(t *testing.T) { - viper.Set("tlsskipverify", true) viper.Set("Federation.DiscoveryUrl", "osg-htc.org") - err := DiscoverFederation() + err := DiscoverFederation(context.Background()) assert.NoError(t, err) // Assert that the metadata matches expectations assert.Equal(t, "https://osdf-director.osg-htc.org", param.Federation_DirectorUrl.GetString(), "Unexpected DirectorEndpoint") @@ -477,3 +517,88 @@ func TestInitServerUrl(t *testing.T) { assert.Equal(t, "https://example-registry.com", param.Federation_BrokerUrl.GetString()) }) } +func TestDiscoverUrlFederation(t *testing.T) { + t.Run("TestMetadataDiscoveryTimeout", func(t *testing.T) { + viper.Set("tlsskipverify", true) + err := InitClient() + assert.NoError(t, err) + // Create a server that sleeps for a longer duration than the timeout + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) + })) + defer server.Close() + + // Set a short timeout for the test + timeout := 1 * time.Second + + // Create a context with the timeout + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // Call the function with the server URL and the context + _, err = DiscoverUrlFederation(ctx, server.URL) + + // Assert that the error is the expected metadata timeout error + assert.Error(t, err) + assert.True(t, errors.Is(err, MetadataTimeoutErr)) + viper.Reset() + }) + + t.Run("TestCanceledContext", func(t *testing.T) { + // Create a server that waits for the context to be canceled + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + defer server.Close() + + // Create a context and cancel it immediately + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Call the function with the server URL and the canceled context + _, err := DiscoverUrlFederation(ctx, server.URL) + + // Assert that the error is the expected context cancel error + assert.Error(t, err) + assert.True(t, errors.Is(err, context.Canceled)) + }) + + t.Run("TestValidDiscovery", func(t *testing.T) { + viper.Set("tlsskipverify", true) + err := InitClient() + assert.NoError(t, err) + // Server to be a "mock" federation + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // make our response: + response := FederationDiscovery{ + DirectorEndpoint: "director", + NamespaceRegistrationEndpoint: "registry", + JwksUri: "jwks", + BrokerEndpoint: "broker", + } + + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + _, err = w.Write(responseJSON) + assert.NoError(t, err) + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + metadata, err := DiscoverUrlFederation(ctx, server.URL) + assert.NoError(t, err) + + // Assert that the metadata matches expectations + assert.Equal(t, "director", metadata.DirectorEndpoint, "Unexpected DirectorEndpoint") + assert.Equal(t, "registry", metadata.NamespaceRegistrationEndpoint, "Unexpected NamespaceRegistrationEndpoint") + assert.Equal(t, "jwks", metadata.JwksUri, "Unexpected JwksUri") + assert.Equal(t, "broker", metadata.BrokerEndpoint, "Unexpected BrokerEndpoint") + viper.Reset() + }) +} diff --git a/github_scripts/get_put_test.sh b/github_scripts/get_put_test.sh index acb21af4f..8630d63cb 100755 --- a/github_scripts/get_put_test.sh +++ b/github_scripts/get_put_test.sh @@ -123,7 +123,7 @@ do done # Run pelican object put -./pelican object put ./get_put_tmp/input.txt pelican:///test/input.txt -d -t get_put_tmp/test-token.jwt -l get_put_tmp/putOutput.txt +./pelican object put ./get_put_tmp/input.txt pelican://$HOSTNAME:8444/test/input.txt -d -t get_put_tmp/test-token.jwt -l get_put_tmp/putOutput.txt # Check output of command if grep -q "Dumping response: HTTP/1.1 200 OK" get_put_tmp/putOutput.txt; then @@ -134,7 +134,7 @@ else exit 1 fi -./pelican object get pelican:///test/input.txt get_put_tmp/output.txt -d -t get_put_tmp/test-token.jwt -l get_put_tmp/getOutput.txt +./pelican object get pelican://$HOSTNAME:8444/test/input.txt get_put_tmp/output.txt -d -t get_put_tmp/test-token.jwt -l get_put_tmp/getOutput.txt # Check output of command if grep -q "HTTP Transfer was successful" get_put_tmp/getOutput.txt; then diff --git a/local_cache/cache_test.go b/local_cache/cache_test.go index 95a4a0f08..4fe621a73 100644 --- a/local_cache/cache_test.go +++ b/local_cache/cache_test.go @@ -31,6 +31,7 @@ import ( "net/url" "os" "path/filepath" + "strconv" "testing" "time" @@ -39,6 +40,7 @@ import ( "github.com/pelicanplatform/pelican/fed_test_utils" local_cache "github.com/pelicanplatform/pelican/local_cache" "github.com/pelicanplatform/pelican/param" + "github.com/pelicanplatform/pelican/test_utils" "github.com/pelicanplatform/pelican/token" "github.com/pelicanplatform/pelican/token_scopes" "github.com/pelicanplatform/pelican/utils" @@ -146,17 +148,16 @@ func TestClient(t *testing.T) { viper.Reset() ft := fed_test_utils.NewFedTest(t, authOriginCfg) + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + cacheUrl := &url.URL{ Scheme: "unix", Path: param.LocalCache_Socket.GetString(), } t.Run("correct-auth", func(t *testing.T) { - discoveryHost := param.Federation_DiscoveryUrl.GetString() - discoveryUrl, err := url.Parse(discoveryHost) - require.NoError(t, err) - tr, err := client.DoGet(ft.Ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, - client.WithToken(ft.Token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + tr, err := client.DoGet(ctx, "pelican://"+param.Server_Hostname.GetString()+":"+strconv.Itoa(param.Server_WebPort.GetInt())+"/test/hello_world.txt", + filepath.Join(tmpDir, "hello_world.txt"), false, client.WithToken(ft.Token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) assert.Equal(t, int64(13), tr[0].TransferredBytes) @@ -167,8 +168,8 @@ func TestClient(t *testing.T) { assert.Equal(t, "Hello, World!", string(byteBuff)) }) t.Run("incorrect-auth", func(t *testing.T) { - _, err := client.DoGet(ft.Ctx, "pelican:///test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, - client.WithToken("badtoken"), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + _, err := client.DoGet(ctx, "pelican://"+param.Server_Hostname.GetString()+":"+strconv.Itoa(param.Server_WebPort.GetInt())+"/test/hello_world.txt", + filepath.Join(tmpDir, "hello_world.txt"), false, client.WithToken("badtoken"), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) assert.Error(t, err) assert.ErrorIs(t, err, &client.ConnectionSetupError{}) var cse *client.ConnectionSetupError @@ -190,11 +191,19 @@ func TestClient(t *testing.T) { token, err := tokConf.CreateToken() require.NoError(t, err) - _, err = client.DoGet(ft.Ctx, "pelican:///test/hello_world.txt.1", filepath.Join(tmpDir, "hello_world.txt.1"), false, - client.WithToken(token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) + _, err = client.DoGet(ctx, "pelican://"+param.Server_Hostname.GetString()+":"+strconv.Itoa(param.Server_WebPort.GetInt())+"/test/hello_world.txt.1", + filepath.Join(tmpDir, "hello_world.txt.1"), false, client.WithToken(token), client.WithCaches(cacheUrl), client.WithAcquireToken(false)) assert.Error(t, err) assert.Equal(t, "failed to download file: transfer error: failed connection setup: server returned 404 Not Found", err.Error()) }) + t.Cleanup(func() { + cancel() + if err := egrp.Wait(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + require.NoError(t, err) + } + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) } // Test that HEAD requests to the local cache return the correct result @@ -258,6 +267,9 @@ func TestLargeFile(t *testing.T) { viper.Set("Client.MaximumDownloadSpeed", 40*1024*1024) ft := fed_test_utils.NewFedTest(t, pubOriginCfg) + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + te := client.NewTransferEngine(ctx) + cacheUrl := &url.URL{ Scheme: "unix", Path: param.LocalCache_Socket.GetString(), @@ -267,16 +279,26 @@ func TestLargeFile(t *testing.T) { require.NoError(t, err) size := writeBigBuffer(t, fp, 100) - discoveryHost := param.Federation_DiscoveryUrl.GetString() - discoveryUrl, err := url.Parse(discoveryHost) require.NoError(t, err) - tr, err := client.DoGet(ft.Ctx, "pelican://"+discoveryUrl.Host+"/test/hello_world.txt", filepath.Join(tmpDir, "hello_world.txt"), false, - client.WithCaches(cacheUrl)) + tr, err := client.DoGet(ctx, "pelican://"+param.Server_Hostname.GetString()+":"+strconv.Itoa(param.Server_WebPort.GetInt())+"/test/hello_world.txt", + filepath.Join(tmpDir, "hello_world.txt"), false, client.WithCaches(cacheUrl)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) assert.Equal(t, int64(size), tr[0].TransferredBytes) assert.NoError(t, tr[0].Error) + t.Cleanup(func() { + cancel() + if err := egrp.Wait(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + require.NoError(t, err) + } + if err := te.Shutdown(); err != nil { + log.Errorln("Failure when shutting down transfer engine:", err) + } + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) + } // Create five 1MB files. Trigger a purge, ensuring that the cleanup is @@ -288,6 +310,9 @@ func TestPurge(t *testing.T) { viper.Set("LocalCache.Size", "5MB") ft := fed_test_utils.NewFedTest(t, pubOriginCfg) + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + te := client.NewTransferEngine(ctx) + cacheUrl := &url.URL{ Scheme: "unix", Path: param.LocalCache_Socket.GetString(), @@ -303,8 +328,8 @@ func TestPurge(t *testing.T) { require.NotEqual(t, 0, size) for idx := 0; idx < 5; idx++ { - tr, err := client.DoGet(ft.Ctx, fmt.Sprintf("pelican:///test/hello_world.txt.%d", idx), filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, - client.WithCaches(cacheUrl)) + tr, err := client.DoGet(ctx, fmt.Sprintf("pelican://"+param.Server_Hostname.GetString()+":"+strconv.Itoa(param.Server_WebPort.GetInt())+"/test/hello_world.txt.%d", idx), + filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, client.WithCaches(cacheUrl)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) assert.Equal(t, int64(size), tr[0].TransferredBytes) @@ -324,6 +349,17 @@ func TestPurge(t *testing.T) { defer fp.Close() }() } + t.Cleanup(func() { + cancel() + if err := egrp.Wait(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + require.NoError(t, err) + } + if err := te.Shutdown(); err != nil { + log.Errorln("Failure when shutting down transfer engine:", err) + } + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) } // Create four 1MB files (above low-water mark). Force a purge, ensuring that the cleanup is @@ -337,6 +373,9 @@ func TestForcePurge(t *testing.T) { viper.Set("LocalCache.LowWaterMarkPercentage", "80") ft := fed_test_utils.NewFedTest(t, pubOriginCfg) + ctx, cancel, egrp := test_utils.TestContext(context.Background(), t) + te := client.NewTransferEngine(ctx) + issuer, err := config.GetServerIssuerURL() require.NoError(t, err) tokConf := token.NewWLCGToken() @@ -369,8 +408,8 @@ func TestForcePurge(t *testing.T) { require.NotEqual(t, 0, size) for idx := 0; idx < 4; idx++ { - tr, err := client.DoGet(ft.Ctx, fmt.Sprintf("pelican:///test/hello_world.txt.%d", idx), filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, - client.WithCaches(cacheUrl)) + tr, err := client.DoGet(ctx, fmt.Sprintf("pelican://"+param.Server_Hostname.GetString()+":"+strconv.Itoa(param.Server_WebPort.GetInt())+"/test/hello_world.txt.%d", idx), + filepath.Join(tmpDir, fmt.Sprintf("hello_world.txt.%d", idx)), false, client.WithCaches(cacheUrl)) assert.NoError(t, err) require.Equal(t, 1, len(tr)) assert.Equal(t, int64(size), tr[0].TransferredBytes) @@ -401,4 +440,15 @@ func TestForcePurge(t *testing.T) { defer fp.Close() }() } + t.Cleanup(func() { + cancel() + if err := egrp.Wait(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + require.NoError(t, err) + } + if err := te.Shutdown(); err != nil { + log.Errorln("Failure when shutting down transfer engine:", err) + } + // Throw in a viper.Reset for good measure. Keeps our env squeaky clean! + viper.Reset() + }) } diff --git a/local_cache/local_cache.go b/local_cache/local_cache.go index 3b30d8011..b4cc74f46 100644 --- a/local_cache/local_cache.go +++ b/local_cache/local_cache.go @@ -506,6 +506,7 @@ func (sc *LocalCache) runMux() error { sourceURL := *sc.directorURL sourceURL.Path = path.Join(sourceURL.Path, path.Clean(req.request.path)) + sourceURL.Scheme = "pelican" tj, err := sc.tc.NewTransferJob(&sourceURL, localPath, false, false, "localcache", client.WithToken(req.request.token)) if err != nil { ds := &downloadStatus{} diff --git a/namespaces/namespaces_test.go b/namespaces/namespaces_test.go index 57da4a258..6714999db 100644 --- a/namespaces/namespaces_test.go +++ b/namespaces/namespaces_test.go @@ -101,8 +101,12 @@ func TestMatchNamespace(t *testing.T) { t.Error(err) } // Reset the prefix to get old OSDF fallback behavior. - oldPrefix := config.SetPreferredPrefix("OSDF") - defer config.SetPreferredPrefix(oldPrefix) + oldPrefix, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + defer func() { + _, err := config.SetPreferredPrefix(oldPrefix) + assert.NoError(t, err) + }() viper.Reset() err = config.InitClient() @@ -248,10 +252,14 @@ func TestDownloadNamespacesFail(t *testing.T) { func TestGetNamespaces(t *testing.T) { // Set the environment to an invalid URL, so it is forced to use the "built-in" namespaces.json os.Setenv("OSDF_TOPOLOGY_NAMESPACE_URL", "https://doesnotexist.org.blah/namespaces.json") - oldPrefix := config.SetPreferredPrefix("OSDF") - defer config.SetPreferredPrefix(oldPrefix) + oldPrefix, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + defer func() { + _, err := config.SetPreferredPrefix(oldPrefix) + assert.NoError(t, err) + }() viper.Reset() - err := config.InitClient() + err = config.InitClient() assert.Nil(t, err) defer os.Unsetenv("OSDF_TOPOLOGY_NAMESPACE_URL") namespaces, err := GetNamespaces() diff --git a/registry/client_commands_test.go b/registry/client_commands_test.go index 873bcea1f..0298bf85f 100644 --- a/registry/client_commands_test.go +++ b/registry/client_commands_test.go @@ -130,7 +130,8 @@ func TestRegistryKeyChainingOSDF(t *testing.T) { defer cancel() viper.Reset() - _ = config.SetPreferredPrefix("OSDF") + _, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) viper.Set("Federation.DirectorUrl", "https://osdf-director.osg-htc.org") viper.Set("Federation.RegistryUrl", "https://osdf-registry.osg-htc.org") viper.Set("Federation.JwkUrl", "https://osg-htc.org/osdf/public_signing_key.jwks") @@ -142,7 +143,7 @@ func TestRegistryKeyChainingOSDF(t *testing.T) { registrySvr := registryMockup(ctx, t, "OSDFkeychaining") topoSvr := topologyMockup(t, []string{"/topo/foo"}) viper.Set("Federation.TopologyNamespaceURL", topoSvr.URL) - err := createTopologyTable() + err = createTopologyTable() require.NoError(t, err) err = PopulateTopology() require.NoError(t, err) @@ -212,7 +213,8 @@ func TestRegistryKeyChainingOSDF(t *testing.T) { err = NamespaceRegister(privKey, registrySvr.URL+"/api/v1.0/registry", "", "/topo") require.NoError(t, err) - config.SetPreferredPrefix("pelican") + _, err = config.SetPreferredPrefix("pelican") + assert.NoError(t, err) viper.Reset() } diff --git a/registry/registry_db_test.go b/registry/registry_db_test.go index 696944953..62185a39f 100644 --- a/registry/registry_db_test.go +++ b/registry/registry_db_test.go @@ -809,7 +809,8 @@ func TestRegistryTopology(t *testing.T) { }() // Set value so that config.GetPreferredPrefix() returns "OSDF" - config.SetPreferredPrefix("OSDF") + _, err = config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) //Test topology table population err = createTopologyTable() diff --git a/xrootd/authorization_test.go b/xrootd/authorization_test.go index 008dc84a6..dc08ac4d7 100644 --- a/xrootd/authorization_test.go +++ b/xrootd/authorization_test.go @@ -276,10 +276,14 @@ func TestOSDFAuthCreation(t *testing.T) { viper.Set("Origin.RunLocation", dirName) xrootdRun = param.Origin_RunLocation.GetString() } - oldPrefix := config.SetPreferredPrefix("OSDF") - defer config.SetPreferredPrefix(oldPrefix) + oldPrefix, err := config.SetPreferredPrefix("OSDF") + assert.NoError(t, err) + defer func() { + _, err := config.SetPreferredPrefix(oldPrefix) + require.NoError(t, err) + }() - err := os.WriteFile(filepath.Join(dirName, "authfile"), []byte(testInput.authIn), fs.FileMode(0600)) + err = os.WriteFile(filepath.Join(dirName, "authfile"), []byte(testInput.authIn), fs.FileMode(0600)) require.NoError(t, err, "Failure writing test input authfile") err = EmitAuthfile(testInput.server) diff --git a/xrootd/xrootd_config.go b/xrootd/xrootd_config.go index ab863f3d2..960af568e 100644 --- a/xrootd/xrootd_config.go +++ b/xrootd/xrootd_config.go @@ -283,7 +283,7 @@ func CheckCacheXrootdEnv(exportPath string, server server_structs.XRootDServer, filepath.Dir(metaPath)) } - err = config.DiscoverFederation() + err = config.DiscoverFederation(context.Background()) if err != nil { return "", errors.Wrap(err, "Failed to pull information from the federation") }