diff --git a/flytecopilot/data/download.go b/flytecopilot/data/download.go index 0fd1f10bd9..e4efa22222 100644 --- a/flytecopilot/data/download.go +++ b/flytecopilot/data/download.go @@ -8,8 +8,10 @@ import ( "io/ioutil" "os" "path" + "path/filepath" "reflect" "strconv" + "sync" "github.com/ghodss/yaml" "github.com/golang/protobuf/jsonpb" @@ -31,57 +33,187 @@ type Downloader struct { mode core.IOStrategy_DownloadMode } -// TODO add support for multipart blobs -func (d Downloader) handleBlob(ctx context.Context, blob *core.Blob, toFilePath string) (interface{}, error) { - ref := storage.DataReference(blob.Uri) - scheme, _, _, err := ref.Split() +// TODO add timeout and rate limit +// TODO use chunk to download +func (d Downloader) handleBlob(ctx context.Context, blob *core.Blob, toPath string) (interface{}, error) { + /* + handleBlob handles the retrieval and local storage of blob data, including support for both single and multipart blob types. + For multipart blobs, it lists all parts recursively and spawns concurrent goroutines to download each part while managing file I/O in parallel. + + - The function begins by validating the blob URI and categorizing the blob type (single or multipart). + - In the multipart case, it recursively lists all blob parts and launches goroutines to download and save each part. + Goroutine closure and I/O success tracking are managed to avoid resource leaks. + - For single-part blobs, it directly downloads and writes the data to the specified path. + + Life Cycle: + 1. Blob URI -> Blob Metadata Type check -> Recursive List parts if Multipart -> Launch goroutines to download parts + (input blob object) (determine multipart/single) (List API, handles recursive case) (each part handled in parallel) + 2. Download part or full blob -> Save locally with error checks -> Handle reader/writer closures -> Return local path or error + (download each part) (error on write or directory) (close streams safely, track success) (completion or report missing closures) + */ + + blobRef := storage.DataReference(blob.Uri) + scheme, _, _, err := blobRef.Split() if err != nil { return nil, errors.Wrapf(err, "Blob uri incorrectly formatted") } - var reader io.ReadCloser - if scheme == "http" || scheme == "https" { - reader, err = DownloadFileFromHTTP(ctx, ref) - } else { - if blob.GetMetadata().GetType().Dimensionality == core.BlobType_MULTIPART { - logger.Warnf(ctx, "Currently only single part blobs are supported, we will force multipart to be 'path/00000'") - ref, err = d.store.ConstructReference(ctx, ref, "000000") - if err != nil { + + if blob.GetMetadata().GetType().Dimensionality == core.BlobType_MULTIPART { + // Collect all parts of the multipart blob recursively (List API handles nested directories) + // Set maxItems to 100 as a parameter for the List API, enabling batch retrieval of items until all are downloaded + maxItems := 100 + cursor := storage.NewCursorAtStart() + var items []storage.DataReference + var absPaths []string + for { + items, cursor, err = d.store.List(ctx, blobRef, maxItems, cursor) + if err != nil || len(items) == 0 { + logger.Errorf(ctx, "failed to collect items from multipart blob [%s]", blobRef) return nil, err } + for _, item := range items { + absPaths = append(absPaths, item.String()) + } + if storage.IsCursorEnd(cursor) { + break + } + } + + // Track the count of successful downloads and the total number of items + downloadSuccess := 0 + itemCount := len(absPaths) + // Track successful closures of readers and writers in deferred functions + readerCloseSuccessCount := 0 + writerCloseSuccessCount := 0 + // We use Mutex to avoid race conditions when updating counters and creating directories + var mu sync.Mutex + var wg sync.WaitGroup + for _, absPath := range absPaths { + absPath := absPath + + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if err := recover(); err != nil { + logger.Errorf(ctx, "recover receives error: [%s]", err) + } + }() + + ref := storage.DataReference(absPath) + reader, err := DownloadFileFromStorage(ctx, ref, d.store) + if err != nil { + logger.Errorf(ctx, "Failed to download from ref [%s]", ref) + return + } + defer func() { + err := reader.Close() + if err != nil { + logger.Errorf(ctx, "failed to close Blob read stream @ref [%s].\n"+ + "Error: %s", ref, err) + } + mu.Lock() + readerCloseSuccessCount++ + mu.Unlock() + }() + + _, _, k, err := ref.Split() + if err != nil { + logger.Errorf(ctx, "Failed to parse ref [%s]", ref) + return + } + newPath := filepath.Join(toPath, k) + dir := filepath.Dir(newPath) + + mu.Lock() + // os.MkdirAll creates the specified directory structure if it doesn’t already exist + // 0777: the directory can be read and written by anyone + err = os.MkdirAll(dir, 0777) + mu.Unlock() + if err != nil { + logger.Errorf(ctx, "failed to make dir at path [%s]", dir) + return + } + + writer, err := os.Create(newPath) + if err != nil { + logger.Errorf(ctx, "failed to open file at path [%s]", newPath) + return + } + defer func() { + err := writer.Close() + if err != nil { + logger.Errorf(ctx, "failed to close File write stream.\n"+ + "Error: [%s]", err) + } + mu.Lock() + writerCloseSuccessCount++ + mu.Unlock() + }() + + _, err = io.Copy(writer, reader) + if err != nil { + logger.Errorf(ctx, "failed to write remote data to local filesystem") + return + } + mu.Lock() + downloadSuccess++ + mu.Unlock() + }() + } + // Go routines are synchronized with a WaitGroup to prevent goroutine leaks. + wg.Wait() + if downloadSuccess != itemCount || readerCloseSuccessCount != itemCount || writerCloseSuccessCount != itemCount { + return nil, errors.Errorf( + "Failed to copy %d out of %d remote files from [%s] to local [%s].\n"+ + "Failed to close %d readers\n"+ + "Failed to close %d writers.", + itemCount-downloadSuccess, itemCount, blobRef, toPath, itemCount-readerCloseSuccessCount, itemCount-writerCloseSuccessCount, + ) + } + logger.Infof(ctx, "successfully copied %d remote files from [%s] to local [%s]", downloadSuccess, blobRef, toPath) + return toPath, nil + } else if blob.GetMetadata().GetType().Dimensionality == core.BlobType_SINGLE { + // reader should be declared here (avoid being shared across all goroutines) + var reader io.ReadCloser + if scheme == "http" || scheme == "https" { + reader, err = DownloadFileFromHTTP(ctx, blobRef) + } else { + reader, err = DownloadFileFromStorage(ctx, blobRef, d.store) } - reader, err = DownloadFileFromStorage(ctx, ref, d.store) - } - if err != nil { - logger.Errorf(ctx, "Failed to download from ref [%s]", ref) - return nil, err - } - defer func() { - err := reader.Close() if err != nil { - logger.Errorf(ctx, "failed to close Blob read stream @ref [%s]. Error: %s", ref, err) + logger.Errorf(ctx, "Failed to download from ref [%s]", blobRef) + return nil, err } - }() + defer func() { + err := reader.Close() + if err != nil { + logger.Errorf(ctx, "failed to close Blob read stream @ref [%s]. Error: %s", blobRef, err) + } + }() - writer, err := os.Create(toFilePath) - if err != nil { - return nil, errors.Wrapf(err, "failed to open file at path %s", toFilePath) - } - defer func() { - err := writer.Close() + writer, err := os.Create(toPath) if err != nil { - logger.Errorf(ctx, "failed to close File write stream. Error: %s", err) + return nil, errors.Wrapf(err, "failed to open file at path %s", toPath) } - }() - v, err := io.Copy(writer, reader) - if err != nil { - return nil, errors.Wrapf(err, "failed to write remote data to local filesystem") + defer func() { + err := writer.Close() + if err != nil { + logger.Errorf(ctx, "failed to close File write stream. Error: %s", err) + } + }() + v, err := io.Copy(writer, reader) + if err != nil { + return nil, errors.Wrapf(err, "failed to write remote data to local filesystem") + } + logger.Infof(ctx, "Successfully copied [%d] bytes remote data from [%s] to local [%s]", v, blobRef, toPath) + return toPath, nil } - logger.Infof(ctx, "Successfully copied [%d] bytes remote data from [%s] to local [%s]", v, ref, toFilePath) - return toFilePath, nil + + return nil, errors.Errorf("unexpected Blob type encountered") } func (d Downloader) handleSchema(ctx context.Context, schema *core.Schema, toFilePath string) (interface{}, error) { - // TODO Handle schema type return d.handleBlob(ctx, &core.Blob{Uri: schema.Uri, Metadata: &core.BlobMetadata{Type: &core.BlobType{Dimensionality: core.BlobType_MULTIPART}}}, toFilePath) } diff --git a/flytecopilot/data/download_test.go b/flytecopilot/data/download_test.go new file mode 100644 index 0000000000..1f3b3a7be6 --- /dev/null +++ b/flytecopilot/data/download_test.go @@ -0,0 +1,151 @@ +package data + +import ( + "bytes" + "context" + "os" + "path/filepath" + "testing" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/flyteorg/flyte/flytestdlib/storage" + + "github.com/stretchr/testify/assert" +) + +func TestHandleBlobMultipart(t *testing.T) { + t.Run("Successful Query", func(t *testing.T) { + s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + ref := storage.DataReference("s3://container/folder/file1") + s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{})) + ref = storage.DataReference("s3://container/folder/file2") + s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{})) + + d := Downloader{store: s} + + blob := &core.Blob{ + Uri: "s3://container/folder", + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{ + Dimensionality: core.BlobType_MULTIPART, + }, + }, + } + + toPath := "./inputs" + defer func() { + err := os.RemoveAll(toPath) + if err != nil { + t.Errorf("Failed to delete directory: %v", err) + } + }() + + result, err := d.handleBlob(context.Background(), blob, toPath) + assert.NoError(t, err) + assert.Equal(t, toPath, result) + + // Check if files were created and data written + for _, file := range []string{"file1", "file2"} { + if _, err := os.Stat(filepath.Join(toPath, "folder", file)); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", file) + } + } + }) + + t.Run("No Items", func(t *testing.T) { + s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + d := Downloader{store: s} + + blob := &core.Blob{ + Uri: "s3://container/folder", + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{ + Dimensionality: core.BlobType_MULTIPART, + }, + }, + } + + toPath := "./inputs" + defer func() { + err := os.RemoveAll(toPath) + if err != nil { + t.Errorf("Failed to delete directory: %v", err) + } + }() + + result, err := d.handleBlob(context.Background(), blob, toPath) + assert.Error(t, err) + assert.Nil(t, result) + }) +} + +func TestHandleBlobSinglePart(t *testing.T) { + s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + ref := storage.DataReference("s3://container/file") + s.WriteRaw(context.Background(), ref, 0, storage.Options{}, bytes.NewReader([]byte{})) + + d := Downloader{store: s} + + blob := &core.Blob{ + Uri: "s3://container/file", + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{ + Dimensionality: core.BlobType_SINGLE, + }, + }, + } + + toPath := "./input" + defer func() { + err := os.RemoveAll(toPath) + if err != nil { + t.Errorf("Failed to delete file: %v", err) + } + }() + + result, err := d.handleBlob(context.Background(), blob, toPath) + assert.NoError(t, err) + assert.Equal(t, toPath, result) + + // Check if files were created and data written + if _, err := os.Stat(toPath); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", toPath) + } +} + +func TestHandleBlobHTTP(t *testing.T) { + s, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + d := Downloader{store: s} + + blob := &core.Blob{ + Uri: "https://raw.githubusercontent.com/flyteorg/flyte/master/README.md", + Metadata: &core.BlobMetadata{ + Type: &core.BlobType{ + Dimensionality: core.BlobType_SINGLE, + }, + }, + } + + toPath := "./input" + defer func() { + err := os.RemoveAll(toPath) + if err != nil { + t.Errorf("Failed to delete file: %v", err) + } + }() + + result, err := d.handleBlob(context.Background(), blob, toPath) + assert.NoError(t, err) + assert.Equal(t, toPath, result) + + // Check if files were created and data written + if _, err := os.Stat(toPath); os.IsNotExist(err) { + t.Errorf("expected file %s to exist", toPath) + } +} diff --git a/flytestdlib/storage/mem_store.go b/flytestdlib/storage/mem_store.go index d9da9b5b1e..540423a2a0 100644 --- a/flytestdlib/storage/mem_store.go +++ b/flytestdlib/storage/mem_store.go @@ -9,6 +9,7 @@ import ( "io" "io/ioutil" "os" + "strings" "sync" ) @@ -60,7 +61,20 @@ func (s *InMemoryStore) Head(ctx context.Context, reference DataReference) (Meta } func (s *InMemoryStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { - return nil, NewCursorAtEnd(), fmt.Errorf("Not implemented yet") + var items []DataReference + prefix := strings.TrimSuffix(string(reference), "/") + "/" + + for ref := range s.cache { + if strings.HasPrefix(ref.String(), prefix) { + items = append(items, ref) + } + } + + if len(items) == 0 { + return nil, NewCursorAtEnd(), os.ErrNotExist + } + + return items, NewCursorAtEnd(), nil } func (s *InMemoryStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go index 3d53a4d25f..e365816ff0 100644 --- a/flytestdlib/storage/storage.go +++ b/flytestdlib/storage/storage.go @@ -76,6 +76,10 @@ func NewCursorFromCustomPosition(customPosition string) Cursor { } } +func IsCursorEnd(cursor Cursor) bool { + return cursor.cursorState == AtEndCursorState +} + // DataStore is a simplified interface for accessing and storing data in one of the Cloud stores. // Today we rely on Stow for multi-cloud support, but this interface abstracts that part type DataStore struct { @@ -114,7 +118,7 @@ type RawStore interface { // Head gets metadata about the reference. This should generally be a light weight operation. Head(ctx context.Context, reference DataReference) (Metadata, error) - // List gets a list of items given a prefix, using a paginated API + // List gets a list of items (relative path to the reference input) given a prefix, using a paginated API List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) // ReadRaw retrieves a byte array from the Blob store or an error