Skip to content

Commit

Permalink
Add multi file error aggregation strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
bgedik committed Oct 1, 2024
1 parent 7989209 commit d74d7a2
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,90 @@ package ioutils
import (
"context"
"fmt"
"math"
"path/filepath"
"strings"

"github.com/pkg/errors"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flytestdlib/storage"
)

type RemoteFileOutputReader struct {
outPath io.OutputFilePaths
type ErrorRetriever interface {
HasError(ctx context.Context) (bool, error)
GetError(ctx context.Context) (io.ExecutionError, error)
}

type ErrorRetrieverBase struct {
store storage.ComposedProtobufStore
maxPayloadSize int64
}

func (r RemoteFileOutputReader) IsError(ctx context.Context) (bool, error) {
metadata, err := r.store.Head(ctx, r.outPath.GetErrorPath())
type SingleFileErrorRetriever struct {
ErrorRetrieverBase
errorFilePath storage.DataReference
}

func NewSingleFileErrorRetriever(errorFilePath storage.DataReference, store storage.ComposedProtobufStore, maxPayloadSize int64) *SingleFileErrorRetriever {
return &SingleFileErrorRetriever{
ErrorRetrieverBase: ErrorRetrieverBase{
store: store,
maxPayloadSize: maxPayloadSize,
},
errorFilePath: errorFilePath,
}
}

func (s *SingleFileErrorRetriever) HasError(ctx context.Context) (bool, error) {
metadata, err := s.store.Head(ctx, s.errorFilePath)
if err != nil {
return false, errors.Wrapf(err, "failed to read error file @[%s]", r.outPath.GetErrorPath())
return false, errors.Wrapf(err, "failed to read error file @[%s]", s.errorFilePath)
}
if metadata.Exists() {
if metadata.Size() > r.maxPayloadSize {
return false, errors.Wrapf(err, "error file @[%s] is too large [%d] bytes, max allowed [%d] bytes", r.outPath.GetErrorPath(), metadata.Size(), r.maxPayloadSize)
if metadata.Size() > s.maxPayloadSize {
return false, errors.Wrapf(err, "error file @[%s] is too large [%d] bytes, max allowed [%d] bytes", s.errorFilePath, metadata.Size(), s.maxPayloadSize)
}
return true, nil
}
return false, nil
}

func (r RemoteFileOutputReader) ReadError(ctx context.Context) (io.ExecutionError, error) {
func errorDoc2ExecutionError(errorDoc *core.ErrorDocument, errorFilePath storage.DataReference) io.ExecutionError {
if errorDoc.Error == nil {
return io.ExecutionError{
IsRecoverable: true,
ExecutionError: &core.ExecutionError{
Code: "ErrorFileBadFormat",
Message: fmt.Sprintf("error not formatted correctly, nil error @path [%s]", errorFilePath),
Kind: core.ExecutionError_SYSTEM,
},
}
}
executionError := io.ExecutionError{
ExecutionError: &core.ExecutionError{
Code: errorDoc.Error.Code,
Message: errorDoc.Error.Message,
Kind: errorDoc.Error.Origin,
},
}

if errorDoc.Error.Kind == core.ContainerError_RECOVERABLE {
executionError.IsRecoverable = true
}

if errorDoc.Error.Kind == core.ContainerError_RECOVERABLE {
executionError.IsRecoverable = true
}

return executionError
}

func (s *SingleFileErrorRetriever) GetError(ctx context.Context) (io.ExecutionError, error) {
errorDoc := &core.ErrorDocument{}
err := r.store.ReadProtobuf(ctx, r.outPath.GetErrorPath(), errorDoc)
err := s.store.ReadProtobuf(ctx, storage.DataReference(s.errorFilePath), errorDoc)
if err != nil {
if storage.IsNotFound(err) {
return io.ExecutionError{
Expand All @@ -45,33 +98,143 @@ func (r RemoteFileOutputReader) ReadError(ctx context.Context) (io.ExecutionErro
},
}, nil
}
return io.ExecutionError{}, errors.Wrapf(err, "failed to read error data from task @[%s]", r.outPath.GetErrorPath())
return io.ExecutionError{}, errors.Wrapf(err, "failed to read error data from task @[%s]", s.errorFilePath)
}

if errorDoc.Error == nil {
return io.ExecutionError{
IsRecoverable: true,
ExecutionError: &core.ExecutionError{
Code: "ErrorFileBadFormat",
Message: fmt.Sprintf("error not formatted correctly, nil error @path [%s]", r.outPath.GetErrorPath()),
Kind: core.ExecutionError_SYSTEM,
},
}, nil
return errorDoc2ExecutionError(errorDoc, s.errorFilePath), nil
}

type EarliestFileErrorRetriever struct {
ErrorRetrieverBase
errorDirPath storage.DataReference
canonicalErrorFilename string
}

func (e *EarliestFileErrorRetriever) parseErrorFilename() (errorFilePathPrefix storage.DataReference, errorFileExtension string, err error) {
// If the canonical error file name is error.pb, we expect multiple error files
// to have name error<suffix>.pb
pieces := strings.Split(e.canonicalErrorFilename, ".")
if len(pieces) != 2 {
err = errors.Errorf("expected canoncal error filename to have a single ., got %d", len(pieces))

Check failure on line 118 in flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go

View workflow job for this annotation

GitHub Actions / Check for spelling errors

canoncal ==> canonical
return
}
errorFilePrefix := pieces[0]
scheme, container, key, _ := e.errorDirPath.Split()
errorFilePathPrefix = storage.NewDataReference(scheme, container, filepath.Join(key, errorFilePrefix))
errorFileExtension = fmt.Sprintf(".%s", pieces[1])
return
}

ee := io.ExecutionError{
ExecutionError: &core.ExecutionError{
Code: errorDoc.Error.Code,
Message: errorDoc.Error.Message,
Kind: errorDoc.Error.Origin,
func (e *EarliestFileErrorRetriever) HasError(ctx context.Context) (bool, error) {
errorFilePathPrefix, errorFileExtension, err := e.parseErrorFilename()
if err != nil {
return false, errors.Wrapf(err, "failed to parse canonical error filename @[%s]", e.canonicalErrorFilename)
}
const maxItems = 1000
cursor := storage.NewCursorAtStart()
for cursor != storage.NewCursorAtEnd() {
var err error
var errorFilePaths []storage.DataReference
errorFilePaths, cursor, err = e.store.List(ctx, errorFilePathPrefix, maxItems, cursor)
if err != nil {
return false, errors.Wrapf(err, "failed to list error files @[%s]", e.errorDirPath)
}
for _, errorFilePath := range errorFilePaths {
if strings.HasSuffix(errorFilePath.String(), errorFileExtension) {
return true, nil
}
}
}
return false, nil
}

func (e *EarliestFileErrorRetriever) GetError(ctx context.Context) (io.ExecutionError, error) {
errorFilePathPrefix, errorFileExtension, err := e.parseErrorFilename()
if err != nil {
return io.ExecutionError{}, errors.Wrapf(err, "failed to parse canonical error filename @[%s]", e.canonicalErrorFilename)
}
const maxItems = 1000
cursor := storage.NewCursorAtStart()
type ErrorFileAndDocument struct {
errorFilePath storage.DataReference
errorDoc *core.ErrorDocument
}
var errorFileAndDocs []ErrorFileAndDocument
for cursor != storage.NewCursorAtEnd() {
var err error
var errorFilePaths []storage.DataReference
errorFilePaths, cursor, err = e.store.List(ctx, errorFilePathPrefix, maxItems, cursor)
if err != nil {
return io.ExecutionError{}, errors.Wrapf(err, "failed to list error files @[%s]", e.errorDirPath)
}
for _, errorFilePath := range errorFilePaths {
if strings.HasSuffix(errorFilePath.String(), errorFileExtension) {
errorDoc := &core.ErrorDocument{}
err := e.store.ReadProtobuf(ctx, errorFilePath, errorDoc)
if err != nil {
return io.ExecutionError{}, errors.Wrapf(err, "failed to read error file @[%s]", errorFilePath.String())
}
errorFileAndDocs = append(errorFileAndDocs, ErrorFileAndDocument{errorFilePath: errorFilePath, errorDoc: errorDoc})
}
}
}

extractTimestampFromErrorDoc := func(errorDoc *core.ErrorDocument) int64 {
// TODO: add optional timestamp to ErrorDocument
if errorDoc == nil {
panic("")
}
return 0
}

var earliestTimestamp int64 = math.MaxInt64
earliestExecutionError := io.ExecutionError{}
for _, errorFileAndDoc := range errorFileAndDocs {
timestamp := extractTimestampFromErrorDoc(errorFileAndDoc.errorDoc)
if earliestTimestamp >= timestamp {
earliestExecutionError = errorDoc2ExecutionError(errorFileAndDoc.errorDoc, errorFileAndDoc.errorFilePath)
earliestTimestamp = timestamp
}
}
return earliestExecutionError, nil
}

func NewEarliestFileErrorRetriever(errorDirPath storage.DataReference, canonicalErrorFilename string, store storage.ComposedProtobufStore, maxPayloadSize int64) *EarliestFileErrorRetriever {
return &EarliestFileErrorRetriever{
ErrorRetrieverBase: ErrorRetrieverBase{
store: store,
maxPayloadSize: maxPayloadSize,
},
errorDirPath: errorDirPath,
canonicalErrorFilename: canonicalErrorFilename,
}
}

if errorDoc.Error.Kind == core.ContainerError_RECOVERABLE {
ee.IsRecoverable = true
func NewErrorRetriever(errorAggregationStrategy k8s.ErrorAggregationStrategy, errorDirPath storage.DataReference, errorFilename string, store storage.ComposedProtobufStore, maxPayloadSize int64) ErrorRetriever {
if errorAggregationStrategy == k8s.DefaultErrorAggregationStrategy {
scheme, container, key, _ := errorDirPath.Split()
errorFilePath := storage.NewDataReference(scheme, container, filepath.Join(key, errorFilename))
return NewSingleFileErrorRetriever(errorFilePath, store, maxPayloadSize)
}
if errorAggregationStrategy == k8s.EarliestErrorAggregationStrategy {
return NewEarliestFileErrorRetriever(errorDirPath, errorFilename, store, maxPayloadSize)
}
return nil
}

type RemoteFileOutputReader struct {
outPath io.OutputFilePaths
store storage.ComposedProtobufStore
maxPayloadSize int64
errorRetriever ErrorRetriever
}

return ee, nil
func (r RemoteFileOutputReader) IsError(ctx context.Context) (bool, error) {
return r.errorRetriever.HasError(ctx)
}

func (r RemoteFileOutputReader) ReadError(ctx context.Context) (io.ExecutionError, error) {
return r.errorRetriever.GetError(ctx)
}

func (r RemoteFileOutputReader) Exists(ctx context.Context) (bool, error) {
Expand Down Expand Up @@ -122,16 +285,25 @@ func (r RemoteFileOutputReader) DeckExists(ctx context.Context) (bool, error) {
return md.Exists(), nil
}

func NewRemoteFileOutputReader(_ context.Context, store storage.ComposedProtobufStore, outPaths io.OutputFilePaths, maxDatasetSize int64) RemoteFileOutputReader {
func NewRemoteFileOutputReader(context context.Context, store storage.ComposedProtobufStore, outPaths io.OutputFilePaths, maxDatasetSize int64) RemoteFileOutputReader {
return NewRemoteFileOutputReaderWithErrorAggregationStrategy(context, store, outPaths, maxDatasetSize, k8s.DefaultErrorAggregationStrategy)
}

func NewRemoteFileOutputReaderWithErrorAggregationStrategy(_ context.Context, store storage.ComposedProtobufStore, outPaths io.OutputFilePaths, maxDatasetSize int64, errorAggregationStrategy k8s.ErrorAggregationStrategy) RemoteFileOutputReader {
// Note: even though the data store retrieval checks against GetLimitMegabytes, there might be external
// storage implementations, so we keep this check here as well.
maxPayloadSize := maxDatasetSize
if maxPayloadSize == 0 {
maxPayloadSize = storage.GetConfig().Limits.GetLimitMegabytes * 1024 * 1024
}
scheme, container, key, _ := outPaths.GetErrorPath().Split()
errorFilename := filepath.Base(key)
errorDirPath := storage.NewDataReference(scheme, container, filepath.Dir(key))
errorRetriever := NewErrorRetriever(errorAggregationStrategy, errorDirPath, errorFilename, store, maxPayloadSize)
return RemoteFileOutputReader{
outPath: outPaths,
store: store,
maxPayloadSize: maxPayloadSize,
errorRetriever: errorRetriever,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package ioutils

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsIOMock "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flytestdlib/storage"
storageMocks "github.com/flyteorg/flyte/flytestdlib/storage/mocks"
)
Expand Down Expand Up @@ -65,11 +67,13 @@ func TestReadOrigin(t *testing.T) {
exists: true,
}, nil)

r := RemoteFileOutputReader{
outPath: opath,
store: store,
maxPayloadSize: 0,
}
maxPayloadSize := int64(0)
r := NewRemoteFileOutputReader(
ctx,
store,
opath,
maxPayloadSize,
)

ee, err := r.ReadError(ctx)
assert.NoError(t, err)
Expand Down Expand Up @@ -97,15 +101,63 @@ func TestReadOrigin(t *testing.T) {
casted.Error = errorDoc.Error
}).Return(nil)

r := RemoteFileOutputReader{
outPath: opath,
store: store,
maxPayloadSize: 0,
}
maxPayloadSize := int64(0)
r := NewRemoteFileOutputReader(
ctx,
store,
opath,
maxPayloadSize,
)

ee, err := r.ReadError(ctx)
assert.NoError(t, err)
assert.Equal(t, core.ExecutionError_SYSTEM, ee.Kind)
assert.True(t, ee.IsRecoverable)
})

t.Run("multi-user-error", func(t *testing.T) {
outputPaths := &pluginsIOMock.OutputFilePaths{}
outputPaths.OnGetErrorPath().Return("s3://errors/error.pb")

store := &storageMocks.ComposedProtobufStore{}
store.OnReadProtobufMatch(mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
errorDoc := &core.ErrorDocument{
Error: &core.ContainerError{
Code: "red",
Message: "hi",
Kind: core.ContainerError_NON_RECOVERABLE,
Origin: core.ExecutionError_USER,
},
}
errorFilePath := args.Get(1)
incomingErrorDoc := args.Get(2)
assert.NotNil(t, incomingErrorDoc)
casted := incomingErrorDoc.(*core.ErrorDocument)
casted.Error = errorDoc.Error
casted.Error.Message = fmt.Sprintf("%s-%s", casted.Error.Message, errorFilePath)
}).Return(nil)

store.OnList(ctx, storage.DataReference("s3://errors/error"), 1000, storage.NewCursorAtStart()).Return(
[]storage.DataReference{"error-0.pb", "error-1.pb", "error-2.pb"}, storage.NewCursorAtEnd(), nil)

maxPayloadSize := int64(0)
r := NewRemoteFileOutputReaderWithErrorAggregationStrategy(
ctx,
store,
outputPaths,
maxPayloadSize,
k8s.EarliestErrorAggregationStrategy,
)

hasError, err := r.IsError(ctx)
assert.NoError(t, err)
assert.True(t, hasError)

executionError, err := r.ReadError(ctx)
assert.NoError(t, err)
assert.Equal(t, core.ExecutionError_USER, executionError.Kind)
assert.Equal(t, "red", executionError.Code)
assert.Equal(t, "hi-error-2.pb", executionError.Message)
assert.False(t, executionError.IsRecoverable)
})
}
Loading

0 comments on commit d74d7a2

Please sign in to comment.