Skip to content

Commit

Permalink
Merge pull request PelicanPlatform#977 from joereuss12/client-wrong-u…
Browse files Browse the repository at this point in the history
…ser-agent-branch

Fix bug where project name not set in user-agent
  • Loading branch information
haoming29 authored Mar 29, 2024
2 parents 63a22dc + ee5dff9 commit 2105b05
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 107 deletions.
13 changes: 6 additions & 7 deletions client/handle_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ type (
upload bool
packOption string
attempts []transferAttemptDetails
accounting payloadStruct
project string
err error
}
Expand Down Expand Up @@ -736,7 +735,7 @@ func (te *TransferEngine) runJobHandler() error {
//
// The returned object can be further customized as desired.
// This function does not "submit" the job for execution.
func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, upload bool, recursive bool, options ...TransferOption) (tj *TransferJob, err error) {
func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, upload bool, recursive bool, project string, options ...TransferOption) (tj *TransferJob, err error) {

id, err := uuid.NewV7()
if err != nil {
Expand All @@ -755,6 +754,7 @@ func (tc *TransferClient) NewTransferJob(remoteUrl *url.URL, localPath string, u
upload: upload,
uuid: id,
token: tc.token,
project: project,
}
tj.ctx, tj.cancel = context.WithCancel(tc.ctx)

Expand Down Expand Up @@ -1314,7 +1314,7 @@ func downloadObject(transfer *transferFile) (transferResults TransferResults, er
transferEndpointUrl.Path = transfer.remoteURL.Path
transferEndpoint.Url = &transferEndpointUrl
transferStartTime := time.Now()
if downloaded, timeToFirstByte, serverVersion, err = downloadHTTP(transfer.ctx, transfer.engine, transfer.callback, transferEndpoint, transfer.localPath, size, transfer.token, &transfer.accounting); err != nil {
if downloaded, timeToFirstByte, serverVersion, err = downloadHTTP(transfer.ctx, transfer.engine, transfer.callback, transferEndpoint, transfer.localPath, size, transfer.token, transfer.project); err != nil {
log.Debugln("Failed to download:", err)
transferEndTime := time.Now()
transferTime := transferEndTime.Unix() - transferStartTime.Unix()
Expand Down Expand Up @@ -1386,7 +1386,7 @@ func parseTransferStatus(status string) (int, string) {
// Perform the actual download of the file
//
// Returns the downloaded size, time to 1st byte downloaded, serverVersion and an error if there is one
func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCallbackFunc, transfer transferAttemptDetails, dest string, totalSize int64, token string, payload *payloadStruct) (downloaded int64, timeToFirstByte float64, serverVersion string, err error) {
func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCallbackFunc, transfer transferAttemptDetails, dest string, totalSize int64, token string, project string) (downloaded int64, timeToFirstByte float64, serverVersion string, err error) {
defer func() {
if r := recover(); r != nil {
log.Errorln("Panic occurred in downloadHTTP:", r)
Expand Down Expand Up @@ -1414,6 +1414,7 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall

// Create the client, request, and context
client := grab.NewClient()
client.UserAgent = getUserAgent(project)
transport := config.GetTransport()
if !transfer.Proxy {
transport.Proxy = nil
Expand Down Expand Up @@ -1473,9 +1474,7 @@ func downloadHTTP(ctx context.Context, te *TransferEngine, callback TransferCall
req.HTTPRequest.Header.Set("X-Transfer-Status", "true")
req.HTTPRequest.Header.Set("X-Pelican-Timeout", param.Transport_ResponseHeaderTimeout.GetDuration().String())
req.HTTPRequest.Header.Set("TE", "trailers")
if payload != nil && payload.ProjectName != "" {
req.HTTPRequest.Header.Set("User-Agent", getUserAgent(payload.ProjectName))
}
req.HTTPRequest.Header.Set("User-Agent", getUserAgent(project))
req = req.WithContext(ctx)

// Test the transfer speed every 5 seconds
Expand Down
46 changes: 41 additions & 5 deletions client/handle_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func TestSlowTransfers(t *testing.T) {
var err error
// Do a quick timeout
go func() {
_, _, _, err = downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", nil)
_, _, _, err = downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", "")
finishedChannel <- true
}()

Expand Down Expand Up @@ -258,7 +258,7 @@ func TestStoppedTransfer(t *testing.T) {
var err error

go func() {
_, _, _, err = downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", nil)
_, _, _, err = downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", "")
finishedChannel <- true
}()

Expand Down Expand Up @@ -290,7 +290,7 @@ func TestConnectionError(t *testing.T) {
addr := l.Addr().String()
l.Close()

_, _, _, err = downloadHTTP(ctx, nil, nil, transferAttemptDetails{Url: &url.URL{Host: addr, Scheme: "http"}, Proxy: false}, filepath.Join(t.TempDir(), "test.txt"), -1, "", nil)
_, _, _, err = downloadHTTP(ctx, nil, nil, transferAttemptDetails{Url: &url.URL{Host: addr, Scheme: "http"}, Proxy: false}, filepath.Join(t.TempDir(), "test.txt"), -1, "", "")

assert.IsType(t, &ConnectionSetupError{}, err)

Expand Down Expand Up @@ -325,7 +325,7 @@ func TestTrailerError(t *testing.T) {
assert.Equal(t, svr.URL, transfers[0].Url.String())

// Call DownloadHTTP and check if the error is returned correctly
_, _, _, err := downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", nil)
_, _, _, err := downloadHTTP(ctx, nil, nil, transfers[0], filepath.Join(t.TempDir(), "test.txt"), -1, "", "")

assert.NotNil(t, err)
assert.EqualError(t, err, "transfer error: Unable to read test.txt; input/output error")
Expand Down Expand Up @@ -475,7 +475,43 @@ func TestTimeoutHeaderSetForDownload(t *testing.T) {

serverURL, err := url.Parse(server.URL)
assert.NoError(t, err)
_, _, _, err = downloadHTTP(ctx, nil, nil, transferAttemptDetails{Url: serverURL, Proxy: false}, filepath.Join(t.TempDir(), "test.txt"), -1, "", nil)
_, _, _, err = downloadHTTP(ctx, nil, nil, transferAttemptDetails{Url: serverURL, Proxy: false}, filepath.Join(t.TempDir(), "test.txt"), -1, "", "")
assert.NoError(t, err)
viper.Reset()
}

// Server test object for testing user agent
type (
server_test struct {
server *httptest.Server
user_agent *string
}
)

// Test to ensure the user-agent header is being updating in the request made within DownloadHTTP()
func TestProjInUserAgent(t *testing.T) {
ctx, _, _ := test_utils.TestContext(context.Background(), t)

server_test := server_test{}
// Create a mock server to download from
server_test.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Note: we check for this HEAD request because within DownloadHTTP() we make a HEAD request to get the content length
// This request is a different user-agent header (and different request) so we need to ignore it so server_test.user_agent is not overwritten
if r.Method == "HEAD" {
w.WriteHeader(http.StatusNoContent)
return
}
userAgent := r.UserAgent()
server_test.user_agent = &userAgent
}))
defer server_test.server.Close()
defer server_test.server.CloseClientConnections()

serverURL, err := url.Parse(server_test.server.URL)
assert.NoError(t, err)
_, _, _, err = downloadHTTP(ctx, nil, nil, transferAttemptDetails{Url: serverURL, Proxy: false}, filepath.Join(t.TempDir(), "test.txt"), -1, "", "test")
assert.NoError(t, err)

// Test the user-agent header is what we expect it to be
assert.Equal(t, "pelican-client/"+config.GetVersion()+" project/test", *server_test.user_agent)
}
98 changes: 13 additions & 85 deletions client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"os"
"path"
"path/filepath"
"time"

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
Expand All @@ -45,22 +44,6 @@ import (
"github.com/spf13/viper"
)

type (
payloadStruct struct {
filename string
status string
Owner string
ProjectName string
version string
start1 int64
end1 int64
timestamp int64
downloadTime int64
fileSize int64
downloadSize int64
}
)

// Number of caches to attempt to use in any invocation
var CachesToTry int = 3

Expand Down Expand Up @@ -496,6 +479,8 @@ func DoPut(ctx context.Context, localObject string, remoteDestination string, re
remoteDestUrl.Scheme, strings.Join(understoodSchemes, ", "))
}

project := GetProjectName()

te := NewTransferEngine(ctx)
defer func() {
if err := te.Shutdown(); err != nil {
Expand All @@ -506,7 +491,7 @@ func DoPut(ctx context.Context, localObject string, remoteDestination string, re
if err != nil {
return
}
tj, err := client.NewTransferJob(remoteDestUrl, localObject, true, recursive)
tj, err := client.NewTransferJob(remoteDestUrl, localObject, true, recursive, project)
if err != nil {
return
}
Expand Down Expand Up @@ -607,17 +592,7 @@ func DoGet(ctx context.Context, remoteObject string, localDestination string, re
localDestination = path.Join(localDestPath, remoteObjectFilename)
}

payload := payloadStruct{}
payload.version = config.GetVersion()

//Fill out the payload as much as possible
payload.filename = remoteObjectUrl.Path

parseJobAd(&payload)

start := time.Now()
payload.start1 = start.Unix()

project := GetProjectName()
success := false

te := NewTransferEngine(ctx)
Expand All @@ -630,7 +605,7 @@ func DoGet(ctx context.Context, remoteObject string, localDestination string, re
if err != nil {
return
}
tj, err := tc.NewTransferJob(remoteObjectUrl, localDestination, false, recursive)
tj, err := tc.NewTransferJob(remoteObjectUrl, localDestination, false, recursive, project)
if err != nil {
return
}
Expand All @@ -640,7 +615,6 @@ func DoGet(ctx context.Context, remoteObject string, localDestination string, re
}

transferResults, err = tc.Shutdown()
end := time.Now()
if err == nil {
if tj.lookupErr == nil {
success = true
Expand All @@ -657,20 +631,10 @@ func DoGet(ctx context.Context, remoteObject string, localDestination string, re
}
}

payload.end1 = end.Unix()

payload.timestamp = payload.end1
payload.downloadTime = int64(end.Sub(start).Seconds())

if success {
payload.status = "Success"

// Get the final size of the download file
payload.fileSize = downloaded
payload.downloadSize = downloaded
} else {
log.Error("Http GET failed! Unable to download file:", err)
payload.status = "Fail"
}

if !success {
Expand Down Expand Up @@ -762,8 +726,7 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv
return nil, errors.New("Do not understand destination scheme")
}

payload := payloadStruct{}
parseJobAd(&payload)
project := GetProjectName()

isPut := destScheme == "stash" || destScheme == "osdf" || destScheme == "pelican"

Expand Down Expand Up @@ -803,18 +766,7 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv
remoteURL = sourceURL
}

payload.version = config.GetVersion()

//Fill out the payload as much as possible
payload.filename = sourceURL.Path

start := time.Now()
payload.start1 = start.Unix()

// Go thru the download methods
success := false

// switch statement?
var downloaded int64 = 0

te := NewTransferEngine(ctx)
Expand All @@ -827,7 +779,7 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv
if err != nil {
return
}
tj, err := tc.NewTransferJob(remoteURL, localPath, isPut, recursive)
tj, err := tc.NewTransferJob(remoteURL, localPath, isPut, recursive, project)
if err != nil {
return
}
Expand All @@ -843,8 +795,6 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv
}
}

end := time.Now()

for _, result := range transferResults {
downloaded += result.TransferredBytes
if err == nil && result.Error != nil {
Expand All @@ -853,20 +803,9 @@ func DoCopy(ctx context.Context, sourceFile string, destination string, recursiv
}
}

payload.end1 = end.Unix()

payload.timestamp = payload.end1
payload.downloadTime = int64(end.Sub(start).Seconds())

if success {
payload.status = "Success"

// Get the final size of the download file
payload.fileSize = downloaded
payload.downloadSize = downloaded
return transferResults, nil
} else {
payload.status = "Fail"
return transferResults, err
}
}
Expand Down Expand Up @@ -917,8 +856,8 @@ func getIPs(name string) []string {

}

func parseJobAd(payload *payloadStruct) {

// This function parses a condor job ad and returns the project name if defined
func GetProjectName() string {
//Parse the .job.ad file for the Owner (username) and ProjectName of the callee.

condorJobAd, isPresent := os.LookupEnv("_CONDOR_JOB_AD")
Expand All @@ -928,7 +867,7 @@ func parseJobAd(payload *payloadStruct) {
} else if _, err := os.Stat(".job.ad"); err == nil {
filename = ".job.ad"
} else {
return
return ""
}

// https://stackoverflow.com/questions/28574609/how-to-apply-regexp-to-content-in-file-go
Expand All @@ -940,34 +879,23 @@ func parseJobAd(payload *payloadStruct) {

// Get all matches from file
// Note: This appears to be invalid regex but is the only thing that appears to work. This way it successfully finds our matches
classadRegex, e := regexp.Compile(`^*\s*(Owner|ProjectName)\s=\s"(.*)"`)
classadRegex, e := regexp.Compile(`^*\s*(ProjectName)\s=\s"*(.*)"*`)
if e != nil {
log.Fatal(e)
}

matches := classadRegex.FindAll(b, -1)
for _, match := range matches {
matchString := strings.TrimSpace(string(match))

if strings.HasPrefix(matchString, "Owner") {
matchParts := strings.Split(strings.TrimSpace(matchString), "=")

if len(matchParts) == 2 { // just confirm we get 2 parts of the string
matchValue := strings.TrimSpace(matchParts[1])
matchValue = strings.Trim(matchValue, "\"") //trim any "" around the match if present
payload.Owner = matchValue
}
}

if strings.HasPrefix(matchString, "ProjectName") {
matchParts := strings.Split(strings.TrimSpace(matchString), "=")

if len(matchParts) == 2 { // just confirm we get 2 parts of the string
matchValue := strings.TrimSpace(matchParts[1])
matchValue = strings.Trim(matchValue, "\"") //trim any "" around the match if present
payload.ProjectName = matchValue
return matchValue
}
}
}

return ""
}
Loading

0 comments on commit 2105b05

Please sign in to comment.