From d74d7a2e52e9e2e4b07d8c0f5ee5a7b8a972462f Mon Sep 17 00:00:00 2001 From: Bugra Gedik Date: Tue, 1 Oct 2024 21:46:16 +0000 Subject: [PATCH] Add multi file error aggregation strategy --- .../ioutils/remote_file_output_reader.go | 228 +++++++++++++++--- .../ioutils/remote_file_output_reader_test.go | 72 +++++- .../go/tasks/pluginmachinery/k8s/plugin.go | 12 + .../nodes/task/k8s/plugin_manager.go | 4 +- flytestdlib/storage/storage.go | 5 + flytestdlib/storage/stow_store.go | 10 +- flytestdlib/storage/stow_store_test.go | 6 +- 7 files changed, 290 insertions(+), 47 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go b/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go index 27d7748701..02439e2daf 100644 --- a/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go +++ b/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go @@ -3,37 +3,90 @@ package ioutils import ( "context" "fmt" + "math" + "path/filepath" + "strings" "github.com/pkg/errors" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flytestdlib/storage" ) -type RemoteFileOutputReader struct { - outPath io.OutputFilePaths +type ErrorRetriever interface { + HasError(ctx context.Context) (bool, error) + GetError(ctx context.Context) (io.ExecutionError, error) +} + +type ErrorRetrieverBase struct { store storage.ComposedProtobufStore maxPayloadSize int64 } -func (r RemoteFileOutputReader) IsError(ctx context.Context) (bool, error) { - metadata, err := r.store.Head(ctx, r.outPath.GetErrorPath()) +type SingleFileErrorRetriever struct { + ErrorRetrieverBase + errorFilePath storage.DataReference +} + +func NewSingleFileErrorRetriever(errorFilePath storage.DataReference, store storage.ComposedProtobufStore, maxPayloadSize int64) *SingleFileErrorRetriever { + return &SingleFileErrorRetriever{ + ErrorRetrieverBase: ErrorRetrieverBase{ + store: store, + maxPayloadSize: maxPayloadSize, + }, + errorFilePath: errorFilePath, + } +} + +func (s *SingleFileErrorRetriever) HasError(ctx context.Context) (bool, error) { + metadata, err := s.store.Head(ctx, s.errorFilePath) if err != nil { - return false, errors.Wrapf(err, "failed to read error file @[%s]", r.outPath.GetErrorPath()) + return false, errors.Wrapf(err, "failed to read error file @[%s]", s.errorFilePath) } if metadata.Exists() { - if metadata.Size() > r.maxPayloadSize { - return false, errors.Wrapf(err, "error file @[%s] is too large [%d] bytes, max allowed [%d] bytes", r.outPath.GetErrorPath(), metadata.Size(), r.maxPayloadSize) + if metadata.Size() > s.maxPayloadSize { + return false, errors.Wrapf(err, "error file @[%s] is too large [%d] bytes, max allowed [%d] bytes", s.errorFilePath, metadata.Size(), s.maxPayloadSize) } return true, nil } return false, nil } -func (r RemoteFileOutputReader) ReadError(ctx context.Context) (io.ExecutionError, error) { +func errorDoc2ExecutionError(errorDoc *core.ErrorDocument, errorFilePath storage.DataReference) io.ExecutionError { + if errorDoc.Error == nil { + return io.ExecutionError{ + IsRecoverable: true, + ExecutionError: &core.ExecutionError{ + Code: "ErrorFileBadFormat", + Message: fmt.Sprintf("error not formatted correctly, nil error @path [%s]", errorFilePath), + Kind: core.ExecutionError_SYSTEM, + }, + } + } + executionError := io.ExecutionError{ + ExecutionError: &core.ExecutionError{ + Code: errorDoc.Error.Code, + Message: errorDoc.Error.Message, + Kind: errorDoc.Error.Origin, + }, + } + + if errorDoc.Error.Kind == core.ContainerError_RECOVERABLE { + executionError.IsRecoverable = true + } + + if errorDoc.Error.Kind == core.ContainerError_RECOVERABLE { + executionError.IsRecoverable = true + } + + return executionError +} + +func (s *SingleFileErrorRetriever) GetError(ctx context.Context) (io.ExecutionError, error) { errorDoc := &core.ErrorDocument{} - err := r.store.ReadProtobuf(ctx, r.outPath.GetErrorPath(), errorDoc) + err := s.store.ReadProtobuf(ctx, storage.DataReference(s.errorFilePath), errorDoc) if err != nil { if storage.IsNotFound(err) { return io.ExecutionError{ @@ -45,33 +98,143 @@ func (r RemoteFileOutputReader) ReadError(ctx context.Context) (io.ExecutionErro }, }, nil } - return io.ExecutionError{}, errors.Wrapf(err, "failed to read error data from task @[%s]", r.outPath.GetErrorPath()) + return io.ExecutionError{}, errors.Wrapf(err, "failed to read error data from task @[%s]", s.errorFilePath) } - if errorDoc.Error == nil { - return io.ExecutionError{ - IsRecoverable: true, - ExecutionError: &core.ExecutionError{ - Code: "ErrorFileBadFormat", - Message: fmt.Sprintf("error not formatted correctly, nil error @path [%s]", r.outPath.GetErrorPath()), - Kind: core.ExecutionError_SYSTEM, - }, - }, nil + return errorDoc2ExecutionError(errorDoc, s.errorFilePath), nil +} + +type EarliestFileErrorRetriever struct { + ErrorRetrieverBase + errorDirPath storage.DataReference + canonicalErrorFilename string +} + +func (e *EarliestFileErrorRetriever) parseErrorFilename() (errorFilePathPrefix storage.DataReference, errorFileExtension string, err error) { + // If the canonical error file name is error.pb, we expect multiple error files + // to have name error.pb + pieces := strings.Split(e.canonicalErrorFilename, ".") + if len(pieces) != 2 { + err = errors.Errorf("expected canoncal error filename to have a single ., got %d", len(pieces)) + return } + errorFilePrefix := pieces[0] + scheme, container, key, _ := e.errorDirPath.Split() + errorFilePathPrefix = storage.NewDataReference(scheme, container, filepath.Join(key, errorFilePrefix)) + errorFileExtension = fmt.Sprintf(".%s", pieces[1]) + return +} - ee := io.ExecutionError{ - ExecutionError: &core.ExecutionError{ - Code: errorDoc.Error.Code, - Message: errorDoc.Error.Message, - Kind: errorDoc.Error.Origin, +func (e *EarliestFileErrorRetriever) HasError(ctx context.Context) (bool, error) { + errorFilePathPrefix, errorFileExtension, err := e.parseErrorFilename() + if err != nil { + return false, errors.Wrapf(err, "failed to parse canonical error filename @[%s]", e.canonicalErrorFilename) + } + const maxItems = 1000 + cursor := storage.NewCursorAtStart() + for cursor != storage.NewCursorAtEnd() { + var err error + var errorFilePaths []storage.DataReference + errorFilePaths, cursor, err = e.store.List(ctx, errorFilePathPrefix, maxItems, cursor) + if err != nil { + return false, errors.Wrapf(err, "failed to list error files @[%s]", e.errorDirPath) + } + for _, errorFilePath := range errorFilePaths { + if strings.HasSuffix(errorFilePath.String(), errorFileExtension) { + return true, nil + } + } + } + return false, nil +} + +func (e *EarliestFileErrorRetriever) GetError(ctx context.Context) (io.ExecutionError, error) { + errorFilePathPrefix, errorFileExtension, err := e.parseErrorFilename() + if err != nil { + return io.ExecutionError{}, errors.Wrapf(err, "failed to parse canonical error filename @[%s]", e.canonicalErrorFilename) + } + const maxItems = 1000 + cursor := storage.NewCursorAtStart() + type ErrorFileAndDocument struct { + errorFilePath storage.DataReference + errorDoc *core.ErrorDocument + } + var errorFileAndDocs []ErrorFileAndDocument + for cursor != storage.NewCursorAtEnd() { + var err error + var errorFilePaths []storage.DataReference + errorFilePaths, cursor, err = e.store.List(ctx, errorFilePathPrefix, maxItems, cursor) + if err != nil { + return io.ExecutionError{}, errors.Wrapf(err, "failed to list error files @[%s]", e.errorDirPath) + } + for _, errorFilePath := range errorFilePaths { + if strings.HasSuffix(errorFilePath.String(), errorFileExtension) { + errorDoc := &core.ErrorDocument{} + err := e.store.ReadProtobuf(ctx, errorFilePath, errorDoc) + if err != nil { + return io.ExecutionError{}, errors.Wrapf(err, "failed to read error file @[%s]", errorFilePath.String()) + } + errorFileAndDocs = append(errorFileAndDocs, ErrorFileAndDocument{errorFilePath: errorFilePath, errorDoc: errorDoc}) + } + } + } + + extractTimestampFromErrorDoc := func(errorDoc *core.ErrorDocument) int64 { + // TODO: add optional timestamp to ErrorDocument + if errorDoc == nil { + panic("") + } + return 0 + } + + var earliestTimestamp int64 = math.MaxInt64 + earliestExecutionError := io.ExecutionError{} + for _, errorFileAndDoc := range errorFileAndDocs { + timestamp := extractTimestampFromErrorDoc(errorFileAndDoc.errorDoc) + if earliestTimestamp >= timestamp { + earliestExecutionError = errorDoc2ExecutionError(errorFileAndDoc.errorDoc, errorFileAndDoc.errorFilePath) + earliestTimestamp = timestamp + } + } + return earliestExecutionError, nil +} + +func NewEarliestFileErrorRetriever(errorDirPath storage.DataReference, canonicalErrorFilename string, store storage.ComposedProtobufStore, maxPayloadSize int64) *EarliestFileErrorRetriever { + return &EarliestFileErrorRetriever{ + ErrorRetrieverBase: ErrorRetrieverBase{ + store: store, + maxPayloadSize: maxPayloadSize, }, + errorDirPath: errorDirPath, + canonicalErrorFilename: canonicalErrorFilename, } +} - if errorDoc.Error.Kind == core.ContainerError_RECOVERABLE { - ee.IsRecoverable = true +func NewErrorRetriever(errorAggregationStrategy k8s.ErrorAggregationStrategy, errorDirPath storage.DataReference, errorFilename string, store storage.ComposedProtobufStore, maxPayloadSize int64) ErrorRetriever { + if errorAggregationStrategy == k8s.DefaultErrorAggregationStrategy { + scheme, container, key, _ := errorDirPath.Split() + errorFilePath := storage.NewDataReference(scheme, container, filepath.Join(key, errorFilename)) + return NewSingleFileErrorRetriever(errorFilePath, store, maxPayloadSize) + } + if errorAggregationStrategy == k8s.EarliestErrorAggregationStrategy { + return NewEarliestFileErrorRetriever(errorDirPath, errorFilename, store, maxPayloadSize) } + return nil +} + +type RemoteFileOutputReader struct { + outPath io.OutputFilePaths + store storage.ComposedProtobufStore + maxPayloadSize int64 + errorRetriever ErrorRetriever +} - return ee, nil +func (r RemoteFileOutputReader) IsError(ctx context.Context) (bool, error) { + return r.errorRetriever.HasError(ctx) +} + +func (r RemoteFileOutputReader) ReadError(ctx context.Context) (io.ExecutionError, error) { + return r.errorRetriever.GetError(ctx) } func (r RemoteFileOutputReader) Exists(ctx context.Context) (bool, error) { @@ -122,16 +285,25 @@ func (r RemoteFileOutputReader) DeckExists(ctx context.Context) (bool, error) { return md.Exists(), nil } -func NewRemoteFileOutputReader(_ context.Context, store storage.ComposedProtobufStore, outPaths io.OutputFilePaths, maxDatasetSize int64) RemoteFileOutputReader { +func NewRemoteFileOutputReader(context context.Context, store storage.ComposedProtobufStore, outPaths io.OutputFilePaths, maxDatasetSize int64) RemoteFileOutputReader { + return NewRemoteFileOutputReaderWithErrorAggregationStrategy(context, store, outPaths, maxDatasetSize, k8s.DefaultErrorAggregationStrategy) +} + +func NewRemoteFileOutputReaderWithErrorAggregationStrategy(_ context.Context, store storage.ComposedProtobufStore, outPaths io.OutputFilePaths, maxDatasetSize int64, errorAggregationStrategy k8s.ErrorAggregationStrategy) RemoteFileOutputReader { // Note: even though the data store retrieval checks against GetLimitMegabytes, there might be external // storage implementations, so we keep this check here as well. maxPayloadSize := maxDatasetSize if maxPayloadSize == 0 { maxPayloadSize = storage.GetConfig().Limits.GetLimitMegabytes * 1024 * 1024 } + scheme, container, key, _ := outPaths.GetErrorPath().Split() + errorFilename := filepath.Base(key) + errorDirPath := storage.NewDataReference(scheme, container, filepath.Dir(key)) + errorRetriever := NewErrorRetriever(errorAggregationStrategy, errorDirPath, errorFilename, store, maxPayloadSize) return RemoteFileOutputReader{ outPath: outPaths, store: store, maxPayloadSize: maxPayloadSize, + errorRetriever: errorRetriever, } } diff --git a/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go b/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go index 251a3adc55..d699403601 100644 --- a/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go @@ -2,6 +2,7 @@ package ioutils import ( "context" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -9,6 +10,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginsIOMock "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flytestdlib/storage" storageMocks "github.com/flyteorg/flyte/flytestdlib/storage/mocks" ) @@ -65,11 +67,13 @@ func TestReadOrigin(t *testing.T) { exists: true, }, nil) - r := RemoteFileOutputReader{ - outPath: opath, - store: store, - maxPayloadSize: 0, - } + maxPayloadSize := int64(0) + r := NewRemoteFileOutputReader( + ctx, + store, + opath, + maxPayloadSize, + ) ee, err := r.ReadError(ctx) assert.NoError(t, err) @@ -97,15 +101,63 @@ func TestReadOrigin(t *testing.T) { casted.Error = errorDoc.Error }).Return(nil) - r := RemoteFileOutputReader{ - outPath: opath, - store: store, - maxPayloadSize: 0, - } + maxPayloadSize := int64(0) + r := NewRemoteFileOutputReader( + ctx, + store, + opath, + maxPayloadSize, + ) ee, err := r.ReadError(ctx) assert.NoError(t, err) assert.Equal(t, core.ExecutionError_SYSTEM, ee.Kind) assert.True(t, ee.IsRecoverable) }) + + t.Run("multi-user-error", func(t *testing.T) { + outputPaths := &pluginsIOMock.OutputFilePaths{} + outputPaths.OnGetErrorPath().Return("s3://errors/error.pb") + + store := &storageMocks.ComposedProtobufStore{} + store.OnReadProtobufMatch(mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + errorDoc := &core.ErrorDocument{ + Error: &core.ContainerError{ + Code: "red", + Message: "hi", + Kind: core.ContainerError_NON_RECOVERABLE, + Origin: core.ExecutionError_USER, + }, + } + errorFilePath := args.Get(1) + incomingErrorDoc := args.Get(2) + assert.NotNil(t, incomingErrorDoc) + casted := incomingErrorDoc.(*core.ErrorDocument) + casted.Error = errorDoc.Error + casted.Error.Message = fmt.Sprintf("%s-%s", casted.Error.Message, errorFilePath) + }).Return(nil) + + store.OnList(ctx, storage.DataReference("s3://errors/error"), 1000, storage.NewCursorAtStart()).Return( + []storage.DataReference{"error-0.pb", "error-1.pb", "error-2.pb"}, storage.NewCursorAtEnd(), nil) + + maxPayloadSize := int64(0) + r := NewRemoteFileOutputReaderWithErrorAggregationStrategy( + ctx, + store, + outputPaths, + maxPayloadSize, + k8s.EarliestErrorAggregationStrategy, + ) + + hasError, err := r.IsError(ctx) + assert.NoError(t, err) + assert.True(t, hasError) + + executionError, err := r.ReadError(ctx) + assert.NoError(t, err) + assert.Equal(t, core.ExecutionError_USER, executionError.Kind) + assert.Equal(t, "red", executionError.Code) + assert.Equal(t, "hi-error-2.pb", executionError.Message) + assert.False(t, executionError.IsRecoverable) + }) } diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 38a84f9b2b..2e4a531fd1 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -30,6 +30,16 @@ type PluginEntry struct { CustomKubeClient func(ctx context.Context) (pluginsCore.KubeClient, error) } +type ErrorAggregationStrategy int + +const ( + // Single error file from a single container + DefaultErrorAggregationStrategy ErrorAggregationStrategy = iota + + // Earliest error from potentially multiple error files + EarliestErrorAggregationStrategy +) + // System level properties that this Plugin supports type PluginProperties struct { // Disables the inclusion of OwnerReferences in kubernetes resources that this plugin is responsible for. @@ -45,6 +55,8 @@ type PluginProperties struct { // override that behavior unless the resource that gets created for this plugin does not consume resources (cluster's // cpu/memory... etc. or external resources) once the plugin's Plugin.GetTaskPhase() returns a terminal phase. DisableDeleteResourceOnFinalize bool + // Specifies how errors are aggregated + ErrorAggregationStrategy ErrorAggregationStrategy } // Special context passed in to plugins when checking task phase diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go index 42d3ad9b85..17935a89e7 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -290,7 +290,9 @@ func (e *PluginManager) checkResourcePhase(ctx context.Context, tCtx pluginsCore var opReader io.OutputReader if pCtx.ow == nil { logger.Infof(ctx, "Plugin [%s] returned no outputReader, assuming file based outputs", e.id) - opReader = ioutils.NewRemoteFileOutputReader(ctx, tCtx.DataStore(), tCtx.OutputWriter(), 0) + opReader = ioutils.NewRemoteFileOutputReaderWithErrorAggregationStrategy( + ctx, tCtx.DataStore(), tCtx.OutputWriter(), 0, + e.plugin.GetProperties().ErrorAggregationStrategy) } else { logger.Infof(ctx, "Plugin [%s] returned outputReader", e.id) opReader = pCtx.ow.GetReader() diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go index 52e6905513..3d53a4d25f 100644 --- a/flytestdlib/storage/storage.go +++ b/flytestdlib/storage/storage.go @@ -8,6 +8,7 @@ package storage import ( "context" + "fmt" "io" "net/url" "strings" @@ -171,3 +172,7 @@ func (r DataReference) Split() (scheme, container, key string, err error) { func (r DataReference) String() string { return string(r) } + +func NewDataReference(scheme string, container string, key string) DataReference { + return DataReference(fmt.Sprintf("%s://%s/%s", scheme, container, key)) +} diff --git a/flytestdlib/storage/stow_store.go b/flytestdlib/storage/stow_store.go index 6b731b9c86..c1950c10de 100644 --- a/flytestdlib/storage/stow_store.go +++ b/flytestdlib/storage/stow_store.go @@ -255,13 +255,13 @@ func (s *StowStore) Head(ctx context.Context, reference DataReference) (Metadata } func (s *StowStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { - _, c, k, err := reference.Split() + scheme, containerName, key, err := reference.Split() if err != nil { s.metrics.BadReference.Inc(ctx) return nil, NewCursorAtEnd(), err } - container, err := s.getContainer(ctx, locationIDMain, c) + container, err := s.getContainer(ctx, locationIDMain, containerName) if err != nil { return nil, NewCursorAtEnd(), err } @@ -275,11 +275,11 @@ func (s *StowStore) List(ctx context.Context, reference DataReference, maxItems } else { stowCursor = cursor.customPosition } - items, stowCursor, err := container.Items(k, stowCursor, maxItems) + items, stowCursor, err := container.Items(key, stowCursor, maxItems) if err == nil { results := make([]DataReference, len(items)) for index, item := range items { - results[index] = DataReference(item.URL().String()) + results[index] = DataReference(fmt.Sprintf("%s://%s/%s", scheme, containerName, item.URL().String())) } if stow.IsCursorEnd(stowCursor) { cursor = NewCursorAtEnd() @@ -291,7 +291,7 @@ func (s *StowStore) List(ctx context.Context, reference DataReference, maxItems } incFailureCounterForError(ctx, s.metrics.ListFailure, err) - return nil, NewCursorAtEnd(), errs.Wrapf(err, "path:%v", k) + return nil, NewCursorAtEnd(), errs.Wrapf(err, "path:%v", key) } func (s *StowStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { diff --git a/flytestdlib/storage/stow_store_test.go b/flytestdlib/storage/stow_store_test.go index 4de273dd93..aec59051f3 100644 --- a/flytestdlib/storage/stow_store_test.go +++ b/flytestdlib/storage/stow_store_test.go @@ -419,7 +419,7 @@ func TestStowStore_List(t *testing.T) { items, cursor, err := s.List(ctx, dataReference, maxResults, NewCursorAtStart()) assert.NoError(t, err) assert.Equal(t, NewCursorAtEnd(), cursor) - assert.Equal(t, []DataReference{"a/1", "a/2"}, items) + assert.Equal(t, []DataReference{"s3://container/a/1", "s3://container/a/2"}, items) }) t.Run("Listing with pagination", func(t *testing.T) { @@ -446,10 +446,10 @@ func TestStowStore_List(t *testing.T) { var dataReference DataReference = "s3://container/a" items, cursor, err := s.List(ctx, dataReference, maxResults, NewCursorAtStart()) assert.NoError(t, err) - assert.Equal(t, []DataReference{"a/1"}, items) + assert.Equal(t, []DataReference{"s3://container/a/1"}, items) items, _, err = s.List(ctx, dataReference, maxResults, cursor) assert.NoError(t, err) - assert.Equal(t, []DataReference{"a/2"}, items) + assert.Equal(t, []DataReference{"s3://container/a/2"}, items) }) }