From e67aae0540f6b72475336922147352c7e92bad86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bu=C4=9Fra=20Gedik?= Date: Wed, 11 Sep 2024 23:49:07 -0700 Subject: [PATCH 01/17] Add listing api to stow storage (#5741) Signed-off-by: Bugra Gedik --- flyteadmin/pkg/common/mocks/storage.go | 4 + flytepropeller/pkg/utils/failing_datastore.go | 4 + flytestdlib/storage/cached_rawstore_test.go | 4 + flytestdlib/storage/mem_store.go | 4 + .../storage/mocks/composed_protobuf_store.go | 48 ++++++++++ flytestdlib/storage/storage.go | 38 ++++++++ flytestdlib/storage/stow_store.go | 46 +++++++++ flytestdlib/storage/stow_store_test.go | 96 ++++++++++++++++++- script/generate_helm.sh | 3 +- 9 files changed, 244 insertions(+), 3 deletions(-) diff --git a/flyteadmin/pkg/common/mocks/storage.go b/flyteadmin/pkg/common/mocks/storage.go index 7e91bf0485..bf29eedd3e 100644 --- a/flyteadmin/pkg/common/mocks/storage.go +++ b/flyteadmin/pkg/common/mocks/storage.go @@ -33,6 +33,10 @@ func (t *TestDataStore) Head(ctx context.Context, reference storage.DataReferenc return t.HeadCb(ctx, reference) } +func (t *TestDataStore) List(ctx context.Context, reference storage.DataReference, maxItems int, cursor storage.Cursor) ([]storage.DataReference, storage.Cursor, error) { + return nil, storage.NewCursorAtEnd(), fmt.Errorf("Not implemented yet") +} + func (t *TestDataStore) ReadProtobuf(ctx context.Context, reference storage.DataReference, msg proto.Message) error { return t.ReadProtobufCb(ctx, reference, msg) } diff --git a/flytepropeller/pkg/utils/failing_datastore.go b/flytepropeller/pkg/utils/failing_datastore.go index f3b65471c7..7948a85b81 100644 --- a/flytepropeller/pkg/utils/failing_datastore.go +++ b/flytepropeller/pkg/utils/failing_datastore.go @@ -27,6 +27,10 @@ func (FailingRawStore) Head(ctx context.Context, reference storage.DataReference return nil, fmt.Errorf("failed metadata fetch") } +func (FailingRawStore) List(ctx context.Context, reference storage.DataReference, maxItems int, cursor storage.Cursor) ([]storage.DataReference, storage.Cursor, error) { + return nil, storage.NewCursorAtEnd(), fmt.Errorf("Not implemented yet") +} + func (FailingRawStore) ReadRaw(ctx context.Context, reference storage.DataReference) (io.ReadCloser, error) { return nil, fmt.Errorf("failed read raw") } diff --git a/flytestdlib/storage/cached_rawstore_test.go b/flytestdlib/storage/cached_rawstore_test.go index b9751d7fa1..9c304790cb 100644 --- a/flytestdlib/storage/cached_rawstore_test.go +++ b/flytestdlib/storage/cached_rawstore_test.go @@ -73,6 +73,10 @@ func (d *dummyStore) Head(ctx context.Context, reference DataReference) (Metadat return d.HeadCb(ctx, reference) } +func (d *dummyStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { + return nil, NewCursorAtEnd(), fmt.Errorf("Not implemented yet") +} + func (d *dummyStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { return d.ReadRawCb(ctx, reference) } diff --git a/flytestdlib/storage/mem_store.go b/flytestdlib/storage/mem_store.go index a95a0a49ca..94083f6646 100644 --- a/flytestdlib/storage/mem_store.go +++ b/flytestdlib/storage/mem_store.go @@ -54,6 +54,10 @@ func (s *InMemoryStore) Head(ctx context.Context, reference DataReference) (Meta }, nil } +func (s *InMemoryStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { + return nil, NewCursorAtEnd(), fmt.Errorf("Not implemented yet") +} + func (s *InMemoryStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { if raw, found := s.cache[reference]; found { return ioutil.NopCloser(bytes.NewReader(raw)), nil diff --git a/flytestdlib/storage/mocks/composed_protobuf_store.go b/flytestdlib/storage/mocks/composed_protobuf_store.go index c9064c2ac5..49a0ee89dd 100644 --- a/flytestdlib/storage/mocks/composed_protobuf_store.go +++ b/flytestdlib/storage/mocks/composed_protobuf_store.go @@ -194,6 +194,54 @@ func (_m *ComposedProtobufStore) Head(ctx context.Context, reference storage.Dat return r0, r1 } +type ComposedProtobufStore_List struct { + *mock.Call +} + +func (_m ComposedProtobufStore_List) Return(_a0 []storage.DataReference, _a1 storage.Cursor, _a2 error) *ComposedProtobufStore_List { + return &ComposedProtobufStore_List{Call: _m.Call.Return(_a0, _a1, _a2)} +} + +func (_m *ComposedProtobufStore) OnList(ctx context.Context, reference storage.DataReference, maxItems int, cursor storage.Cursor) *ComposedProtobufStore_List { + c_call := _m.On("List", ctx, reference, maxItems, cursor) + return &ComposedProtobufStore_List{Call: c_call} +} + +func (_m *ComposedProtobufStore) OnListMatch(matchers ...interface{}) *ComposedProtobufStore_List { + c_call := _m.On("List", matchers...) + return &ComposedProtobufStore_List{Call: c_call} +} + +// List provides a mock function with given fields: ctx, reference, maxItems, cursor +func (_m *ComposedProtobufStore) List(ctx context.Context, reference storage.DataReference, maxItems int, cursor storage.Cursor) ([]storage.DataReference, storage.Cursor, error) { + ret := _m.Called(ctx, reference, maxItems, cursor) + + var r0 []storage.DataReference + if rf, ok := ret.Get(0).(func(context.Context, storage.DataReference, int, storage.Cursor) []storage.DataReference); ok { + r0 = rf(ctx, reference, maxItems, cursor) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]storage.DataReference) + } + } + + var r1 storage.Cursor + if rf, ok := ret.Get(1).(func(context.Context, storage.DataReference, int, storage.Cursor) storage.Cursor); ok { + r1 = rf(ctx, reference, maxItems, cursor) + } else { + r1 = ret.Get(1).(storage.Cursor) + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, storage.DataReference, int, storage.Cursor) error); ok { + r2 = rf(ctx, reference, maxItems, cursor) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + type ComposedProtobufStore_ReadProtobuf struct { *mock.Call } diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go index 3e84cb7acb..52e6905513 100644 --- a/flytestdlib/storage/storage.go +++ b/flytestdlib/storage/storage.go @@ -40,6 +40,41 @@ type Metadata interface { ContentMD5() string } +type CursorState int + +const ( + // Enum representing state of the cursor + AtStartCursorState CursorState = 0 + AtEndCursorState CursorState = 1 + AtCustomPosCursorState CursorState = 2 +) + +type Cursor struct { + cursorState CursorState + customPosition string +} + +func NewCursorAtStart() Cursor { + return Cursor{ + cursorState: AtStartCursorState, + customPosition: "", + } +} + +func NewCursorAtEnd() Cursor { + return Cursor{ + cursorState: AtEndCursorState, + customPosition: "", + } +} + +func NewCursorFromCustomPosition(customPosition string) Cursor { + return Cursor{ + cursorState: AtCustomPosCursorState, + customPosition: customPosition, + } +} + // DataStore is a simplified interface for accessing and storing data in one of the Cloud stores. // Today we rely on Stow for multi-cloud support, but this interface abstracts that part type DataStore struct { @@ -78,6 +113,9 @@ type RawStore interface { // Head gets metadata about the reference. This should generally be a light weight operation. Head(ctx context.Context, reference DataReference) (Metadata, error) + // List gets a list of items given a prefix, using a paginated API + List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) + // ReadRaw retrieves a byte array from the Blob store or an error ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) diff --git a/flytestdlib/storage/stow_store.go b/flytestdlib/storage/stow_store.go index ce4a75a0a1..6b731b9c86 100644 --- a/flytestdlib/storage/stow_store.go +++ b/flytestdlib/storage/stow_store.go @@ -92,6 +92,9 @@ type stowMetrics struct { HeadFailure labeled.Counter HeadLatency labeled.StopWatch + ListFailure labeled.Counter + ListLatency labeled.StopWatch + ReadFailure labeled.Counter ReadOpenLatency labeled.StopWatch @@ -251,6 +254,46 @@ func (s *StowStore) Head(ctx context.Context, reference DataReference) (Metadata return StowMetadata{exists: false}, errs.Wrapf(err, "path:%v", k) } +func (s *StowStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { + _, c, k, err := reference.Split() + if err != nil { + s.metrics.BadReference.Inc(ctx) + return nil, NewCursorAtEnd(), err + } + + container, err := s.getContainer(ctx, locationIDMain, c) + if err != nil { + return nil, NewCursorAtEnd(), err + } + + t := s.metrics.ListLatency.Start(ctx) + var stowCursor string + if cursor.cursorState == AtStartCursorState { + stowCursor = stow.CursorStart + } else if cursor.cursorState == AtEndCursorState { + return nil, NewCursorAtEnd(), fmt.Errorf("Cursor cannot be at end for the List call") + } else { + stowCursor = cursor.customPosition + } + items, stowCursor, err := container.Items(k, stowCursor, maxItems) + if err == nil { + results := make([]DataReference, len(items)) + for index, item := range items { + results[index] = DataReference(item.URL().String()) + } + if stow.IsCursorEnd(stowCursor) { + cursor = NewCursorAtEnd() + } else { + cursor = NewCursorFromCustomPosition(stowCursor) + } + t.Stop() + return results, cursor, nil + } + + incFailureCounterForError(ctx, s.metrics.ListFailure, err) + return nil, NewCursorAtEnd(), errs.Wrapf(err, "path:%v", k) +} + func (s *StowStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { _, c, k, err := reference.Split() if err != nil { @@ -434,6 +477,9 @@ func newStowMetrics(scope promutils.Scope) *stowMetrics { HeadFailure: labeled.NewCounter("head_failure", "Indicates failure in HEAD for a given reference", scope, labeled.EmitUnlabeledMetric), HeadLatency: labeled.NewStopWatch("head", "Indicates time to fetch metadata using the Head API", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + ListFailure: labeled.NewCounter("list_failure", "Indicates failure in item listing for a given reference", scope, labeled.EmitUnlabeledMetric), + ListLatency: labeled.NewStopWatch("list", "Indicates time to fetch item listing using the List API", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + ReadFailure: labeled.NewCounter("read_failure", "Indicates failure in GET for a given reference", scope, labeled.EmitUnlabeledMetric, failureTypeOption), ReadOpenLatency: labeled.NewStopWatch("read_open", "Indicates time to first byte when reading", time.Millisecond, scope, labeled.EmitUnlabeledMetric), diff --git a/flytestdlib/storage/stow_store_test.go b/flytestdlib/storage/stow_store_test.go index 99678eb8ad..4de273dd93 100644 --- a/flytestdlib/storage/stow_store_test.go +++ b/flytestdlib/storage/stow_store_test.go @@ -10,6 +10,8 @@ import ( "net/url" "os" "path/filepath" + "sort" + "strconv" "testing" "time" @@ -73,8 +75,37 @@ func (m mockStowContainer) Item(id string) (stow.Item, error) { return nil, stow.ErrNotFound } -func (mockStowContainer) Items(prefix, cursor string, count int) ([]stow.Item, string, error) { - return []stow.Item{}, "", nil +func (m mockStowContainer) Items(prefix, cursor string, count int) ([]stow.Item, string, error) { + startIndex := 0 + if cursor != "" { + index, err := strconv.Atoi(cursor) + if err != nil { + return nil, "", fmt.Errorf("Invalid cursor '%s'", cursor) + } + startIndex = index + } + endIndexExc := min(len(m.items), startIndex+count) + + itemKeys := make([]string, len(m.items)) + index := 0 + for key := range m.items { + itemKeys[index] = key + index++ + } + sort.Strings(itemKeys) + + numItems := endIndexExc - startIndex + results := make([]stow.Item, numItems) + for index, itemKey := range itemKeys[startIndex:endIndexExc] { + results[index] = m.items[itemKey] + } + + if endIndexExc == len(m.items) { + cursor = "" + } else { + cursor = fmt.Sprintf("%d", endIndexExc) + } + return results, cursor, nil } func (m mockStowContainer) RemoveItem(id string) error { @@ -361,6 +392,67 @@ func TestStowStore_ReadRaw(t *testing.T) { }) } +func TestStowStore_List(t *testing.T) { + const container = "container" + t.Run("Listing", func(t *testing.T) { + ctx := context.Background() + fn := fQNFn["s3"] + s, err := NewStowRawStore(fn(container), &mockStowLoc{ + ContainerCb: func(id string) (stow.Container, error) { + if id == container { + return newMockStowContainer(container), nil + } + return nil, fmt.Errorf("container is not supported") + }, + CreateContainerCb: func(name string) (stow.Container, error) { + if name == container { + return newMockStowContainer(container), nil + } + return nil, fmt.Errorf("container is not supported") + }, + }, nil, false, metrics) + assert.NoError(t, err) + writeTestFile(ctx, t, s, "s3://container/a/1") + writeTestFile(ctx, t, s, "s3://container/a/2") + var maxResults = 10 + var dataReference DataReference = "s3://container/a" + items, cursor, err := s.List(ctx, dataReference, maxResults, NewCursorAtStart()) + assert.NoError(t, err) + assert.Equal(t, NewCursorAtEnd(), cursor) + assert.Equal(t, []DataReference{"a/1", "a/2"}, items) + }) + + t.Run("Listing with pagination", func(t *testing.T) { + ctx := context.Background() + fn := fQNFn["s3"] + s, err := NewStowRawStore(fn(container), &mockStowLoc{ + ContainerCb: func(id string) (stow.Container, error) { + if id == container { + return newMockStowContainer(container), nil + } + return nil, fmt.Errorf("container is not supported") + }, + CreateContainerCb: func(name string) (stow.Container, error) { + if name == container { + return newMockStowContainer(container), nil + } + return nil, fmt.Errorf("container is not supported") + }, + }, nil, false, metrics) + assert.NoError(t, err) + writeTestFile(ctx, t, s, "s3://container/a/1") + writeTestFile(ctx, t, s, "s3://container/a/2") + var maxResults = 1 + var dataReference DataReference = "s3://container/a" + items, cursor, err := s.List(ctx, dataReference, maxResults, NewCursorAtStart()) + assert.NoError(t, err) + assert.Equal(t, []DataReference{"a/1"}, items) + items, _, err = s.List(ctx, dataReference, maxResults, cursor) + assert.NoError(t, err) + assert.Equal(t, []DataReference{"a/2"}, items) + }) +} + func TestNewLocalStore(t *testing.T) { labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) t.Run("Valid config", func(t *testing.T) { diff --git a/script/generate_helm.sh b/script/generate_helm.sh index 1c836b9002..a0ae15c019 100755 --- a/script/generate_helm.sh +++ b/script/generate_helm.sh @@ -7,7 +7,8 @@ echo "Generating Helm" HELM_SKIP_INSTALL=${HELM_SKIP_INSTALL:-false} if [ "${HELM_SKIP_INSTALL}" != "true" ]; then - curl https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash + # See https://github.com/helm/helm/issues/13324 for a breaking change in latest version of helm + curl https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | DESIRED_VERSION=v3.15.4 bash fi helm version From 59bf19158eb8f60613055dcc3f61ee9b71be0531 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Thu, 12 Sep 2024 11:17:45 -0700 Subject: [PATCH 02/17] Use latest upload/download-artifact action version (#5743) Signed-off-by: Jason Parraga --- .github/workflows/single-binary.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/single-binary.yml b/.github/workflows/single-binary.yml index c40b33e3a4..27ed4fabbc 100644 --- a/.github/workflows/single-binary.yml +++ b/.github/workflows/single-binary.yml @@ -94,7 +94,7 @@ jobs: file: Dockerfile outputs: type=docker,dest=docker/sandbox-bundled/images/tar/amd64/flyte-binary.tar - name: Upload single binary image - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: single-binary-image path: docker/sandbox-bundled/images/tar @@ -133,7 +133,7 @@ jobs: echo "FLYTESNACKS_VERSION=${FLYTESNACKS_VERSION}" >> ${GITHUB_ENV} - name: Checkout uses: actions/checkout@v4 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: single-binary-image path: docker/sandbox-bundled/images/tar @@ -207,7 +207,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: single-binary-image path: docker/sandbox-bundled/images/tar From 44dc0eb3521c81b201e9ae545c47a05cc1afc632 Mon Sep 17 00:00:00 2001 From: Rob Ulbrich <141313113+robert-ulbrich-mercedes-benz@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:33:53 +0200 Subject: [PATCH 03/17] Introduced SMTP notification (#5535) Signed-off-by: Rob Ulbrich --- flyteadmin/go.mod | 2 +- flyteadmin/pkg/async/notifications/factory.go | 200 ++++++- .../pkg/async/notifications/factory_test.go | 32 +- .../notifications/implementations/emailers.go | 1 + .../implementations/smtp_emailer.go | 158 ++++++ .../implementations/smtp_emailer_test.go | 498 ++++++++++++++++++ .../notifications/interfaces/smtp_client.go | 22 + .../async/notifications/mocks/smtp_client.go | 321 +++++++++++ flyteadmin/pkg/rpc/adminservice/base.go | 5 +- .../interfaces/application_configuration.go | 9 +- flyteadmin/pkg/server/service.go | 17 +- 11 files changed, 1243 insertions(+), 22 deletions(-) create mode 100644 flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go create mode 100644 flyteadmin/pkg/async/notifications/interfaces/smtp_client.go create mode 100644 flyteadmin/pkg/async/notifications/mocks/smtp_client.go diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index b9eba5b83a..2eec0f8cf3 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -51,6 +51,7 @@ require ( github.com/wolfeidau/humanhash v1.1.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0 go.opentelemetry.io/otel v1.24.0 + golang.org/x/net v0.27.0 golang.org/x/oauth2 v0.16.0 golang.org/x/time v0.5.0 google.golang.org/api v0.155.0 @@ -189,7 +190,6 @@ require ( go.opentelemetry.io/proto/otlp v1.1.0 // indirect golang.org/x/crypto v0.25.0 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect - golang.org/x/net v0.27.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/term v0.22.0 // indirect diff --git a/flyteadmin/pkg/async/notifications/factory.go b/flyteadmin/pkg/async/notifications/factory.go index f94129a1d5..483978238e 100644 --- a/flyteadmin/pkg/async/notifications/factory.go +++ b/flyteadmin/pkg/async/notifications/factory.go @@ -18,6 +18,7 @@ import ( "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" "github.com/flyteorg/flyte/flyteadmin/pkg/common" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" ) @@ -27,6 +28,7 @@ const maxRetries = 3 var enable64decoding = false var msgChan chan []byte + var once sync.Once type PublisherConfig struct { @@ -35,220 +37,404 @@ type PublisherConfig struct { type ProcessorConfig struct { QueueName string + AccountID string } type EmailerConfig struct { SenderEmail string - BaseURL string + + BaseURL string } // For sandbox only + func CreateMsgChan() { + once.Do(func() { + msgChan = make(chan []byte) + }) + } -func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Emailer { +func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Emailer { + // If an external email service is specified use that instead. + // TODO: Handling of this is messy, see https://github.com/flyteorg/flyte/issues/1063 + if config.NotificationsEmailerConfig.EmailerConfig.ServiceName != "" { + switch config.NotificationsEmailerConfig.EmailerConfig.ServiceName { + case implementations.Sendgrid: + return implementations.NewSendGridEmailer(config, scope) + + case implementations.SMTP: + + return implementations.NewSMTPEmailer(context.Background(), config, scope, sm) + default: + panic(fmt.Errorf("No matching email implementation for %s", config.NotificationsEmailerConfig.EmailerConfig.ServiceName)) + } + } switch config.Type { + case common.AWS: + region := config.AWSConfig.Region + if region == "" { + region = config.Region + } + awsConfig := aws.NewConfig().WithRegion(region).WithMaxRetries(maxRetries) + awsSession, err := session.NewSession(awsConfig) + if err != nil { + panic(err) + } + sesClient := ses.New(awsSession) + return implementations.NewAwsEmailer( + config, + scope, + sesClient, ) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), "Using default noop emailer implementation for config type [%s]", config.Type) + return implementations.NewNoopEmail() + } + } -func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Processor { +func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Processor { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + var sub pubsub.Subscriber + var emailer interfaces.Emailer + switch config.Type { + case common.AWS: + sqsConfig := gizmoAWS.SQSConfig{ - QueueName: config.NotificationsProcessorConfig.QueueName, + + QueueName: config.NotificationsProcessorConfig.QueueName, + QueueOwnerAccountID: config.NotificationsProcessorConfig.AccountID, + // The AWS configuration type uses SNS to SQS for notifications. + // Gizmo by default will decode the SQS message using Base64 decoding. + // However, the message body of SQS is the SNS message format which isn't Base64 encoded. + ConsumeBase64: &enable64decoding, } + if config.AWSConfig.Region != "" { + sqsConfig.Region = config.AWSConfig.Region + } else { + sqsConfig.Region = config.Region + } + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoAWS.NewSubscriber(sqsConfig) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo aws subscriber with config [%+v] and err: %v", sqsConfig, err) + } + return err + }) if err != nil { + panic(err) + } - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewProcessor(sub, emailer, scope) + case common.GCP: + projectID := config.GCPConfig.ProjectID + subscription := config.NotificationsProcessorConfig.QueueName + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoGCP.NewSubscriber(context.TODO(), projectID, subscription) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo gcp subscriber with config [ProjectID: %s, Subscription: %s] and err: %v", projectID, subscription, err) + } + return err + }) + if err != nil { + panic(err) + } - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewGcpProcessor(sub, emailer, scope) + case common.Sandbox: - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewSandboxProcessor(msgChan, emailer) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications processor implementation for config type [%s]", config.Type) + return implementations.NewNoopProcess() + } + } func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Publisher { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + switch config.Type { + case common.AWS: + snsConfig := gizmoAWS.SNSConfig{ + Topic: config.NotificationsPublisherConfig.TopicName, } + if config.AWSConfig.Region != "" { + snsConfig.Region = config.AWSConfig.Region + } else { + snsConfig.Region = config.Region + } var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.GCP: + pubsubConfig := gizmoGCP.Config{ + Topic: config.NotificationsPublisherConfig.TopicName, } + pubsubConfig.ProjectID = config.GCPConfig.ProjectID + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.Sandbox: + CreateMsgChan() + return implementations.NewSandboxPublisher(msgChan) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } + } func NewEventsPublisher(config runtimeInterfaces.ExternalEventsConfig, scope promutils.Scope) interfaces.Publisher { + if !config.Enable { + return implementations.NewNoopPublish() + } + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + switch config.Type { + case common.AWS: + snsConfig := gizmoAWS.SNSConfig{ + Topic: config.EventsPublisherConfig.TopicName, } + snsConfig.Region = config.AWSConfig.Region var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewEventsPublisher(publisher, scope, config.EventsPublisherConfig.EventTypes) + case common.GCP: + pubsubConfig := gizmoGCP.Config{ + Topic: config.EventsPublisherConfig.TopicName, } + pubsubConfig.ProjectID = config.GCPConfig.ProjectID + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) if err != nil { + panic(err) + } + return implementations.NewEventsPublisher(publisher, scope, config.EventsPublisherConfig.EventTypes) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop events publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } + } diff --git a/flyteadmin/pkg/async/notifications/factory_test.go b/flyteadmin/pkg/async/notifications/factory_test.go index 1bfd1f4596..43602525a5 100644 --- a/flyteadmin/pkg/async/notifications/factory_test.go +++ b/flyteadmin/pkg/async/notifications/factory_test.go @@ -9,52 +9,76 @@ import ( "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/implementations" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/flytestdlib/promutils" ) var ( - scope = promutils.NewScope("test_sandbox_processor") + scope = promutils.NewScope("test_sandbox_processor") + notificationsConfig = runtimeInterfaces.NotificationsConfig{ + Type: "sandbox", } + testEmail = admin.EmailMessage{ + RecipientsEmail: []string{ + "a@example.com", + "b@example.com", }, + SenderEmail: "no-reply@example.com", + SubjectLine: "Test email", - Body: "This is a sample email.", + + Body: "This is a sample email.", } ) func TestGetEmailer(t *testing.T) { + defer func() { r := recover(); assert.NotNil(t, r) }() + cfg := runtimeInterfaces.NotificationsConfig{ + NotificationsEmailerConfig: runtimeInterfaces.NotificationsEmailerConfig{ + EmailerConfig: runtimeInterfaces.EmailServerConfig{ + ServiceName: "unsupported", }, }, } - GetEmailer(cfg, promutils.NewTestScope()) + GetEmailer(cfg, promutils.NewTestScope(), &mocks.SecretManager{}) // shouldn't reach here + t.Errorf("did not panic") + } func TestNewNotificationPublisherAndProcessor(t *testing.T) { + testSandboxPublisher := NewNotificationsPublisher(notificationsConfig, scope) + assert.IsType(t, testSandboxPublisher, &implementations.SandboxPublisher{}) - testSandboxProcessor := NewNotificationsProcessor(notificationsConfig, scope) + + testSandboxProcessor := NewNotificationsProcessor(notificationsConfig, scope, &mocks.SecretManager{}) + assert.IsType(t, testSandboxProcessor, &implementations.SandboxProcessor{}) go func() { + testSandboxProcessor.StartProcessing() + }() assert.Nil(t, testSandboxPublisher.Publish(context.Background(), "TEST_NOTIFICATION", &testEmail)) assert.Nil(t, testSandboxProcessor.StopProcessing()) + } diff --git a/flyteadmin/pkg/async/notifications/implementations/emailers.go b/flyteadmin/pkg/async/notifications/implementations/emailers.go index e630b5a4ea..0da3fbf600 100644 --- a/flyteadmin/pkg/async/notifications/implementations/emailers.go +++ b/flyteadmin/pkg/async/notifications/implementations/emailers.go @@ -4,4 +4,5 @@ type ExternalEmailer = string const ( Sendgrid ExternalEmailer = "sendgrid" + SMTP ExternalEmailer = "smtp" ) diff --git a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go new file mode 100644 index 0000000000..5a705bc0c1 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go @@ -0,0 +1,158 @@ +package implementations + +import ( + "crypto/tls" + "fmt" + "net/smtp" + "strings" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + + "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" + "github.com/flyteorg/flyte/flyteadmin/pkg/errors" + runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +type SMTPEmailer struct { + config *runtimeInterfaces.NotificationsEmailerConfig + systemMetrics emailMetrics + tlsConf *tls.Config + auth *smtp.Auth + smtpClient interfaces.SMTPClient + CreateSMTPClientFunc func(connectString string) (interfaces.SMTPClient, error) +} + +func (s *SMTPEmailer) createClient(ctx context.Context) (interfaces.SMTPClient, error) { + newClient, err := s.CreateSMTPClientFunc(s.config.EmailerConfig.SMTPServer + ":" + s.config.EmailerConfig.SMTPPort) + + if err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error creating email client: %s", err)) + } + + if err = newClient.Hello("localhost"); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error initiating connection to SMTP server: %s", err)) + } + + if ok, _ := newClient.Extension("STARTTLS"); ok { + if err = newClient.StartTLS(s.tlsConf); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error initiating connection to SMTP server: %s", err)) + } + } + + if ok, _ := newClient.Extension("AUTH"); ok { + if err = newClient.Auth(*s.auth); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error authenticating email client: %s", err)) + } + } + + return newClient, nil +} + +func (s *SMTPEmailer) SendEmail(ctx context.Context, email *admin.EmailMessage) error { + + if s.smtpClient == nil || s.smtpClient.Noop() != nil { + + if s.smtpClient != nil { + err := s.smtpClient.Close() + if err != nil { + logger.Info(ctx, err) + } + } + smtpClient, err := s.createClient(ctx) + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error creating SMTP email client: %s", err)) + } + + s.smtpClient = smtpClient + } + + if err := s.smtpClient.Mail(email.SenderEmail); err != nil { + return s.emailError(ctx, fmt.Sprintf("Error creating email instance: %s", err)) + } + + for _, recipient := range email.RecipientsEmail { + if err := s.smtpClient.Rcpt(recipient); err != nil { + return s.emailError(ctx, fmt.Sprintf("Error adding email recipient: %s", err)) + } + } + + writer, err := s.smtpClient.Data() + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error adding email recipient: %s", err)) + } + + _, err = writer.Write([]byte(createMailBody(s.config.Sender, email))) + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error writing mail body: %s", err)) + } + + err = writer.Close() + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error closing mail body: %s", err)) + } + + s.systemMetrics.SendSuccess.Inc() + return nil +} + +func (s *SMTPEmailer) emailError(ctx context.Context, error string) error { + s.systemMetrics.SendError.Inc() + logger.Error(ctx, error) + return errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails") +} + +func createMailBody(emailSender string, email *admin.EmailMessage) string { + headerMap := make(map[string]string) + headerMap["From"] = emailSender + headerMap["To"] = strings.Join(email.RecipientsEmail, ",") + headerMap["Subject"] = email.SubjectLine + headerMap["Content-Type"] = "text/html; charset=\"UTF-8\"" + + mailMessage := "" + + for k, v := range headerMap { + mailMessage += fmt.Sprintf("%s: %s\r\n", k, v) + } + + mailMessage += "\r\n" + email.Body + + return mailMessage +} + +func NewSMTPEmailer(ctx context.Context, config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Emailer { + var tlsConfiguration *tls.Config + emailConf := config.NotificationsEmailerConfig.EmailerConfig + + smtpPassword, err := sm.Get(ctx, emailConf.SMTPPasswordSecretName) + if err != nil { + logger.Debug(ctx, "No SMTP password found.") + smtpPassword = "" + } + + auth := smtp.PlainAuth("", emailConf.SMTPUsername, smtpPassword, emailConf.SMTPServer) + + // #nosec G402 + tlsConfiguration = &tls.Config{ + InsecureSkipVerify: emailConf.SMTPSkipTLSVerify, + ServerName: emailConf.SMTPServer, + } + + return &SMTPEmailer{ + config: &config.NotificationsEmailerConfig, + systemMetrics: newEmailMetrics(scope.NewSubScope("smtp")), + tlsConf: tlsConfiguration, + auth: &auth, + CreateSMTPClientFunc: func(connectString string) (interfaces.SMTPClient, error) { + return smtp.Dial(connectString) + }, + } +} diff --git a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go new file mode 100644 index 0000000000..558a5d6408 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go @@ -0,0 +1,498 @@ +package implementations + +import ( + "context" + "crypto/tls" + "errors" + "net/smtp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + + notification_interfaces "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" + notification_mocks "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/mocks" + flyte_errors "github.com/flyteorg/flyte/flyteadmin/pkg/errors" + "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +type StringWriter struct { + buffer string + writeErr error + closeErr error +} + +func (s *StringWriter) Write(p []byte) (n int, err error) { + s.buffer = s.buffer + string(p) + return len(p), s.writeErr +} + +func (s *StringWriter) Close() error { + return s.closeErr +} + +func getNotificationsEmailerConfig() interfaces.NotificationsConfig { + return interfaces.NotificationsConfig{ + Type: "", + Region: "", + AWSConfig: interfaces.AWSConfig{}, + GCPConfig: interfaces.GCPConfig{}, + NotificationsPublisherConfig: interfaces.NotificationsPublisherConfig{}, + NotificationsProcessorConfig: interfaces.NotificationsProcessorConfig{}, + NotificationsEmailerConfig: interfaces.NotificationsEmailerConfig{ + EmailerConfig: interfaces.EmailServerConfig{ + ServiceName: SMTP, + SMTPServer: "smtpServer", + SMTPPort: "smtpPort", + SMTPUsername: "smtpUsername", + SMTPPasswordSecretName: "smtp_password", + }, + Subject: "subject", + Sender: "sender", + Body: "body"}, + ReconnectAttempts: 1, + ReconnectDelaySeconds: 2} +} + +func TestEmailCreation(t *testing.T) { + email := admin.EmailMessage{ + RecipientsEmail: []string{"john@doe.com", "teresa@tester.com"}, + SubjectLine: "subject", + Body: "Email Body", + SenderEmail: "sender@sender.com", + } + + body := createMailBody("sender@sender.com", &email) + assert.Contains(t, body, "From: sender@sender.com\r\n") + assert.Contains(t, body, "To: john@doe.com,teresa@tester.com") + assert.Contains(t, body, "Subject: subject\r\n") + assert.Contains(t, body, "Content-Type: text/html; charset=\"UTF-8\"\r\n") + assert.Contains(t, body, "Email Body") +} + +func TestNewSmtpEmailer(t *testing.T) { + secretManagerMock := mocks.SecretManager{} + secretManagerMock.On("Get", mock.Anything, "smtp_password").Return("password", nil) + + notificationsConfig := getNotificationsEmailerConfig() + + smtpEmailer := NewSMTPEmailer(context.Background(), notificationsConfig, promutils.NewTestScope(), &secretManagerMock) + + assert.NotNil(t, smtpEmailer) +} + +func TestCreateClient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil) + smtpClient.On("Extension", "STARTTLS").Return(true, "") + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil) + smtpClient.On("Extension", "AUTH").Return(true, "") + smtpClient.On("Auth", auth).Return(nil) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Nil(t, err) + assert.NotNil(t, client) + +} + +func TestCreateClientErrorCreatingClient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, errors.New("error creating client")) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorHello(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(errors.New("Error with hello")) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorStartTLS(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(errors.New("Error with startls")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorAuth(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(errors.New("Error with hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestSendMail(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: ""} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.True(t, strings.Contains(stringWriter.buffer, "From: sender")) + assert.True(t, strings.Contains(stringWriter.buffer, "To: alice@flyte.org,bob@flyte.org")) + assert.True(t, strings.Contains(stringWriter.buffer, "Subject: subject")) + assert.True(t, strings.Contains(stringWriter.buffer, "This is an email.")) + assert.Nil(t, err) + +} + +func TestSendMailCreateClientError(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(errors.New("error hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorMail(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(errors.New("error sending mail")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorRecipient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(errors.New("error adding recipient")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorData(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(nil, errors.New("error creating data writer")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorWriting(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: "", writeErr: errors.New("error writing"), closeErr: nil} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorClose(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: "", writeErr: nil, closeErr: errors.New("error writing")} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.True(t, strings.Contains(stringWriter.buffer, "From: sender")) + assert.True(t, strings.Contains(stringWriter.buffer, "To: alice@flyte.org,bob@flyte.org")) + assert.True(t, strings.Contains(stringWriter.buffer, "Subject: subject")) + assert.True(t, strings.Contains(stringWriter.buffer, "This is an email.")) + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func createSMTPEmailer(smtpClient notification_interfaces.SMTPClient, tlsConf *tls.Config, auth *smtp.Auth, creationErr error) *SMTPEmailer { + secretManagerMock := mocks.SecretManager{} + secretManagerMock.On("Get", mock.Anything, "smtp_password").Return("password", nil) + + notificationsConfig := getNotificationsEmailerConfig() + + return &SMTPEmailer{ + config: ¬ificationsConfig.NotificationsEmailerConfig, + systemMetrics: newEmailMetrics(promutils.NewTestScope()), + tlsConf: tlsConf, + auth: auth, + CreateSMTPClientFunc: func(connectString string) (notification_interfaces.SMTPClient, error) { + return smtpClient, creationErr + }, + smtpClient: smtpClient, + } +} diff --git a/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go b/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go new file mode 100644 index 0000000000..bdc6171f46 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go @@ -0,0 +1,22 @@ +package interfaces + +import ( + "crypto/tls" + "io" + "net/smtp" +) + +// This interface is introduced to allow for mocking of the smtp.Client object. + +//go:generate mockery -name=SMTPClient -output=../mocks -case=underscore +type SMTPClient interface { + Hello(localName string) error + Extension(ext string) (bool, string) + Auth(a smtp.Auth) error + StartTLS(config *tls.Config) error + Noop() error + Close() error + Mail(from string) error + Rcpt(to string) error + Data() (io.WriteCloser, error) +} diff --git a/flyteadmin/pkg/async/notifications/mocks/smtp_client.go b/flyteadmin/pkg/async/notifications/mocks/smtp_client.go new file mode 100644 index 0000000000..11dafefc9c --- /dev/null +++ b/flyteadmin/pkg/async/notifications/mocks/smtp_client.go @@ -0,0 +1,321 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + io "io" + smtp "net/smtp" + + mock "github.com/stretchr/testify/mock" + + tls "crypto/tls" +) + +// SMTPClient is an autogenerated mock type for the SMTPClient type +type SMTPClient struct { + mock.Mock +} + +type SMTPClient_Auth struct { + *mock.Call +} + +func (_m SMTPClient_Auth) Return(_a0 error) *SMTPClient_Auth { + return &SMTPClient_Auth{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnAuth(a smtp.Auth) *SMTPClient_Auth { + c_call := _m.On("Auth", a) + return &SMTPClient_Auth{Call: c_call} +} + +func (_m *SMTPClient) OnAuthMatch(matchers ...interface{}) *SMTPClient_Auth { + c_call := _m.On("Auth", matchers...) + return &SMTPClient_Auth{Call: c_call} +} + +// Auth provides a mock function with given fields: a +func (_m *SMTPClient) Auth(a smtp.Auth) error { + ret := _m.Called(a) + + var r0 error + if rf, ok := ret.Get(0).(func(smtp.Auth) error); ok { + r0 = rf(a) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Close struct { + *mock.Call +} + +func (_m SMTPClient_Close) Return(_a0 error) *SMTPClient_Close { + return &SMTPClient_Close{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnClose() *SMTPClient_Close { + c_call := _m.On("Close") + return &SMTPClient_Close{Call: c_call} +} + +func (_m *SMTPClient) OnCloseMatch(matchers ...interface{}) *SMTPClient_Close { + c_call := _m.On("Close", matchers...) + return &SMTPClient_Close{Call: c_call} +} + +// Close provides a mock function with given fields: +func (_m *SMTPClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Data struct { + *mock.Call +} + +func (_m SMTPClient_Data) Return(_a0 io.WriteCloser, _a1 error) *SMTPClient_Data { + return &SMTPClient_Data{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SMTPClient) OnData() *SMTPClient_Data { + c_call := _m.On("Data") + return &SMTPClient_Data{Call: c_call} +} + +func (_m *SMTPClient) OnDataMatch(matchers ...interface{}) *SMTPClient_Data { + c_call := _m.On("Data", matchers...) + return &SMTPClient_Data{Call: c_call} +} + +// Data provides a mock function with given fields: +func (_m *SMTPClient) Data() (io.WriteCloser, error) { + ret := _m.Called() + + var r0 io.WriteCloser + if rf, ok := ret.Get(0).(func() io.WriteCloser); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.WriteCloser) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SMTPClient_Extension struct { + *mock.Call +} + +func (_m SMTPClient_Extension) Return(_a0 bool, _a1 string) *SMTPClient_Extension { + return &SMTPClient_Extension{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SMTPClient) OnExtension(ext string) *SMTPClient_Extension { + c_call := _m.On("Extension", ext) + return &SMTPClient_Extension{Call: c_call} +} + +func (_m *SMTPClient) OnExtensionMatch(matchers ...interface{}) *SMTPClient_Extension { + c_call := _m.On("Extension", matchers...) + return &SMTPClient_Extension{Call: c_call} +} + +// Extension provides a mock function with given fields: ext +func (_m *SMTPClient) Extension(ext string) (bool, string) { + ret := _m.Called(ext) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(ext) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 string + if rf, ok := ret.Get(1).(func(string) string); ok { + r1 = rf(ext) + } else { + r1 = ret.Get(1).(string) + } + + return r0, r1 +} + +type SMTPClient_Hello struct { + *mock.Call +} + +func (_m SMTPClient_Hello) Return(_a0 error) *SMTPClient_Hello { + return &SMTPClient_Hello{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnHello(localName string) *SMTPClient_Hello { + c_call := _m.On("Hello", localName) + return &SMTPClient_Hello{Call: c_call} +} + +func (_m *SMTPClient) OnHelloMatch(matchers ...interface{}) *SMTPClient_Hello { + c_call := _m.On("Hello", matchers...) + return &SMTPClient_Hello{Call: c_call} +} + +// Hello provides a mock function with given fields: localName +func (_m *SMTPClient) Hello(localName string) error { + ret := _m.Called(localName) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(localName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Mail struct { + *mock.Call +} + +func (_m SMTPClient_Mail) Return(_a0 error) *SMTPClient_Mail { + return &SMTPClient_Mail{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnMail(from string) *SMTPClient_Mail { + c_call := _m.On("Mail", from) + return &SMTPClient_Mail{Call: c_call} +} + +func (_m *SMTPClient) OnMailMatch(matchers ...interface{}) *SMTPClient_Mail { + c_call := _m.On("Mail", matchers...) + return &SMTPClient_Mail{Call: c_call} +} + +// Mail provides a mock function with given fields: from +func (_m *SMTPClient) Mail(from string) error { + ret := _m.Called(from) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(from) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Noop struct { + *mock.Call +} + +func (_m SMTPClient_Noop) Return(_a0 error) *SMTPClient_Noop { + return &SMTPClient_Noop{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnNoop() *SMTPClient_Noop { + c_call := _m.On("Noop") + return &SMTPClient_Noop{Call: c_call} +} + +func (_m *SMTPClient) OnNoopMatch(matchers ...interface{}) *SMTPClient_Noop { + c_call := _m.On("Noop", matchers...) + return &SMTPClient_Noop{Call: c_call} +} + +// Noop provides a mock function with given fields: +func (_m *SMTPClient) Noop() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Rcpt struct { + *mock.Call +} + +func (_m SMTPClient_Rcpt) Return(_a0 error) *SMTPClient_Rcpt { + return &SMTPClient_Rcpt{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnRcpt(to string) *SMTPClient_Rcpt { + c_call := _m.On("Rcpt", to) + return &SMTPClient_Rcpt{Call: c_call} +} + +func (_m *SMTPClient) OnRcptMatch(matchers ...interface{}) *SMTPClient_Rcpt { + c_call := _m.On("Rcpt", matchers...) + return &SMTPClient_Rcpt{Call: c_call} +} + +// Rcpt provides a mock function with given fields: to +func (_m *SMTPClient) Rcpt(to string) error { + ret := _m.Called(to) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(to) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_StartTLS struct { + *mock.Call +} + +func (_m SMTPClient_StartTLS) Return(_a0 error) *SMTPClient_StartTLS { + return &SMTPClient_StartTLS{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnStartTLS(config *tls.Config) *SMTPClient_StartTLS { + c_call := _m.On("StartTLS", config) + return &SMTPClient_StartTLS{Call: c_call} +} + +func (_m *SMTPClient) OnStartTLSMatch(matchers ...interface{}) *SMTPClient_StartTLS { + c_call := _m.On("StartTLS", matchers...) + return &SMTPClient_StartTLS{Call: c_call} +} + +// StartTLS provides a mock function with given fields: config +func (_m *SMTPClient) StartTLS(config *tls.Config) error { + ret := _m.Called(config) + + var r0 error + if rf, ok := ret.Get(0).(func(*tls.Config) error); ok { + r0 = rf(config) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flyteadmin/pkg/rpc/adminservice/base.go b/flyteadmin/pkg/rpc/adminservice/base.go index 8df2c595c7..491a24a1f0 100644 --- a/flyteadmin/pkg/rpc/adminservice/base.go +++ b/flyteadmin/pkg/rpc/adminservice/base.go @@ -20,6 +20,7 @@ import ( workflowengineImpl "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/impl" "github.com/flyteorg/flyte/flyteadmin/plugins" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" @@ -45,7 +46,7 @@ type AdminService struct { const defaultRetries = 3 func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, configuration runtimeIfaces.Configuration, - kubeConfig, master string, dataStorageClient *storage.DataStore, adminScope promutils.Scope) *AdminService { + kubeConfig, master string, dataStorageClient *storage.DataStore, adminScope promutils.Scope, sm core.SecretManager) *AdminService { applicationConfiguration := configuration.ApplicationConfiguration().GetTopLevelConfig() panicCounter := adminScope.MustNewCounter("initialization_panic", @@ -81,7 +82,7 @@ func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, confi pluginRegistry.RegisterDefault(plugins.PluginIDWorkflowExecutor, workflowExecutor) publisher := notifications.NewNotificationsPublisher(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope) - processor := notifications.NewNotificationsProcessor(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope) + processor := notifications.NewNotificationsProcessor(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope, sm) eventPublisher := notifications.NewEventsPublisher(*configuration.ApplicationConfiguration().GetExternalEventsConfig(), adminScope) go func() { logger.Info(ctx, "Started processing notifications.") diff --git a/flyteadmin/pkg/runtime/interfaces/application_configuration.go b/flyteadmin/pkg/runtime/interfaces/application_configuration.go index 3505150919..e3453db0f7 100644 --- a/flyteadmin/pkg/runtime/interfaces/application_configuration.go +++ b/flyteadmin/pkg/runtime/interfaces/application_configuration.go @@ -492,8 +492,13 @@ type NotificationsProcessorConfig struct { type EmailServerConfig struct { ServiceName string `json:"serviceName"` // Only one of these should be set. - APIKeyEnvVar string `json:"apiKeyEnvVar"` - APIKeyFilePath string `json:"apiKeyFilePath"` + APIKeyEnvVar string `json:"apiKeyEnvVar"` + APIKeyFilePath string `json:"apiKeyFilePath"` + SMTPServer string `json:"smtpServer"` + SMTPPort string `json:"smtpPort"` + SMTPSkipTLSVerify bool `json:"smtpSkipTLSVerify"` + SMTPUsername string `json:"smtpUsername"` + SMTPPasswordSecretName string `json:"smtpPasswordSecretName"` } // This section handles the configuration of notifications emails. diff --git a/flyteadmin/pkg/server/service.go b/flyteadmin/pkg/server/service.go index 587ea86e3b..840d0d9f17 100644 --- a/flyteadmin/pkg/server/service.go +++ b/flyteadmin/pkg/server/service.go @@ -43,6 +43,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/clients/go/assets" grpcService "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/gateway/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/secretmanager" "github.com/flyteorg/flyte/flytestdlib/contextutils" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -82,7 +83,7 @@ func SetMetricKeys(appConfig *runtimeIfaces.ApplicationConfig) { // Creates a new gRPC Server with all the configuration func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, storageCfg *storage.Config, authCtx interfaces.AuthenticationContext, - scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) { + scope promutils.Scope, sm core.SecretManager, opts ...grpc.ServerOption) (*grpc.Server, error) { logger.Infof(ctx, "Registering default middleware with blanket auth validation") pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer( @@ -152,7 +153,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c } configuration := runtime2.NewConfigurationProvider() - adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, adminScope) + adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, adminScope, sm) grpcService.RegisterAdminServiceServer(grpcServer, adminServer) if cfg.Security.UseAuth { grpcService.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService()) @@ -339,12 +340,15 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, // This will parse configuration and create the necessary objects for dealing with auth var authCtx interfaces.AuthenticationContext var err error + + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + // This code is here to support authentication without SSL. This setup supports a network topology where // Envoy does the SSL termination. The final hop is made over localhost only on a trusted machine. // Warning: Running authentication without SSL in any other topology is a severe security flaw. // See the auth.Config object for additional settings as well. if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + var oauth2Provider interfaces.OAuth2Provider var oauth2ResourceServer interfaces.OAuth2ResourceServer if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { @@ -373,7 +377,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, } } - grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageConfig, authCtx, scope) + grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageConfig, authCtx, scope, sm) if err != nil { return fmt.Errorf("failed to create a newGRPCServer. Error: %w", err) } @@ -448,13 +452,14 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error { certPool, cert, err := GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + if err != nil { return err } // This will parse configuration and create the necessary objects for dealing with auth var authCtx interfaces.AuthenticationContext if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) var oauth2Provider interfaces.OAuth2Provider var oauth2ResourceServer interfaces.OAuth2ResourceServer if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { @@ -483,7 +488,7 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c } } - grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageCfg, authCtx, scope, grpc.Creds(credentials.NewServerTLSFromCert(cert))) + grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageCfg, authCtx, scope, sm, grpc.Creds(credentials.NewServerTLSFromCert(cert))) if err != nil { return fmt.Errorf("failed to create a newGRPCServer. Error: %w", err) } From 7989209e15600b56fcf0f4c4a7c9af7bfeab6f3e Mon Sep 17 00:00:00 2001 From: Prafulla Mahindrakar Date: Mon, 16 Sep 2024 09:52:43 -0700 Subject: [PATCH 04/17] Added literal offloading for array node map tasks (#5697) * Added literal offloading for array node map tasks Signed-off-by: pmahindrakar-oss * fix Signed-off-by: pmahindrakar-oss * feedback Signed-off-by: pmahindrakar-oss * feedback Signed-off-by: pmahindrakar-oss * nit Signed-off-by: pmahindrakar-oss * feedback Signed-off-by: pmahindrakar-oss * add missing flag files Signed-off-by: pmahindrakar-oss * disabling flag until flytekit release is confirmed Signed-off-by: pmahindrakar-oss * nit Signed-off-by: pmahindrakar-oss --------- Signed-off-by: pmahindrakar-oss --- flyteadmin/go.mod | 1 + flyteadmin/go.sum | 2 + .../manager/impl/testutils/mock_requests.go | 22 +++ .../impl/validation/execution_validator.go | 8 +- .../validation/execution_validator_test.go | 34 ++++ .../pkg/manager/impl/validation/validation.go | 2 +- flytepropeller/go.mod | 1 + flytepropeller/go.sum | 2 + .../pkg/apis/flyteworkflow/v1alpha1/iface.go | 4 + .../pkg/compiler/transformers/k8s/inputs.go | 8 +- .../pkg/controller/config/config.go | 119 +++++++---- .../pkg/controller/config/config_flags.go | 4 + .../controller/config/config_flags_test.go | 56 ++++++ .../pkg/controller/config/config_test.go | 54 +++++ flytepropeller/pkg/controller/controller.go | 4 +- .../pkg/controller/nodes/array/handler.go | 17 +- .../controller/nodes/array/handler_test.go | 6 +- .../pkg/controller/nodes/common/utils.go | 98 +++++++++- .../pkg/controller/nodes/common/utils_test.go | 185 ++++++++++++++++++ .../pkg/controller/nodes/executor.go | 8 +- .../pkg/controller/nodes/executor_test.go | 32 +-- .../nodes/factory/handler_factory.go | 45 +++-- .../pkg/controller/workflow/executor_test.go | 18 +- go.mod | 1 + go.sum | 2 + 25 files changed, 642 insertions(+), 91 deletions(-) create mode 100644 flytepropeller/pkg/controller/config/config_test.go diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index 2eec0f8cf3..cfc2bfa010 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -82,6 +82,7 @@ require ( cloud.google.com/go/pubsub v1.34.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect + github.com/Masterminds/semver v1.5.0 // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 049add4bbc..5b7d47b2a6 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -75,6 +75,8 @@ github.com/DataDog/datadog-go v3.4.1+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/datadog-go v4.0.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20191210083620-6965a1cfed68/go.mod h1:gMGUEe16aZh0QN941HgDjwrdjU4iTthPoz2/AtDRADE= github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= diff --git a/flyteadmin/pkg/manager/impl/testutils/mock_requests.go b/flyteadmin/pkg/manager/impl/testutils/mock_requests.go index b868612269..b3d01897f1 100644 --- a/flyteadmin/pkg/manager/impl/testutils/mock_requests.go +++ b/flyteadmin/pkg/manager/impl/testutils/mock_requests.go @@ -241,6 +241,28 @@ func GetExecutionRequest() *admin.ExecutionCreateRequest { } } +func GetExecutionRequestWithOffloadedInputs(inputParam string, literalValue *core.Literal) *admin.ExecutionCreateRequest { + execReq := GetExecutionRequest() + execReq.Inputs = &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "s3://bucket/key", + SizeBytes: 100, + InferredType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + }, + }, + } + return execReq +} + func GetSampleWorkflowSpecForTest() *admin.WorkflowSpec { return &admin.WorkflowSpec{ Template: &core.WorkflowTemplate{ diff --git a/flyteadmin/pkg/manager/impl/validation/execution_validator.go b/flyteadmin/pkg/manager/impl/validation/execution_validator.go index 0a21165c93..f7b385b8a8 100644 --- a/flyteadmin/pkg/manager/impl/validation/execution_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/execution_validator.go @@ -100,7 +100,13 @@ func CheckAndFetchInputsForExecution( } executionInputMap[name] = expectedInput.GetDefault() } else { - inputType := validators.LiteralTypeForLiteral(executionInputMap[name]) + var inputType *core.LiteralType + switch executionInputMap[name].GetValue().(type) { + case *core.Literal_OffloadedMetadata: + inputType = executionInputMap[name].GetOffloadedMetadata().GetInferredType() + default: + inputType = validators.LiteralTypeForLiteral(executionInputMap[name]) + } if !validators.AreTypesCastable(inputType, expectedInput.GetVar().GetType()) { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got %s", name, expectedInput.GetVar().GetType(), inputType) } diff --git a/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go b/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go index 1329dc6f96..7e5f991788 100644 --- a/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go @@ -105,6 +105,40 @@ func TestGetExecutionInputs(t *testing.T) { assert.EqualValues(t, expectedMap, actualInputs) } +func TestGetExecutionWithOffloadedInputs(t *testing.T) { + execLiteral := &core.Literal{ + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "s3://bucket/key", + SizeBytes: 100, + InferredType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + } + executionRequest := testutils.GetExecutionRequestWithOffloadedInputs("foo", execLiteral) + lpRequest := testutils.GetLaunchPlanRequest() + + actualInputs, err := CheckAndFetchInputsForExecution( + executionRequest.Inputs, + lpRequest.Spec.FixedInputs, + lpRequest.Spec.DefaultInputs, + ) + expectedMap := core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": execLiteral, + "bar": coreutils.MustMakeLiteral("bar-value"), + }, + } + assert.Nil(t, err) + assert.NotNil(t, actualInputs) + assert.EqualValues(t, expectedMap.GetLiterals()["foo"], actualInputs.Literals["foo"]) + assert.EqualValues(t, expectedMap.GetLiterals()["bar"], actualInputs.Literals["bar"]) +} + func TestValidateExecInputsWrongType(t *testing.T) { executionRequest := testutils.GetExecutionRequest() lpRequest := testutils.GetLaunchPlanRequest() diff --git a/flyteadmin/pkg/manager/impl/validation/validation.go b/flyteadmin/pkg/manager/impl/validation/validation.go index 6c9bd2fdbb..894eaee435 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation.go +++ b/flyteadmin/pkg/manager/impl/validation/validation.go @@ -234,7 +234,7 @@ func validateLiteralMap(inputMap *core.LiteralMap, fieldName string) error { if name == "" { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing key in %s", fieldName) } - if fixedInput == nil || fixedInput.GetValue() == nil { + if fixedInput.GetValue() == nil && fixedInput.GetOffloadedMetadata() == nil { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing valid literal in %s %s", fieldName, name) } if isDateTime(fixedInput) { diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 5d828f9e9b..f579049aff 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 + github.com/Masterminds/semver v1.5.0 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 github.com/flyteorg/flyte/flyteidl v0.0.0-00010101000000-000000000000 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 8bbdd06eba..07a92b902b 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -64,6 +64,8 @@ github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 h1:xJ0dAkuxJXf github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295/go.mod h1:e0aH495YLkrsIe9fhedd6aSR6fgU/qhKvtroi6y7G/M= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625/go.mod h1:6PnrZv6zUDkrNMw0mIoGRmGBR7i9LulhKPmxFq4rUiM= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/aws/aws-sdk-go v1.44.2 h1:5VBk5r06bgxgRKVaUtm1/4NT/rtrnH2E4cnAYv5zgQc= diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index bcd1064e67..486ac35a16 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -554,6 +554,10 @@ func GetOutputsFile(outputDir DataReference) DataReference { return outputDir + "/outputs.pb" } +func GetOutputsLiteralMetadataFile(literalKey string, outputDir DataReference) DataReference { + return outputDir + DataReference(fmt.Sprintf("/%s_offloaded_metadata.pb", literalKey)) +} + func GetInputsFile(inputDir DataReference) DataReference { return inputDir + "/inputs.pb" } diff --git a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go index 2d967c560e..0976df669b 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go @@ -35,7 +35,13 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor continue } - inputType := validators.LiteralTypeForLiteral(inputVal) + var inputType *core.LiteralType + switch inputVal.GetValue().(type) { + case *core.Literal_OffloadedMetadata: + inputType = inputVal.GetOffloadedMetadata().GetInferredType() + default: + inputType = validators.LiteralTypeForLiteral(inputVal) + } if !validators.AreTypesCastable(inputType, v.Type) { errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String())) continue diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index a0217e186a..2d61c94970 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -34,12 +34,16 @@ package config import ( + "context" + "fmt" "time" + "github.com/Masterminds/semver" "k8s.io/apimachinery/pkg/types" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/contextutils" + "github.com/flyteorg/flyte/flytestdlib/logger" ) //go:generate pflags Config --default-var=defaultConfig @@ -120,6 +124,14 @@ var ( EventVersion: 0, DefaultParallelismBehavior: ParallelismBehaviorUnlimited, }, + LiteralOffloadingConfig: LiteralOffloadingConfig{ + Enabled: false, // Default keep this disabled and we will followup when flytekit is released with the offloaded changes. + SupportedSDKVersions: map[string]string{ // The key is the SDK name (matches the supported SDK in core.RuntimeMetadata_RuntimeType) and the value is the minimum supported version + "FLYTE_SDK": "1.13.5", // Expected release number with flytekit support from this PR https://github.com/flyteorg/flytekit/pull/2685 + }, + MinSizeInMBForOffloading: 10, // 10 MB is the default size for offloading + MaxSizeInMBForOffloading: 1000, // 1 GB is the default size before failing fast. + }, } ) @@ -127,40 +139,79 @@ var ( // the base configuration to start propeller // NOTE: when adding new fields, do not mark them as "omitempty" if it's desirable to read the value from env variables. type Config struct { - KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."` - MasterURL string `json:"master"` - Workers int `json:"workers" pflag:",Number of threads to process workflows"` - WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:",Frequency of re-evaluating workflows"` - DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:",Frequency of re-evaluating downstream tasks"` - LimitNamespace string `json:"limit-namespace" pflag:",Namespaces to watch for this propeller"` - ProfilerPort config.Port `json:"prof-port" pflag:",Profiler port"` - MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."` - DefaultRawOutputPrefix string `json:"rawoutput-prefix" pflag:",a fully qualified storage path of the form s3://flyte/abc/..., where all data sandboxes should be stored."` - Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` - MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."` - MetricKeys []string `json:"metrics-keys" pflag:",Metrics labels applied to prometheus metrics emitted by the service."` - EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"Enable remote Workflow launcher to Admin"` - MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"Maximum number of retries per workflow"` - MaxTTLInHours int `json:"max-ttl-hours" pflag:"Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"` - GCInterval config.Duration `json:"gc-interval" pflag:"Run periodic GC every 30 minutes"` - LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` - PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` - MaxDatasetSizeBytes int64 `json:"max-output-size-bytes" pflag:",Deprecated! Use storage.limits.maxDownloadMBs instead"` - EnableGrpcLatencyMetrics bool `json:"enable-grpc-latency-metrics" pflag:",Enable grpc latency metrics. Note Histograms metrics can be expensive on Prometheus servers."` - KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"` - NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"` - MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."` - EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."` - IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"` - ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"` - IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"` - ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"` - IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"` - ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` - ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` - CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` - NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` - ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"` + KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."` + MasterURL string `json:"master"` + Workers int `json:"workers" pflag:",Number of threads to process workflows"` + WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:",Frequency of re-evaluating workflows"` + DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:",Frequency of re-evaluating downstream tasks"` + LimitNamespace string `json:"limit-namespace" pflag:",Namespaces to watch for this propeller"` + ProfilerPort config.Port `json:"prof-port" pflag:",Profiler port"` + MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."` + DefaultRawOutputPrefix string `json:"rawoutput-prefix" pflag:",a fully qualified storage path of the form s3://flyte/abc/..., where all data sandboxes should be stored."` + Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` + MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."` + MetricKeys []string `json:"metrics-keys" pflag:",Metrics labels applied to prometheus metrics emitted by the service."` + EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"Enable remote Workflow launcher to Admin"` + MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"Maximum number of retries per workflow"` + MaxTTLInHours int `json:"max-ttl-hours" pflag:"Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"` + GCInterval config.Duration `json:"gc-interval" pflag:"Run periodic GC every 30 minutes"` + LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` + PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` + MaxDatasetSizeBytes int64 `json:"max-output-size-bytes" pflag:",Deprecated! Use storage.limits.maxDownloadMBs instead"` + EnableGrpcLatencyMetrics bool `json:"enable-grpc-latency-metrics" pflag:",Enable grpc latency metrics. Note Histograms metrics can be expensive on Prometheus servers."` + KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"` + NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"` + MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."` + EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."` + IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"` + ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"` + IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"` + ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"` + IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"` + ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` + ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` + CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` + NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` + ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"` + LiteralOffloadingConfig LiteralOffloadingConfig `json:"literalOffloadingConfig" pflag:",config used for literal offloading."` +} + +type LiteralOffloadingConfig struct { + Enabled bool + // Maps flytekit and union SDK names to minimum supported version that can handle reading offloaded literals. + SupportedSDKVersions map[string]string + // Default, 10Mbs. Determines the size of a literal at which to trigger offloading + MinSizeInMBForOffloading int64 + // Fail fast threshold + MaxSizeInMBForOffloading int64 +} + +// IsSupportedSDKVersion returns true if the provided SDK and version are supported by the literal offloading config. +func (l LiteralOffloadingConfig) IsSupportedSDKVersion(sdk string, versionString string) bool { + if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok { + c, err := semver.NewConstraint(fmt.Sprintf(">= %s", leastSupportedVersion)) + if err != nil { + // This should never happen + logger.Warnf(context.TODO(), "Failed to parse version constraint %s", leastSupportedVersion) + return false + } + version, err := semver.NewVersion(versionString) + if err != nil { + // This should never happen + logger.Warnf(context.TODO(), "Failed to parse version %s", versionString) + return false + } + return c.Check(version) + } + return false +} + +// GetSupportedSDKVersion returns the least supported version for the provided SDK. +func (l LiteralOffloadingConfig) GetSupportedSDKVersion(sdk string) string { + if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok { + return leastSupportedVersion + } + return "" } // KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client. diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index 858fc8a8ba..b2e88e88e6 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -112,5 +112,9 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "node-execution-worker-count"), defaultConfig.NodeExecutionWorkerCount, "Number of workers to evaluate node executions, currently only used for array nodes") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "array-node-config.event-version"), defaultConfig.ArrayNode.EventVersion, "ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "array-node-config.default-parallelism-behavior"), defaultConfig.ArrayNode.DefaultParallelismBehavior, "Default parallelism behavior for array nodes") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.Enabled"), defaultConfig.LiteralOffloadingConfig.Enabled, "") + cmdFlags.StringToString(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.SupportedSDKVersions"), defaultConfig.LiteralOffloadingConfig.SupportedSDKVersions, "") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.MinSizeInMBForOffloading"), defaultConfig.LiteralOffloadingConfig.MinSizeInMBForOffloading, "") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.MaxSizeInMBForOffloading"), defaultConfig.LiteralOffloadingConfig.MaxSizeInMBForOffloading, "") return cmdFlags } diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go index 27e7b76efa..aadb24b36a 100755 --- a/flytepropeller/pkg/controller/config/config_flags_test.go +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -967,4 +967,60 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_literalOffloadingConfig.Enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.Enabled", testValue) + if vBool, err := cmdFlags.GetBool("literalOffloadingConfig.Enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LiteralOffloadingConfig.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.SupportedSDKVersions", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "a=1,b=2" + + cmdFlags.Set("literalOffloadingConfig.SupportedSDKVersions", testValue) + if vStringToString, err := cmdFlags.GetStringToString("literalOffloadingConfig.SupportedSDKVersions"); err == nil { + testDecodeRaw_Config(t, vStringToString, &actual.LiteralOffloadingConfig.SupportedSDKVersions) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.MinSizeInMBForOffloading", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.MinSizeInMBForOffloading", testValue) + if vInt64, err := cmdFlags.GetInt64("literalOffloadingConfig.MinSizeInMBForOffloading"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.LiteralOffloadingConfig.MinSizeInMBForOffloading) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.MaxSizeInMBForOffloading", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.MaxSizeInMBForOffloading", testValue) + if vInt64, err := cmdFlags.GetInt64("literalOffloadingConfig.MaxSizeInMBForOffloading"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.LiteralOffloadingConfig.MaxSizeInMBForOffloading) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flytepropeller/pkg/controller/config/config_test.go b/flytepropeller/pkg/controller/config/config_test.go new file mode 100644 index 0000000000..afc9ed2fea --- /dev/null +++ b/flytepropeller/pkg/controller/config/config_test.go @@ -0,0 +1,54 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsSupportedSDKVersion(t *testing.T) { + t.Run("supported version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.True(t, config.IsSupportedSDKVersion("flytekit", "0.16.0")) + }) + + t.Run("unsupported version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.15.0")) + }) + + t.Run("unsupported SDK", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("unknown", "0.16.0")) + }) + + t.Run("invalid version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "invalid")) + }) + + t.Run("invalid constraint", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "invalid", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.16.0")) + }) +} diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go index c59aa9745d..39047e811d 100644 --- a/flytepropeller/pkg/controller/controller.go +++ b/flytepropeller/pkg/controller/controller.go @@ -436,14 +436,14 @@ func New(ctx context.Context, cfg *config.Config, kubeClientset kubernetes.Inter recoveryClient := recovery.NewClient(adminClient) nodeHandlerFactory, err := factory.NewHandlerFactory(ctx, launchPlanActor, launchPlanActor, - kubeClient, kubeClientset, catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, scope) + kubeClient, kubeClientset, catalogClient, recoveryClient, &cfg.EventConfig, cfg.LiteralOffloadingConfig, cfg.ClusterID, signalClient, scope) if err != nil { return nil, errors.Wrapf(err, "failed to create node handler factory") } nodeExecutor, err := nodes.NewExecutor(ctx, cfg.NodeConfig, store, controller.enqueueWorkflowForNodeUpdates, eventSink, launchPlanActor, launchPlanActor, storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, - catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) + catalogClient, recoveryClient, cfg.LiteralOffloadingConfig, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) if err != nil { return nil, errors.Wrapf(err, "Failed to create Controller.") } diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index a101ed5a30..5e9f910e14 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -45,6 +45,7 @@ var ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig gatherOutputsRequestChannel chan *gatherOutputsRequest metrics metrics nodeExecutionRequestChannel chan *nodeExecutionRequest @@ -498,7 +499,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // attempt best effort at initializing outputLiterals with output variable names. currently // only TaskNode and WorkflowNode contain node interfaces. outputLiterals := make(map[string]*idlcore.Literal) - switch arrayNode.GetSubNodeSpec().GetKind() { case v1alpha1.NodeKindTask: taskID := *arrayNode.GetSubNodeSpec().TaskRef @@ -547,6 +547,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu return handler.UnknownTransition, fmt.Errorf("worker error(s) encountered: %s", workerErrorCollector.Summary(events.MaxErrorMessageLength)) } + // only offload literal if config is enabled for this feature. + if a.literalOffloadingConfig.Enabled { + for outputLiteralKey, outputLiteral := range outputLiterals { + // if the size of the output Literal is > threshold then we write the literal to the offloaded store and populate the literal with its zero value and update the offloaded url + // use the OffloadLargeLiteralKey to create {OffloadLargeLiteralKey}_offloaded_metadata.pb file in the datastore. + // Update the url in the outputLiteral with the offloaded url and also update the size of the literal. + offloadedOutputFile := v1alpha1.GetOutputsLiteralMetadataFile(outputLiteralKey, nCtx.NodeStatus().GetOutputDir()) + if err := common.OffloadLargeLiteral(ctx, nCtx.DataStore(), offloadedOutputFile, outputLiteral, a.literalOffloadingConfig); err != nil { + return handler.UnknownTransition, err + } + } + } outputLiteralMap := &idlcore.LiteralMap{ Literals: outputLiterals, } @@ -649,7 +661,7 @@ func (a *arrayNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) e } // New initializes a new arrayNodeHandler -func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) (interfaces.NodeHandler, error) { +func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, literalOffloadingConfig config.LiteralOffloadingConfig, scope promutils.Scope) (interfaces.NodeHandler, error) { // create k8s PluginState byte mocks to reuse instead of creating for each subNode evaluation pluginStateBytesNotStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseNotStarted}) if err != nil { @@ -676,6 +688,7 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ eventConfig: deepCopiedEventConfig, + literalOffloadingConfig: literalOffloadingConfig, gatherOutputsRequestChannel: make(chan *gatherOutputsRequest), metrics: newMetrics(arrayScope), nodeExecutionRequestChannel: make(chan *nodeExecutionRequest), diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index 648d70e36c..cb2f2898a6 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -52,6 +52,8 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter adminClient := launchplan.NewFailFastLaunchPlanExecutor() enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} eventConfig := &config.EventConfig{ErrorOnAlreadyExists: true} + offloadingConfig := config.LiteralOffloadingConfig{Enabled: false} + literalOffloadingConfig := config.LiteralOffloadingConfig{Enabled: true, MinSizeInMBForOffloading: 1024, MaxSizeInMBForOffloading: 1024 * 1024} mockEventSink := eventmocks.NewMockEventSink() mockHandlerFactory := &mocks.HandlerFactory{} mockHandlerFactory.OnGetHandlerMatch(mock.Anything).Return(nodeHandler, nil) @@ -62,11 +64,11 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter // create node executor nodeExecutor, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, dataStore, enqueueWorkflowFunc, mockEventSink, adminClient, - adminClient, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, eventConfig, "clusterID", mockSignalClient, mockHandlerFactory, scope) + adminClient, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, offloadingConfig, eventConfig, "clusterID", mockSignalClient, mockHandlerFactory, scope) assert.NoError(t, err) // return ArrayNodeHandler - arrayNodeHandler, err := New(nodeExecutor, eventConfig, scope) + arrayNodeHandler, err := New(nodeExecutor, eventConfig, literalOffloadingConfig, scope) if err != nil { return nil, err } diff --git a/flytepropeller/pkg/controller/nodes/common/utils.go b/flytepropeller/pkg/controller/nodes/common/utils.go index 04ddc5183d..89bb0afe2e 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils.go +++ b/flytepropeller/pkg/controller/nodes/common/utils.go @@ -2,17 +2,28 @@ package common import ( "context" + "fmt" "strconv" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/encoding" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/storage" ) -const maxUniqueIDLength = 20 +const ( + maxUniqueIDLength = 20 + MB = 1024 * 1024 // 1 MB in bytes (1 MiB) +) // GenerateUniqueID is the UniqueId of a node is unique within a given workflow execution. // In order to achieve that we track the lineage of the node. @@ -67,3 +78,88 @@ func GetTargetEntity(ctx context.Context, nCtx interfaces.NodeExecutionContext) } return targetEntity } + +// OffloadLargeLiteral offloads the large literal if meets the threshold conditions +func OffloadLargeLiteral(ctx context.Context, datastore *storage.DataStore, dataReference storage.DataReference, + toBeOffloaded *idlcore.Literal, literalOffloadingConfig config.LiteralOffloadingConfig) error { + literalSizeBytes := int64(proto.Size(toBeOffloaded)) + literalSizeMB := literalSizeBytes / MB + // check if the literal is large + if literalSizeMB >= literalOffloadingConfig.MaxSizeInMBForOffloading { + errString := fmt.Sprintf("Literal size [%d] MB is larger than the max size [%d] MB for offloading", literalSizeMB, literalOffloadingConfig.MaxSizeInMBForOffloading) + logger.Errorf(ctx, errString) + return fmt.Errorf(errString) + } + if literalSizeMB < literalOffloadingConfig.MinSizeInMBForOffloading { + logger.Debugf(ctx, "Literal size [%d] MB is smaller than the min size [%d] MB for offloading", literalSizeMB, literalOffloadingConfig.MinSizeInMBForOffloading) + return nil + } + + inferredType := validators.LiteralTypeForLiteral(toBeOffloaded) + if inferredType == nil { + errString := "Failed to determine literal type for offloaded literal" + logger.Errorf(ctx, errString) + return fmt.Errorf(errString) + } + + // offload the literal + if err := datastore.WriteProtobuf(ctx, dataReference, storage.Options{}, toBeOffloaded); err != nil { + logger.Errorf(ctx, "Failed to offload literal at location [%s] with error [%s]", dataReference, err) + return err + } + + // update the literal with the offloaded URI, size and inferred type + toBeOffloaded.Value = &idlcore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlcore.LiteralOffloadedMetadata{ + Uri: dataReference.String(), + SizeBytes: uint64(literalSizeBytes), + InferredType: inferredType, + }, + } + logger.Infof(ctx, "Offloaded literal at location [%s] with size [%d] MB and inferred type [%s]", dataReference, literalSizeMB, inferredType) + return nil +} + +// CheckOffloadingCompat checks if the upstream and downstream nodes are compatible with the literal offloading feature and returns an error if not contained in phase info object +func CheckOffloadingCompat(ctx context.Context, nCtx interfaces.NodeExecutionContext, inputLiterals map[string]*core.Literal, node v1alpha1.ExecutableNode, literalOffloadingConfig config.LiteralOffloadingConfig) *handler.PhaseInfo { + consumesOffloadLiteral := false + for _, val := range inputLiterals { + if val != nil && val.GetOffloadedMetadata() != nil { + consumesOffloadLiteral = true + break + } + } + if !consumesOffloadLiteral { + return nil + } + var phaseInfo handler.PhaseInfo + + // Return early if the node is not of type NodeKindTask + if node.GetKind() != v1alpha1.NodeKindTask { + return nil + } + + // Process NodeKindTask + taskID := *node.GetTaskID() + taskNode, err := nCtx.ExecutionContext().GetTask(taskID) + if err != nil { + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "GetTaskIDFailure", err.Error(), nil) + return &phaseInfo + } + runtimeData := taskNode.CoreTask().GetMetadata().GetRuntime() + if !literalOffloadingConfig.IsSupportedSDKVersion(runtimeData.GetType().String(), runtimeData.GetVersion()) { + if !literalOffloadingConfig.Enabled { + errMsg := fmt.Sprintf("task [%s] is trying to consume offloaded literals but feature is not enabled", taskID) + logger.Errorf(ctx, errMsg) + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_USER, "LiteralOffloadingDisabled", errMsg, nil) + return &phaseInfo + } + leastSupportedVersion := literalOffloadingConfig.GetSupportedSDKVersion(runtimeData.GetType().String()) + errMsg := fmt.Sprintf("Literal offloading is not supported for this task as its registered with SDK version [%s] which is less than the least supported version [%s] for this feature", runtimeData.GetVersion(), leastSupportedVersion) + logger.Errorf(ctx, errMsg) + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_USER, "LiteralOffloadingNotSupported", errMsg, nil) + return &phaseInfo + } + + return nil +} diff --git a/flytepropeller/pkg/controller/nodes/common/utils_test.go b/flytepropeller/pkg/controller/nodes/common/utils_test.go index 9e451da69a..7d5ce1e372 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils_test.go +++ b/flytepropeller/pkg/controller/nodes/common/utils_test.go @@ -1,11 +1,22 @@ package common import ( + "context" "testing" "github.com/stretchr/testify/assert" + idlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" + executorMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors/mocks" + nodeMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces/mocks" + "github.com/flyteorg/flyte/flytestdlib/contextutils" + "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/flyteorg/flyte/flytestdlib/promutils/labeled" + "github.com/flyteorg/flyte/flytestdlib/storage" ) type ParentInfo struct { @@ -66,3 +77,177 @@ func TestCreateParentInfoNil(t *testing.T) { assert.Equal(t, uint32(1), parent.CurrentAttempt()) assert.True(t, parent.IsInDynamicChain()) } + +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + +func TestOffloadLargeLiteral(t *testing.T) { + t.Run("offload successful with valid size", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 0, + MaxSizeInMBForOffloading: 1, + } + inferredType := validators.LiteralTypeForLiteral(toBeOffloaded) + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.NoError(t, err) + assert.Equal(t, "foo/bar", toBeOffloaded.GetOffloadedMetadata().GetUri()) + assert.Equal(t, uint64(6), toBeOffloaded.GetOffloadedMetadata().GetSizeBytes()) + assert.Equal(t, inferredType.GetSimple(), toBeOffloaded.GetOffloadedMetadata().InferredType.GetSimple()) + + }) + + t.Run("offload fails with size larger than max", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 0, + MaxSizeInMBForOffloading: 0, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.Error(t, err) + }) + + t.Run("offload not attempted with size smaller than min", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 2, + MaxSizeInMBForOffloading: 3, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.NoError(t, err) + assert.Nil(t, toBeOffloaded.GetOffloadedMetadata()) + }) +} + +func TestCheckOffloadingCompat(t *testing.T) { + ctx := context.Background() + nCtx := &nodeMocks.NodeExecutionContext{} + executionContext := &executorMocks.ExecutionContext{} + executableTask := &mocks.ExecutableTask{} + node := &mocks.ExecutableNode{} + node.OnGetKind().Return(v1alpha1.NodeKindTask) + nCtx.OnExecutionContext().Return(executionContext) + executionContext.OnGetTask("task1").Return(executableTask, nil) + executableTask.OnCoreTask().Return(&idlCore.TaskTemplate{ + Metadata: &idlCore.TaskMetadata{ + Runtime: &idlCore.RuntimeMetadata{ + Type: idlCore.RuntimeMetadata_FLYTE_SDK, + Version: "0.16.0", + }, + }, + }) + taskID := "task1" + node.OnGetTaskID().Return(&taskID) + t.Run("supported version success", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + idlCore.RuntimeMetadata_FLYTE_SDK.String(): "0.16.0", + }, + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.Nil(t, phaseInfo) + }) + t.Run("unsupported version", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + idlCore.RuntimeMetadata_FLYTE_SDK.String(): "0.17.0", + }, + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.NotNil(t, phaseInfo) + assert.Equal(t, idlCore.ExecutionError_USER, phaseInfo.GetErr().GetKind()) + assert.Equal(t, "LiteralOffloadingNotSupported", phaseInfo.GetErr().GetCode()) + }) + t.Run("offloading config disabled with offloaded data", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + Enabled: false, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.NotNil(t, phaseInfo) + assert.Equal(t, idlCore.ExecutionError_USER, phaseInfo.GetErr().GetKind()) + assert.Equal(t, "LiteralOffloadingDisabled", phaseInfo.GetErr().GetCode()) + }) + t.Run("offloading config enabled with no offloaded data", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.Nil(t, phaseInfo) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 47c91edc51..2c3103e4ad 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -491,6 +491,7 @@ type nodeExecutor struct { defaultExecutionDeadline time.Duration enqueueWorkflow v1alpha1.EnqueueWorkflow eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig interruptibleFailureThreshold int32 maxNodeRetriesForSystemFailures uint32 metrics *nodeMetrics @@ -764,6 +765,10 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur } if nodeInputs != nil { + p := common.CheckOffloadingCompat(ctx, nCtx, nodeInputs.Literals, node, c.literalOffloadingConfig) + if p != nil { + return *p, nil + } inputsFile := v1alpha1.GetInputsFile(dataDir) if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { c.metrics.InputsWriteFailure.Inc(ctx) @@ -1417,7 +1422,7 @@ func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructur func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, + catalogClient catalog.Client, recoveryClient recovery.Client, literalOffloadingConfig config.LiteralOffloadingConfig, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, nodeHandlerFactory interfaces.HandlerFactory, scope promutils.Scope) (interfaces.Node, error) { // TODO we may want to make this configurable. @@ -1469,6 +1474,7 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, enqueueWorkflow: enQWorkflow, eventConfig: eventConfig, + literalOffloadingConfig: literalOffloadingConfig, interruptibleFailureThreshold: nodeConfig.InterruptibleFailureThreshold, maxNodeRetriesForSystemFailures: uint32(nodeConfig.MaxNodeRetriesOnSystemFailures), metrics: metrics, diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go index ea7da42112..7fc4c05992 100644 --- a/flytepropeller/pkg/controller/nodes/executor_test.go +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -69,7 +69,7 @@ func TestSetInputsForStartNode(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) exec, err := NewExecutor(ctx, config.GetConfig().NodeConfig, mockStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + adminClient, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) inputs := &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -116,7 +116,7 @@ func TestSetInputsForStartNode(t *testing.T) { failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) execFail, err := NewExecutor(ctx, config.GetConfig().NodeConfig, failStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) t.Run("StorageFailure", func(t *testing.T) { w := createDummyBaseWorkflow(mockStorage) @@ -145,7 +145,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -156,7 +156,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("error")) - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -176,7 +176,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -281,7 +281,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -696,7 +696,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { nodeConfig := config.GetConfig().NodeConfig nodeConfig.EnableCRDebugMetadata = test.enableCRDebugMetadata execIface, err := NewExecutor(ctx, nodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -771,7 +771,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -885,7 +885,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -952,7 +952,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -983,7 +983,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1018,7 +1018,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1131,7 +1131,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1249,7 +1249,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) // Node not yet started @@ -1889,7 +1889,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -2666,7 +2666,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Cache(t *testing.T) { mockHandlerFactory.OnGetHandler(v1alpha1.NodeKindTask).Return(mockHandler, nil) nodeExecutor, err := NewExecutor(ctx, nodeConfig, dataStore, enqueueWorkflow, mockEventSink, adminClient, adminClient, rawOutputPrefix, fakeKubeClient, catalogClient, - recoveryClient, eventConfig, testClusterID, signalClient, mockHandlerFactory, testScope) + recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, mockHandlerFactory, testScope) assert.NoError(t, err) return nodeExecutor diff --git a/flytepropeller/pkg/controller/nodes/factory/handler_factory.go b/flytepropeller/pkg/controller/nodes/factory/handler_factory.go index 424bd15f10..72dcff5310 100644 --- a/flytepropeller/pkg/controller/nodes/factory/handler_factory.go +++ b/flytepropeller/pkg/controller/nodes/factory/handler_factory.go @@ -28,16 +28,17 @@ import ( type handlerFactory struct { handlers map[v1alpha1.NodeKind]interfaces.NodeHandler - workflowLauncher launchplan.Executor - launchPlanReader launchplan.Reader - kubeClient executors.Client - kubeClientset kubernetes.Interface - catalogClient catalog.Client - recoveryClient recovery.Client - eventConfig *config.EventConfig - clusterID string - signalClient service.SignalServiceClient - scope promutils.Scope + workflowLauncher launchplan.Executor + launchPlanReader launchplan.Reader + kubeClient executors.Client + kubeClientset kubernetes.Interface + catalogClient catalog.Client + recoveryClient recovery.Client + eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig + clusterID string + signalClient service.SignalServiceClient + scope promutils.Scope } func (f *handlerFactory) GetHandler(kind v1alpha1.NodeKind) (interfaces.NodeHandler, error) { @@ -54,7 +55,7 @@ func (f *handlerFactory) Setup(ctx context.Context, executor interfaces.Node, se return err } - arrayHandler, err := array.New(executor, f.eventConfig, f.scope) + arrayHandler, err := array.New(executor, f.eventConfig, f.literalOffloadingConfig, f.scope) if err != nil { return err } @@ -79,18 +80,20 @@ func (f *handlerFactory) Setup(ctx context.Context, executor interfaces.Node, se func NewHandlerFactory(ctx context.Context, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, kubeClient executors.Client, kubeClientset kubernetes.Interface, catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, + literalOffloadingConfig config.LiteralOffloadingConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (interfaces.HandlerFactory, error) { return &handlerFactory{ - workflowLauncher: workflowLauncher, - launchPlanReader: launchPlanReader, - kubeClient: kubeClient, - kubeClientset: kubeClientset, - catalogClient: catalogClient, - recoveryClient: recoveryClient, - eventConfig: eventConfig, - clusterID: clusterID, - signalClient: signalClient, - scope: scope, + workflowLauncher: workflowLauncher, + launchPlanReader: launchPlanReader, + kubeClient: kubeClient, + kubeClientset: kubeClientset, + catalogClient: catalogClient, + recoveryClient: recoveryClient, + eventConfig: eventConfig, + literalOffloadingConfig: literalOffloadingConfig, + clusterID: clusterID, + signalClient: signalClient, + scope: scope, }, nil } diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index 85667b0e26..a3d028e94b 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -242,11 +242,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -328,11 +328,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -398,7 +398,7 @@ func BenchmarkWorkflowExecutor(b *testing.B) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() handlerFactory := &nodemocks.HandlerFactory{} nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, scope) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, scope) assert.NoError(b, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -512,7 +512,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -613,11 +613,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() recoveryClient := &recoveryMocks.Client{} - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() assert.NoError(t, err) @@ -685,7 +685,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { diff --git a/go.mod b/go.mod index 8fd55ed61a..8c8053def6 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.4.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 // indirect + github.com/Masterminds/semver v1.5.0 // indirect github.com/NYTimes/gizmo v1.3.6 // indirect github.com/Shopify/sarama v1.26.4 // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect diff --git a/go.sum b/go.sum index ae60f26800..68eebb1fde 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20191210083620-6965a1cf github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625/go.mod h1:6PnrZv6zUDkrNMw0mIoGRmGBR7i9LulhKPmxFq4rUiM= github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= From 367459b4256c46234274bbf83087382110594711 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 18 Sep 2024 00:18:44 -0700 Subject: [PATCH 05/17] pytorch object.inv moved (#5755) Signed-off-by: Yee Hing Tong --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 24f6feb97e..2be3b0185f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -287,7 +287,7 @@ "python": ("https://docs.python.org/3", None), "numpy": ("https://numpy.org/doc/stable", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), - "torch": ("https://pytorch.org/docs/master/", None), + "torch": ("https://pytorch.org/docs/main/", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), "matplotlib": ("https://matplotlib.org", None), "pandera": ("https://pandera.readthedocs.io/en/stable/", None), From 312910d7ae308818a9d294dd76697a666e40fc0c Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 18 Sep 2024 19:15:34 +0800 Subject: [PATCH 06/17] [RFC] Binary IDL With MessagePack Bytes (#5742) --- .../5741-binary-idl-with-message-pack.md | 620 ++++++++++++++++++ 1 file changed, 620 insertions(+) create mode 100644 rfc/system/5741-binary-idl-with-message-pack.md diff --git a/rfc/system/5741-binary-idl-with-message-pack.md b/rfc/system/5741-binary-idl-with-message-pack.md new file mode 100644 index 0000000000..ae04b0903f --- /dev/null +++ b/rfc/system/5741-binary-idl-with-message-pack.md @@ -0,0 +1,620 @@ +# Binary IDL With MessagePack Bytes + +**Authors:** + +- [@Han-Ru](https://github.com/future-outlier) +- [@Yee Hing Tong](https://github.com/wild-endeavor) +- [@Ping-Su](https://github.com/pingsutw) +- [@Eduardo Apolinario](https://github.com/eapolinario) +- [@Haytham Abuelfutuh](https://github.com/EngHabu) +- [@Ketan Umare](https://github.com/kumare3) + +## 1 Executive Summary +### Literal Value +Literal Value will be `Binary`. + +Use `bytes` in `Binary` instead of `Protobuf struct`. + +- To Literal + +| Before | Now | +|-----------------------------------|----------------------------------------------| +| Python Val -> JSON String -> Protobuf Struct | Python Val -> (Dict ->) Bytes -> Binary (value: MessagePack Bytes, tag: msgpack) IDL Object | + +- To Python Value + +| Before | Now | +|-----------------------------------|----------------------------------------------| +| Protobuf Struct -> JSON String -> Python Val | Binary (value: MessagePack Bytes, tag: msgpack) IDL Object -> Bytes -> (Dict ->) -> Python Val | + + +Note: + +1. If a Python value can't be directly converted to `MessagePack Bytes`, we can first convert it to a `Dict`, and then convert it to `MessagePack Bytes`. + + - **For example:** The Pydantic-to-literal function workflow will be: + `BaseModel` -> `dict` -> `MessagePack Bytes` -> `Binary (value: MessagePack Bytes, tag: msgpack) IDL Object`. + + - **For pure `dict` in Python:** The to-literal function workflow will be: + `dict` -> `MessagePack Bytes` -> `Binary (value: MessagePack Bytes, tag: msgpack) IDL Object`. + +2. There is **NO JSON** involved in the new type at all. Only **JSON Schema** is used to construct `DataClass` or `Pydantic BaseModel`. + + +### Literal Type +Literal Type will be `SimpleType.STRUCT`. +`Json Schema` will be stored in `Literal Type's metadata`. + +1. Dataclass, Pydantic BaseModel and pure dict in python will all use `SimpleType.STRUCT`. +2. We will put `Json Schema` in Literal Type's `metadata` field, this will be used in flytekit remote api to construct dataclass/Pydantic BaseModel by `Json Schema`. +3. We will use libraries written in golang to compare `Json Schema` to solve this issue: ["[BUG] Union types fail for e.g. two different dataclasses"](https://github.com/flyteorg/flyte/issues/5489). + +Note: The `metadata` of `Literal Type` and `Literal Value` are not the same. + +## 2 Motivation + +Prior to this RFC, in flytekit, when handling dataclasses, Pydantic base models, and dictionaries, we store data using a JSON string within Protobuf struct datatype. + +This approach causes issues with integers, as Protobuf struct does not support int types, leading to their conversion to floats. + +This results in performance issues since we need to recursively iterate through all attributes/keys in dataclasses and dictionaries to ensure floats types are converted to int. + +In addition to performance issues, the required code is complicated and error prone. + +Note: We have more than 10 issues about dict, dataclass and Pydantic. + +This feature can solve them all. + +## 3 Proposed Implementation +### Before +```python +@task +def t1() -> dict: + ... + return {"a": 1} # Protobuf Struct {"a": 1.0} + +@task +def t2(a: dict): + print(a["integer"]) # wrong, will be a float +``` +### After +```python +@task +def t1() -> dict: # Literal(scalar=Scalar(binary=Binary(value=b'msgpack_bytes', tag="msgpack"))) + ... + return {"a": 1} # Protobuf Binary value=b'\x81\xa1a\x01', produced by msgpack + +@task +def t2(a: dict): + print(a["integer"]) # correct, it will be a integer +``` + +#### Note +- We will use implement `to_python_value` to every type transformer to ensure backward compatibility. +For example, `Binary IDL Object` -> python value and `Protobuf Struct IDL Object` -> python value are both supported. + +### How to turn a value to bytes? +#### Use MsgPack to convert a value into bytes +##### Python +```python +import msgpack + +# Encode +def to_literal(): + msgpack_bytes = msgpack.dumps(python_val) + return Literal(scalar=Scalar(binary=Binary(value=b'msgpack_bytes', tag="msgpack"))) + +# Decode +def to_python_value(): + # lv: literal value + if lv.scalar.binary.tag == "msgpack": + msgpack_bytes = lv.scalar.binary.value + else: + raise ValueError(f"{tag} is not supported to decode this Binary Literal: {lv.scalar.binary}.") + return msgpack.loads(msgpack_bytes) +``` +reference: https://github.com/msgpack/msgpack-python + +##### Golang +```go +package main + +import ( + "fmt" + "github.com/shamaton/msgpack/v2" +) + +func main() { + // Example data to encode + data := map[string]int{"a": 1} + + // Encode the data + encodedData, err := msgpack.Marshal(data) + if err != nil { + panic(err) + } + + // Print the encoded data + fmt.Printf("Encoded data: %x\n", encodedData) // Output: 81a16101 + + // Decode the data + var decodedData map[string]int + err = msgpack.Unmarshal(encodedData, &decodedData) + if err != nil { + panic(err) + } + + // Print the decoded data + fmt.Printf("Decoded data: %+v\n", decodedData) // Output: map[a:1] +} +``` + +reference: [shamaton/msgpack GitHub Repository](https://github.com/shamaton/msgpack) + +Notes: + +1. **MessagePack Implementations**: + - We can explore all MessagePack implementations for Golang at the [MessagePack official website](https://msgpack.org/index.html). + +2. **Library Comparison**: + - The library [github.com/vmihailenco/msgpack](https://github.com/vmihailenco/msgpack) doesn't support strict type deserialization (for example, `map[int]string`), but [github.com/shamaton/msgpack/v2](https://github.com/shamaton/msgpack) supports this feature. This is super important for backward compatibility. + +3. **Library Popularity**: + - While [github.com/shamaton/msgpack/v2](https://github.com/shamaton/msgpack) has fewer stars on GitHub, it has proven to be reliable in various test cases. All cases created by me have passed successfully, which you can find in this [pull request](https://github.com/flyteorg/flytekit/pull/2751). + +4. **Project Activity**: + - [github.com/shamaton/msgpack/v2](https://github.com/shamaton/msgpack) is still an actively maintained project. The author responds quickly to issues and questions, making it a well-supported choice for projects requiring ongoing maintenance and active support. + +5. **Testing Process**: + - I initially started with [github.com/vmihailenco/msgpack](https://github.com/vmihailenco/msgpack) but switched to [github.com/shamaton/msgpack/v2](https://github.com/shamaton/msgpack) due to its better support for strict typing and the active support provided by the author. + + +##### JavaScript +```javascript +import { encode, decode } from '@msgpack/msgpack'; + +// Example data to encode +const data = { a: 1 }; + +// Encode the data +const encodedData = encode(data); + +// Print the encoded data +console.log(encodedData); // + +// Decode the data +const decodedData = decode(encodedData); + +// Print the decoded data +console.log(decodedData); // { a: 1 } +``` +reference: https://github.com/msgpack/msgpack-javascript + +### FlyteIDL +#### Literal Value + +Here is the [IDL definition](https://github.com/flyteorg/flyte/blob/7989209e15600b56fcf0f4c4a7c9af7bfeab6f3e/flyteidl/protos/flyteidl/core/literals.proto#L42-L47). + +The `bytes` field is used for serialized data, and the `tag` field specifies the serialization format identifier. +#### Literal Type +```proto +import "google/protobuf/struct.proto"; + +enum SimpleType { + NONE = 0; + INTEGER = 1; + FLOAT = 2; + STRING = 3; + BOOLEAN = 4; + DATETIME = 5; + DURATION = 6; + BINARY = 7; + ERROR = 8; + STRUCT = 9; // Use this one. +} +message LiteralType { + SimpleType simple = 1; // Use this one. + google.protobuf.Struct metadata = 6; // Store Json Schema to differentiate different dataclass. +} +``` + +### FlytePropeller +1. Attribute Access for dictionary, Dataclass, and Pydantic in workflow. +Dict[type, type] is supported already, we have to support Dataclass, Pydantic and dict now. +```python +from flytekit import task, workflow +from dataclasses import dataclass + +@dataclass +class DC: + a: int + +@task +def t1() -> DC: + return DC(a=1) + +@task +def t2(x: int): + print("x:", x) + return + +@workflow +def wf(): + o = t1() + t2(x=o.a) +``` +2. Create a Literal Type for Scalar when doing type validation. +```go +func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { + ... + case *core.Scalar_Binary: + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} + ... + return literalType +} +``` +3. Support input and default input. +```go +// Literal Input +func ExtractFromLiteral(literal *core.Literal) (interface{}, error) { + switch literalValue := literal.Value.(type) { + case *core.Literal_Scalar: + ... + case *core.Scalar_Binary: + return scalarValue.Binary, nil + } +} +// Default Input +func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) { + switch t := typ.GetType().(type) { + case *core.LiteralType_Simple: + case core.SimpleType_BINARY: + return MakeLiteral([]byte{}) + } +} +// Use Message Pack as Default Tag for deserialization. +// "tag" will default be "msgpack" +func MakeBinaryLiteral(v []byte, tag string) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: v, + Tag: tag, + }, + }, + }, + }, + } +} +``` +4. Compiler +```go +func (t trivialChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + if upstreamType.GetEnumType() != nil { + if t.literalType.GetSimple() == flyte.SimpleType_STRING { + return true + } + } + + if t.literalType.GetEnumType() != nil { + if upstreamType.GetSimple() == flyte.SimpleType_STRING { + return true + } + } + + if GetTagForType(upstreamType) != "" && GetTagForType(t.literalType) != GetTagForType(upstreamType) { + return false + } + + // Here is the new way to check if dataclass/pydantic BaseModel are castable or not. + if upstreamTypeCopy.GetSimple() == flyte.SimpleType_STRUCT &&\ + downstreamTypeCopy.GetSimple() == flyte.SimpleType_STRUCT { + // Json Schema is stored in Metadata + upstreamMetadata := upstreamTypeCopy.GetMetadata() + downstreamMetadata := downstreamTypeCopy.GetMetadata() + + // There's bug in flytekit's dataclass Transformer to generate JSON Scheam before, + // in some case, we the JSON Schema will be nil, so we can only pass it to support + // backward compatible. (reference task should be supported.) + if upstreamMetadata == nil || downstreamMetadata == nil { + return true + } + + return isSameTypeInJSON(upstreamMetadata, downstreamMetadata) ||\ + isSuperTypeInJSON(upstreamMetadata, downstreamMetadata) + } + + upstreamTypeCopy := *upstreamType + downstreamTypeCopy := *t.literalType + upstreamTypeCopy.Structure = &flyte.TypeStructure{} + downstreamTypeCopy.Structure = &flyte.TypeStructure{} + upstreamTypeCopy.Metadata = &structpb.Struct{} + downstreamTypeCopy.Metadata = &structpb.Struct{} + upstreamTypeCopy.Annotation = &flyte.TypeAnnotation{} + downstreamTypeCopy.Annotation = &flyte.TypeAnnotation{} + return upstreamTypeCopy.String() == downstreamTypeCopy.String() +} +``` +### FlyteKit +#### Attribute Access + +In all transformers, we should implement a function called `from_binary_idl` to convert the Binary IDL Object into the desired type. + +A base method can be added to the `TypeTransformer` class, allowing child classes to override it as needed. + +During attribute access, Flyte Propeller will deserialize the msgpack bytes into a map object in golang, retrieve the specific attribute, and then serialize it back into msgpack bytes (resulting in a Binary IDL Object containing msgpack bytes). + +This implies that when converting a literal to a Python value, we will receive `msgpack bytes` instead of the `expected Python type`. + +```python +# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. +# This is relevant for cases like Dict[int, str]. +# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed. +def _default_flytekit_decoder(data: bytes) -> Any: + return msgpack.unpackb(data, raw=False, strict_map_key=False) + + +def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + # Handle msgpack serialization + if binary_idl_object.tag == "msgpack": + try: + # Retrieve the existing decoder for the expected type + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + # Create a new decoder if not already cached + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder) + self._msgpack_decoder[expected_python_type] = decoder + # Decode the binary IDL object into the expected Python type + return decoder.decode(binary_idl_object.value) + else: + # Raise an error if the binary format is not supported + raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") +``` + +Note: +1. This base method can handle primitive types, nested typed dictionaries, nested typed lists, and combinations of nested typed dictionaries and lists. + +2. Dataclass transformer needs its own `from_binary_idl` method to handle specific cases such as [discriminated classes](https://github.com/flyteorg/flyte/issues/5588). + +3. Flyte types (e.g., FlyteFile, FlyteDirectory, StructuredDataset, and FlyteSchema) will need their own `from_binary_idl` methods, as they must handle downloading files from remote object storage when converting literals to Python values. + +For example, see the FlyteFile implementation: https://github.com/flyteorg/flytekit/pull/2751/files#diff-22cf9c7153b54371b4a77331ddf276a082cf4b3c5e7bd1595dd67232288594fdR522-R552 + +#### pyflyte run +The behavior will remain unchanged. + +We will pass the value to our class, which inherits from `click.ParamType`, and use the corresponding type transformer to convert the input to the correct type. + +### Dict Transformer +There are 2 cases in Dict Transformer, `Dict[type, type]` and `dict`. + +For `Dict[type, type]`, we will stay everything the same as before. + +#### Literal Value +For `dict`, the life cycle of it will be as below. + +Before: +- `to_literal`: `dict` -> `JSON String` -> `Protobuf Struct` +- `to_python_val`: `Protobuf Struct` -> `JSON String` -> `dict` + +After: +- `to_literal`: `dict` -> `msgpack bytes` -> `Binary(value=b'msgpack_bytes', tag="msgpack")` +- `to_python_val`: `Binary(value=b'msgpack_bytes', tag="msgpack")` -> `msgpack bytes` -> `dict` + +#### JSON Schema +The JSON Schema of `dict` will be empty. +### Dataclass Transformer +#### Literal Value +Before: +- `to_literal`: `dataclass` -> `JSON String` -> `Protobuf Struct` +- `to_python_val`: `Protobuf Struct` -> `JSON String` -> `dataclass` + +After: +- `to_literal`: `dataclass` -> `msgpack bytes` -> `Binary(value=b'msgpack_bytes', tag="msgpack")` +- `to_python_val`: `Binary(value=b'msgpack_bytes', tag="msgpack")` -> `msgpack bytes` -> `dataclass` + +Note: We will use mashumaro's `MessagePackEncoder` and `MessagePackDecoder` to serialize and deserialize dataclass value in python. +```python +from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder +``` + +#### Literal Type's TypeStructure's dataclass_type +This is used for compiling dataclass attribute access. + +With it, we can retrieve the literal type of an attribute and validate it in Flyte's propeller compiler. + +For more details, check here: https://github.com/flyteorg/flytekit/blob/fb55841f8660b2a31e99381dd06e42f8cd22758e/flytekit/core/type_engine.py#L454-L525 + +#### JSON Schema +The JSON Schema of `dataclass` will be generated by `marshmallow` or `mashumaro`. +Check here: https://github.com/flyteorg/flytekit/blob/8c6f6f0f17d113447e1b10b03e25a34bad79685c/flytekit/core/type_engine.py#L442-L474 + + +### Pydantic Transformer +#### Literal Value +Pydantic can't be serialized to `msgpack bytes` directly. +But `dict` can. + +- `to_literal`: `BaseModel` -> `dict` -> `msgpack bytes` -> `Binary(value=b'msgpack_bytes', tag="msgpack")` +- `to_python_val`: `Binary(value=b'msgpack_bytes', tag="msgpack")` -> `msgpack bytes` -> `dict` -> `BaseModel` + +Note: Pydantic BaseModel can't be serialized directly by `msgpack`, but this implementation will still ensure 100% correct. + +```python +@dataclass +class DC_inside: + a: int + b: float + +@dataclass +class DC: + a: int + b: float + c: str + d: Dict[str, int] + e: DC_inside + +class MyDCModel(BaseModel): + dc: DC + +my_dc = MyDCModel(dc=DC(a=1, b=2.0, c="3", d={"4": 5}, e=DC_inside(a=6, b=7.0))) +# {'dc': {'a': 1, 'b': 2.0, 'c': '3', 'd': {'4': 5}, 'e': {'a': 6, 'b': 7.0}}} +``` + +#### Literal Type's TypeStructure's dataclass_type +This is used for compiling Pydantic BaseModel attribute access. + +With it, we can retrieve an attribute's literal type and validate it in Flyte's propeller compiler. + +Although this feature is not currently implemented, it will function similarly to the dataclass transformer in the future. + +#### JSON Schema +The JSON Schema of `BaseModel` will be generated by Pydantic's API. +```python +@dataclass +class DC_inside: + a: int + b: float + +@dataclass +class DC: + a: int + b: float + c: str + d: Dict[str, int] + e: DC_inside + +class MyDCModel(BaseModel): + dc: DC + +my_dc = MyDCModel(dc=DC(a=1, b=2.0, c="3", d={"4": 5}, e=DC_inside(a=6, b=7.0))) +my_dc.model_json_schema() +""" +{'$defs': {'DC': {'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'number'}, 'c': {'title': 'C', 'type': 'string'}, 'd': {'additionalProperties': {'type': 'integer'}, 'title': 'D', 'type': 'object'}, 'e': {'$ref': '#/$defs/DC_inside'}}, 'required': ['a', 'b', 'c', 'd', 'e'], 'title': 'DC', 'type': 'object'}, 'DC_inside': {'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['a', 'b'], 'title': 'DC_inside', 'type': 'object'}}, 'properties': {'dc': {'$ref': '#/$defs/DC'}}, 'required': ['dc'], 'title': 'MyDCModel', 'type': 'object'} +""" +``` + +### FlyteCtl + +In FlyteCtl, we can construct input for the execution. + +When we receive `SimpleType.STRUCT`, we can construct a `Binary IDL Object` using the following logic in `flyteidl/clients/go/coreutils/literals.go`: + +```go +if newT.Simple == core.SimpleType_STRUCT { + if _, isValueStringType := v.(string); !isValueStringType { + byteValue, err := msgpack.Marshal(v) + if err != nil { + return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v) + } + strValue = string(byteValue) + } +} +``` + +This is how users can create an execution by using a YAML file: +```bash +flytectl create execution --execFile ./flytectl/create_dataclass_task.yaml -p flytesnacks -d development +``` + +Example YAML file (`create_dataclass_task.yaml`): +```yaml +iamRoleARN: "" +inputs: + input: + a: 1 + b: 3.14 + c: example_string + d: + "1": 100 + "2": 200 + e: + a: 1 + b: 3.14 +envs: {} +kubeServiceAcct: "" +targetDomain: "" +targetProject: "" +task: dataclass_example.dataclass_task +version: OSyTikiBTAkjBgrL5JVOVw +``` + +### FlyteCopilot + +When we need to pass an attribute access value to a copilot task, we must modify the code to convert a Binary Literal value with the `msgpack` tag into a primitive value. + +(Currently, we will only support primitive values.) + +You can reference the relevant section of code here: + +[FlyteCopilot - Data Download](https://github.com/flyteorg/flyte/blob/7989209e15600b56fcf0f4c4a7c9af7bfeab6f3e/flytecopilot/data/download.go#L88-L95) + +### FlyteConsole +#### How users input into launch form? +When FlyteConsole receives a literal type of `SimpleType.STRUCT`, the input method depends on the availability of a JSON schema: + +1. No JSON Schema provided: + +Input is expected as `a Javascript Object` (e.g., `{"a": 1}`). + +2. JSON Schema provided: + +Users can input values based on the schema's expected type and construct an appropriate `Javascript Object`. + +Note: + +For `dataclass` and Pydantic `BaseModel`, a JSON schema will be provided in their literal type, and the input form will be constructed accordingly. + +##### What happens after the user enters data? + +Input values -> Javascript Object -> msgpack bytes -> Binary IDL With MessagePack Bytes + +#### Displaying Inputs/Outputs in the Console +Use `msgpack` to deserialize bytes into an Object and display it in Flyte Console. + +#### Copying Inputs/Outputs in the Console +Allow users to copy the `Object` to the clipboard, as currently implemented. + +#### Pasting and Copying from FlyteConsole +Currently, we should support JSON pasting if the content is a JavaScript object. However, there's a question of how we might handle other formats like YAML or MsgPack bytes, especially if copied from a binary file. + +For now, focusing on supporting JSON pasting makes sense. However, adding support for YAML and MsgPack bytes could be valuable future enhancements. + +## 4 Metrics & Dashboards + +None + +## 5 Drawbacks + +None + +## 6 Alternatives + +MsgPack is a good choice because it's more smaller and faster than UTF-8 Encoded JSON String. + +You can see the performance comparison here: https://github.com/flyteorg/flyte/pull/5607#issuecomment-2333174325 + +We will use `msgpack` to do it. + +## 7 Potential Impact and Dependencies +None. + +## 8. Unresolved Questions +### Conditional Branch +Currently, our support for `DataClass/BaseModel/Dict[type, type]` within conditional branches is incomplete. At present, we only support comparisons of primitive types. However, there are two key challenges when attempting to handle these more complex types: + +1. **Primitive Type Comparison vs. Binary IDL Object:** + - In conditional branches, we receive a `Binary IDL Object` during type comparison, which needs to be converted into a `Primitive IDL Object`. + - The issue is that we don't know the expected Python type or primitive type beforehand, making this conversion ambiguous. + +2. **MsgPack Incompatibility with `Primitive_Datetime` and `Primitive_Duration`:** + - MsgPack does not natively support the `Primitive_Datetime` and `Primitive_Duration` types, and instead converts them to strings. + - This can lead to inconsistencies in comparison logic. One potential workaround is to convert both types to strings for comparison. However, it is uncertain whether this approach is the best solution. + +## 9 Conclusion + +1. Binary IDL with MessagePack Bytes provides a better representation for dataclasses, Pydantic BaseModels, and untyped dictionaries in Flyte. + +2. This approach ensures 100% accuracy of each attribute and enables attribute access. From fd841bafa5d0db88ab0a0e05861cf127583a0f48 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Wed, 18 Sep 2024 08:25:49 -0700 Subject: [PATCH 07/17] Rename literal offloading flags and enable offloading in local config (#5754) * Rename literal offloading flags and enable offloading in local config Signed-off-by: Eduardo Apolinario * Run `make -C flytepropeller generate` Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- flyte-single-binary-local.yaml | 2 ++ .../pkg/controller/config/config.go | 8 +++---- .../pkg/controller/config/config_flags.go | 8 +++---- .../controller/config/config_flags_test.go | 24 +++++++++---------- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/flyte-single-binary-local.yaml b/flyte-single-binary-local.yaml index 3a4dc780b2..487eae5bf1 100644 --- a/flyte-single-binary-local.yaml +++ b/flyte-single-binary-local.yaml @@ -24,6 +24,8 @@ propeller: create-flyteworkflow-crd: true kube-config: $HOME/.flyte/sandbox/kubeconfig rawoutput-prefix: s3://my-s3-bucket/data + literal-offloading-config: + enabled: true server: kube-config: $HOME/.flyte/sandbox/kubeconfig diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index 2d61c94970..f045058e95 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -173,17 +173,17 @@ type Config struct { CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"` - LiteralOffloadingConfig LiteralOffloadingConfig `json:"literalOffloadingConfig" pflag:",config used for literal offloading."` + LiteralOffloadingConfig LiteralOffloadingConfig `json:"literal-offloading-config" pflag:",config used for literal offloading."` } type LiteralOffloadingConfig struct { Enabled bool // Maps flytekit and union SDK names to minimum supported version that can handle reading offloaded literals. - SupportedSDKVersions map[string]string + SupportedSDKVersions map[string]string `json:"supported-sdk-versions" pflag:",Maps flytekit and union SDK names to minimum supported version that can handle reading offloaded literals."` // Default, 10Mbs. Determines the size of a literal at which to trigger offloading - MinSizeInMBForOffloading int64 + MinSizeInMBForOffloading int64 `json:"min-size-in-mb-for-offloading" pflag:",Size of a literal at which to trigger offloading"` // Fail fast threshold - MaxSizeInMBForOffloading int64 + MaxSizeInMBForOffloading int64 `json:"max-size-in-mb-for-offloading" pflag:",Size of a literal at which to fail fast"` } // IsSupportedSDKVersion returns true if the provided SDK and version are supported by the literal offloading config. diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index b2e88e88e6..d8496a56fe 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -112,9 +112,9 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "node-execution-worker-count"), defaultConfig.NodeExecutionWorkerCount, "Number of workers to evaluate node executions, currently only used for array nodes") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "array-node-config.event-version"), defaultConfig.ArrayNode.EventVersion, "ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "array-node-config.default-parallelism-behavior"), defaultConfig.ArrayNode.DefaultParallelismBehavior, "Default parallelism behavior for array nodes") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.Enabled"), defaultConfig.LiteralOffloadingConfig.Enabled, "") - cmdFlags.StringToString(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.SupportedSDKVersions"), defaultConfig.LiteralOffloadingConfig.SupportedSDKVersions, "") - cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.MinSizeInMBForOffloading"), defaultConfig.LiteralOffloadingConfig.MinSizeInMBForOffloading, "") - cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.MaxSizeInMBForOffloading"), defaultConfig.LiteralOffloadingConfig.MaxSizeInMBForOffloading, "") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "literal-offloading-config.Enabled"), defaultConfig.LiteralOffloadingConfig.Enabled, "") + cmdFlags.StringToString(fmt.Sprintf("%v%v", prefix, "literal-offloading-config.supported-sdk-versions"), defaultConfig.LiteralOffloadingConfig.SupportedSDKVersions, "Maps flytekit and union SDK names to minimum supported version that can handle reading offloaded literals.") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literal-offloading-config.min-size-in-mb-for-offloading"), defaultConfig.LiteralOffloadingConfig.MinSizeInMBForOffloading, "Size of a literal at which to trigger offloading") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literal-offloading-config.max-size-in-mb-for-offloading"), defaultConfig.LiteralOffloadingConfig.MaxSizeInMBForOffloading, "Size of a literal at which to fail fast") return cmdFlags } diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go index aadb24b36a..109dc47b28 100755 --- a/flytepropeller/pkg/controller/config/config_flags_test.go +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -967,13 +967,13 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_literalOffloadingConfig.Enabled", func(t *testing.T) { + t.Run("Test_literal-offloading-config.Enabled", func(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("literalOffloadingConfig.Enabled", testValue) - if vBool, err := cmdFlags.GetBool("literalOffloadingConfig.Enabled"); err == nil { + cmdFlags.Set("literal-offloading-config.Enabled", testValue) + if vBool, err := cmdFlags.GetBool("literal-offloading-config.Enabled"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LiteralOffloadingConfig.Enabled) } else { @@ -981,13 +981,13 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_literalOffloadingConfig.SupportedSDKVersions", func(t *testing.T) { + t.Run("Test_literal-offloading-config.supported-sdk-versions", func(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "a=1,b=2" - cmdFlags.Set("literalOffloadingConfig.SupportedSDKVersions", testValue) - if vStringToString, err := cmdFlags.GetStringToString("literalOffloadingConfig.SupportedSDKVersions"); err == nil { + cmdFlags.Set("literal-offloading-config.supported-sdk-versions", testValue) + if vStringToString, err := cmdFlags.GetStringToString("literal-offloading-config.supported-sdk-versions"); err == nil { testDecodeRaw_Config(t, vStringToString, &actual.LiteralOffloadingConfig.SupportedSDKVersions) } else { @@ -995,13 +995,13 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_literalOffloadingConfig.MinSizeInMBForOffloading", func(t *testing.T) { + t.Run("Test_literal-offloading-config.min-size-in-mb-for-offloading", func(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("literalOffloadingConfig.MinSizeInMBForOffloading", testValue) - if vInt64, err := cmdFlags.GetInt64("literalOffloadingConfig.MinSizeInMBForOffloading"); err == nil { + cmdFlags.Set("literal-offloading-config.min-size-in-mb-for-offloading", testValue) + if vInt64, err := cmdFlags.GetInt64("literal-offloading-config.min-size-in-mb-for-offloading"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.LiteralOffloadingConfig.MinSizeInMBForOffloading) } else { @@ -1009,13 +1009,13 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_literalOffloadingConfig.MaxSizeInMBForOffloading", func(t *testing.T) { + t.Run("Test_literal-offloading-config.max-size-in-mb-for-offloading", func(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("literalOffloadingConfig.MaxSizeInMBForOffloading", testValue) - if vInt64, err := cmdFlags.GetInt64("literalOffloadingConfig.MaxSizeInMBForOffloading"); err == nil { + cmdFlags.Set("literal-offloading-config.max-size-in-mb-for-offloading", testValue) + if vInt64, err := cmdFlags.GetInt64("literal-offloading-config.max-size-in-mb-for-offloading"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.LiteralOffloadingConfig.MaxSizeInMBForOffloading) } else { From 2d9c35b06f06a642e526c2f56437dfff55d88f35 Mon Sep 17 00:00:00 2001 From: Prafulla Mahindrakar Date: Wed, 18 Sep 2024 08:39:38 -0700 Subject: [PATCH 08/17] Add semver check for dev and beta version (#5757) * Add semver check for dev and beta version Signed-off-by: pmahindrakar-oss * add unsupported version tests Signed-off-by: pmahindrakar-oss * passing context and refactored to table driven utests Signed-off-by: pmahindrakar-oss --------- Signed-off-by: pmahindrakar-oss --- .../pkg/controller/config/config.go | 26 ++- .../pkg/controller/config/config_test.go | 158 +++++++++++++----- .../pkg/controller/nodes/common/utils.go | 2 +- 3 files changed, 134 insertions(+), 52 deletions(-) diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index f045058e95..f9b46a7037 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -36,6 +36,7 @@ package config import ( "context" "fmt" + "regexp" "time" "github.com/Masterminds/semver" @@ -133,6 +134,9 @@ var ( MaxSizeInMBForOffloading: 1000, // 1 GB is the default size before failing fast. }, } + + // This regex is used to sanitize semver versions passed to IsSupportedSDK checks for literal offloading feature. + sanitizeProtoRegex = regexp.MustCompile(`v?(\d+\.\d+\.\d+)`) ) // Config that uses the flytestdlib Config module to generate commandline and load config files. This configuration is @@ -187,18 +191,24 @@ type LiteralOffloadingConfig struct { } // IsSupportedSDKVersion returns true if the provided SDK and version are supported by the literal offloading config. -func (l LiteralOffloadingConfig) IsSupportedSDKVersion(sdk string, versionString string) bool { +func (l LiteralOffloadingConfig) IsSupportedSDKVersion(ctx context.Context, sdk string, versionString string) bool { + regexMatches := sanitizeProtoRegex.FindStringSubmatch(versionString) + if len(regexMatches) > 1 { + logger.Infof(ctx, "original: %s, semVer: %s", versionString, regexMatches[1]) + } else { + logger.Warnf(ctx, "no match found for: %s", versionString) + return false + } + version, err := semver.NewVersion(regexMatches[1]) + if err != nil { + logger.Warnf(ctx, "Failed to parse version %s", versionString) + return false + } if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok { c, err := semver.NewConstraint(fmt.Sprintf(">= %s", leastSupportedVersion)) if err != nil { // This should never happen - logger.Warnf(context.TODO(), "Failed to parse version constraint %s", leastSupportedVersion) - return false - } - version, err := semver.NewVersion(versionString) - if err != nil { - // This should never happen - logger.Warnf(context.TODO(), "Failed to parse version %s", versionString) + logger.Warnf(ctx, "Failed to parse version constraint %s", leastSupportedVersion) return false } return c.Check(version) diff --git a/flytepropeller/pkg/controller/config/config_test.go b/flytepropeller/pkg/controller/config/config_test.go index afc9ed2fea..507643a569 100644 --- a/flytepropeller/pkg/controller/config/config_test.go +++ b/flytepropeller/pkg/controller/config/config_test.go @@ -1,54 +1,126 @@ package config import ( + "context" "testing" "github.com/stretchr/testify/assert" ) func TestIsSupportedSDKVersion(t *testing.T) { - t.Run("supported version", func(t *testing.T) { - config := LiteralOffloadingConfig{ - SupportedSDKVersions: map[string]string{ - "flytekit": "0.16.0", - }, - } - assert.True(t, config.IsSupportedSDKVersion("flytekit", "0.16.0")) - }) - - t.Run("unsupported version", func(t *testing.T) { - config := LiteralOffloadingConfig{ - SupportedSDKVersions: map[string]string{ - "flytekit": "0.16.0", - }, - } - assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.15.0")) - }) - - t.Run("unsupported SDK", func(t *testing.T) { - config := LiteralOffloadingConfig{ - SupportedSDKVersions: map[string]string{ - "flytekit": "0.16.0", - }, - } - assert.False(t, config.IsSupportedSDKVersion("unknown", "0.16.0")) - }) - - t.Run("invalid version", func(t *testing.T) { - config := LiteralOffloadingConfig{ - SupportedSDKVersions: map[string]string{ - "flytekit": "0.16.0", - }, - } - assert.False(t, config.IsSupportedSDKVersion("flytekit", "invalid")) - }) + ctx := context.Background() + tests := []struct { + name string + config LiteralOffloadingConfig + sdk string + version string + expectedResult bool + }{ + { + name: "supported version", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + }, + sdk: "flytekit", + version: "0.16.0", + expectedResult: true, + }, + { + name: "unsupported version", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + }, + sdk: "flytekit", + version: "0.15.0", + expectedResult: false, + }, + { + name: "unsupported SDK", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + }, + sdk: "unknown", + version: "0.16.0", + expectedResult: false, + }, + { + name: "invalid version", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + }, + sdk: "flytekit", + version: "invalid", + expectedResult: false, + }, + { + name: "invalid constraint", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "invalid", + }, + }, + sdk: "flytekit", + version: "0.16.0", + expectedResult: false, + }, + { + name: "supported dev version", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "1.13.4", + }, + }, + sdk: "flytekit", + version: "1.13.4.dev12+g990b450ea.d20240917", + expectedResult: true, + }, + { + name: "supported beta version", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "1.13.4", + }, + }, + sdk: "flytekit", + version: "v1.13.6b0", + expectedResult: true, + }, + { + name: "unsupported dev version", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "1.13.4", + }, + }, + sdk: "flytekit", + version: "1.13.3.dev12+g990b450ea.d20240917", + expectedResult: false, + }, + { + name: "unsupported beta version", + config: LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "1.13.4", + }, + }, + sdk: "flytekit", + version: "v1.13.3b0", + expectedResult: false, + }, + } - t.Run("invalid constraint", func(t *testing.T) { - config := LiteralOffloadingConfig{ - SupportedSDKVersions: map[string]string{ - "flytekit": "invalid", - }, - } - assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.16.0")) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.config.IsSupportedSDKVersion(ctx, tt.sdk, tt.version) + assert.Equal(t, tt.expectedResult, result) + }) + } } diff --git a/flytepropeller/pkg/controller/nodes/common/utils.go b/flytepropeller/pkg/controller/nodes/common/utils.go index 89bb0afe2e..b02d830fe9 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils.go +++ b/flytepropeller/pkg/controller/nodes/common/utils.go @@ -147,7 +147,7 @@ func CheckOffloadingCompat(ctx context.Context, nCtx interfaces.NodeExecutionCon return &phaseInfo } runtimeData := taskNode.CoreTask().GetMetadata().GetRuntime() - if !literalOffloadingConfig.IsSupportedSDKVersion(runtimeData.GetType().String(), runtimeData.GetVersion()) { + if !literalOffloadingConfig.IsSupportedSDKVersion(ctx, runtimeData.GetType().String(), runtimeData.GetVersion()) { if !literalOffloadingConfig.Enabled { errMsg := fmt.Sprintf("task [%s] is trying to consume offloaded literals but feature is not enabled", taskID) logger.Errorf(ctx, errMsg) From 792ce1ab21c082c88299da04413d0c348bfae851 Mon Sep 17 00:00:00 2001 From: Prafulla Mahindrakar Date: Wed, 18 Sep 2024 11:50:32 -0700 Subject: [PATCH 09/17] Add comments for semver regex (#5758) Signed-off-by: pmahindrakar-oss --- flytepropeller/pkg/controller/config/config.go | 1 + 1 file changed, 1 insertion(+) diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index f9b46a7037..488ada1127 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -136,6 +136,7 @@ var ( } // This regex is used to sanitize semver versions passed to IsSupportedSDK checks for literal offloading feature. + // It matches against 1.13.3 in v1.13.3b0 (beta version) or 1.13.3 in 1.13.3.dev12+g990b450ea.d20240917(dev version) sanitizeProtoRegex = regexp.MustCompile(`v?(\d+\.\d+\.\d+)`) ) From 0eaa20e8b92245555be8fd725f97e52fbed54e78 Mon Sep 17 00:00:00 2001 From: Christoph Paulik Date: Wed, 18 Sep 2024 21:20:56 +0200 Subject: [PATCH 10/17] Without this caching is actually not enabled (#5739) Signed-off-by: Christoph Paulik Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- charts/flyte-core/templates/propeller/configmap.yaml | 7 ++++--- .../eks/flyte_aws_scheduler_helm_generated.yaml | 11 ++++++----- deployment/eks/flyte_helm_dataplane_generated.yaml | 11 ++++++----- deployment/eks/flyte_helm_generated.yaml | 11 ++++++----- deployment/gcp/flyte_helm_dataplane_generated.yaml | 11 ++++++----- deployment/gcp/flyte_helm_generated.yaml | 11 ++++++----- deployment/sandbox/flyte_helm_generated.yaml | 11 ++++++----- docker/sandbox-bundled/manifests/complete-agent.yaml | 4 ++-- docker/sandbox-bundled/manifests/complete.yaml | 4 ++-- docker/sandbox-bundled/manifests/dev.yaml | 4 ++-- 10 files changed, 46 insertions(+), 39 deletions(-) diff --git a/charts/flyte-core/templates/propeller/configmap.yaml b/charts/flyte-core/templates/propeller/configmap.yaml index 65a5458cfc..3e1b94ef2c 100644 --- a/charts/flyte-core/templates/propeller/configmap.yaml +++ b/charts/flyte-core/templates/propeller/configmap.yaml @@ -48,9 +48,10 @@ data: {{- end }} storage.yaml: | {{ tpl (include "storage" .) $ | nindent 4 }} cache.yaml: | - cache: - max_size_mbs: {{ .Values.flytepropeller.cacheSizeMbs }} - target_gc_percent: 70 + storage: + cache: + max_size_mbs: {{ .Values.flytepropeller.cacheSizeMbs }} + target_gc_percent: 70 {{- with .Values.configmap.task_logs }} task_logs.yaml: | {{ tpl (toYaml .) $ | nindent 4 }} {{- end }} diff --git a/deployment/eks/flyte_aws_scheduler_helm_generated.yaml b/deployment/eks/flyte_aws_scheduler_helm_generated.yaml index dc4ac5e800..9ee5eecc49 100644 --- a/deployment/eks/flyte_aws_scheduler_helm_generated.yaml +++ b/deployment/eks/flyte_aws_scheduler_helm_generated.yaml @@ -513,9 +513,10 @@ data: limits: maxDownloadMBs: 10 cache.yaml: | - cache: - max_size_mbs: 1024 - target_gc_percent: 70 + storage: + cache: + max_size_mbs: 1024 + target_gc_percent: 70 task_logs.yaml: | plugins: logs: @@ -1278,7 +1279,7 @@ spec: template: metadata: annotations: - configChecksum: "33bc4dd986fdb015ce49d998deedf122e119579ec09db311e67276230678a70" + configChecksum: "16400bbf28ab823ca433a04379f00fecb11530c3c265dd1bdc3e30209ac6a6d" prometheus.io/path: "/metrics" prometheus.io/port: "10254" labels: @@ -1362,7 +1363,7 @@ spec: app.kubernetes.io/name: flyte-pod-webhook app.kubernetes.io/version: v1.13.1 annotations: - configChecksum: "33bc4dd986fdb015ce49d998deedf122e119579ec09db311e67276230678a70" + configChecksum: "16400bbf28ab823ca433a04379f00fecb11530c3c265dd1bdc3e30209ac6a6d" prometheus.io/path: "/metrics" prometheus.io/port: "10254" spec: diff --git a/deployment/eks/flyte_helm_dataplane_generated.yaml b/deployment/eks/flyte_helm_dataplane_generated.yaml index 03640c4c05..aa649b24ea 100644 --- a/deployment/eks/flyte_helm_dataplane_generated.yaml +++ b/deployment/eks/flyte_helm_dataplane_generated.yaml @@ -177,9 +177,10 @@ data: limits: maxDownloadMBs: 10 cache.yaml: | - cache: - max_size_mbs: 1024 - target_gc_percent: 70 + storage: + cache: + max_size_mbs: 1024 + target_gc_percent: 70 task_logs.yaml: | plugins: logs: @@ -429,7 +430,7 @@ spec: template: metadata: annotations: - configChecksum: "33bc4dd986fdb015ce49d998deedf122e119579ec09db311e67276230678a70" + configChecksum: "16400bbf28ab823ca433a04379f00fecb11530c3c265dd1bdc3e30209ac6a6d" prometheus.io/path: "/metrics" prometheus.io/port: "10254" labels: @@ -513,7 +514,7 @@ spec: app.kubernetes.io/name: flyte-pod-webhook app.kubernetes.io/version: v1.13.1 annotations: - configChecksum: "33bc4dd986fdb015ce49d998deedf122e119579ec09db311e67276230678a70" + configChecksum: "16400bbf28ab823ca433a04379f00fecb11530c3c265dd1bdc3e30209ac6a6d" prometheus.io/path: "/metrics" prometheus.io/port: "10254" spec: diff --git a/deployment/eks/flyte_helm_generated.yaml b/deployment/eks/flyte_helm_generated.yaml index c2e861857c..6e95514c44 100644 --- a/deployment/eks/flyte_helm_generated.yaml +++ b/deployment/eks/flyte_helm_generated.yaml @@ -544,9 +544,10 @@ data: limits: maxDownloadMBs: 10 cache.yaml: | - cache: - max_size_mbs: 1024 - target_gc_percent: 70 + storage: + cache: + max_size_mbs: 1024 + target_gc_percent: 70 task_logs.yaml: | plugins: logs: @@ -1408,7 +1409,7 @@ spec: template: metadata: annotations: - configChecksum: "33bc4dd986fdb015ce49d998deedf122e119579ec09db311e67276230678a70" + configChecksum: "16400bbf28ab823ca433a04379f00fecb11530c3c265dd1bdc3e30209ac6a6d" prometheus.io/path: "/metrics" prometheus.io/port: "10254" labels: @@ -1492,7 +1493,7 @@ spec: app.kubernetes.io/name: flyte-pod-webhook app.kubernetes.io/version: v1.13.1 annotations: - configChecksum: "33bc4dd986fdb015ce49d998deedf122e119579ec09db311e67276230678a70" + configChecksum: "16400bbf28ab823ca433a04379f00fecb11530c3c265dd1bdc3e30209ac6a6d" prometheus.io/path: "/metrics" prometheus.io/port: "10254" spec: diff --git a/deployment/gcp/flyte_helm_dataplane_generated.yaml b/deployment/gcp/flyte_helm_dataplane_generated.yaml index be2feeb698..40f2542640 100644 --- a/deployment/gcp/flyte_helm_dataplane_generated.yaml +++ b/deployment/gcp/flyte_helm_dataplane_generated.yaml @@ -180,9 +180,10 @@ data: limits: maxDownloadMBs: 10 cache.yaml: | - cache: - max_size_mbs: 1024 - target_gc_percent: 70 + storage: + cache: + max_size_mbs: 1024 + target_gc_percent: 70 task_logs.yaml: | plugins: k8s-array: @@ -437,7 +438,7 @@ spec: template: metadata: annotations: - configChecksum: "64a3f5e546eddd8126d03c005460b964f428a0da31c0bfac5f70c63fbf3d635" + configChecksum: "b11d6c5dd0dd16bced82fb44fd2fb31c2f27a134543ab7edaf060da42ab27e4" prometheus.io/path: "/metrics" prometheus.io/port: "10254" labels: @@ -520,7 +521,7 @@ spec: app.kubernetes.io/name: flyte-pod-webhook app.kubernetes.io/version: v1.13.1 annotations: - configChecksum: "64a3f5e546eddd8126d03c005460b964f428a0da31c0bfac5f70c63fbf3d635" + configChecksum: "b11d6c5dd0dd16bced82fb44fd2fb31c2f27a134543ab7edaf060da42ab27e4" prometheus.io/path: "/metrics" prometheus.io/port: "10254" spec: diff --git a/deployment/gcp/flyte_helm_generated.yaml b/deployment/gcp/flyte_helm_generated.yaml index acd3985c6d..6f75ece6a1 100644 --- a/deployment/gcp/flyte_helm_generated.yaml +++ b/deployment/gcp/flyte_helm_generated.yaml @@ -560,9 +560,10 @@ data: limits: maxDownloadMBs: 10 cache.yaml: | - cache: - max_size_mbs: 1024 - target_gc_percent: 70 + storage: + cache: + max_size_mbs: 1024 + target_gc_percent: 70 task_logs.yaml: | plugins: k8s-array: @@ -1431,7 +1432,7 @@ spec: template: metadata: annotations: - configChecksum: "64a3f5e546eddd8126d03c005460b964f428a0da31c0bfac5f70c63fbf3d635" + configChecksum: "b11d6c5dd0dd16bced82fb44fd2fb31c2f27a134543ab7edaf060da42ab27e4" prometheus.io/path: "/metrics" prometheus.io/port: "10254" labels: @@ -1514,7 +1515,7 @@ spec: app.kubernetes.io/name: flyte-pod-webhook app.kubernetes.io/version: v1.13.1 annotations: - configChecksum: "64a3f5e546eddd8126d03c005460b964f428a0da31c0bfac5f70c63fbf3d635" + configChecksum: "b11d6c5dd0dd16bced82fb44fd2fb31c2f27a134543ab7edaf060da42ab27e4" prometheus.io/path: "/metrics" prometheus.io/port: "10254" spec: diff --git a/deployment/sandbox/flyte_helm_generated.yaml b/deployment/sandbox/flyte_helm_generated.yaml index c68497cf1b..9072c486d3 100644 --- a/deployment/sandbox/flyte_helm_generated.yaml +++ b/deployment/sandbox/flyte_helm_generated.yaml @@ -683,9 +683,10 @@ data: limits: maxDownloadMBs: 10 cache.yaml: | - cache: - max_size_mbs: 0 - target_gc_percent: 70 + storage: + cache: + max_size_mbs: 0 + target_gc_percent: 70 task_logs.yaml: | plugins: logs: @@ -7182,7 +7183,7 @@ spec: template: metadata: annotations: - configChecksum: "84d449758d51dc641aff55ae07f4376a860b5038e8407cb9d2444c4f895d953" + configChecksum: "8a003fdbed4b3801328c26cb5c202e4ca113875347388366d29a302d610e7e4" prometheus.io/path: "/metrics" prometheus.io/port: "10254" labels: @@ -7258,7 +7259,7 @@ spec: app.kubernetes.io/name: flyte-pod-webhook app.kubernetes.io/version: v1.13.1 annotations: - configChecksum: "84d449758d51dc641aff55ae07f4376a860b5038e8407cb9d2444c4f895d953" + configChecksum: "8a003fdbed4b3801328c26cb5c202e4ca113875347388366d29a302d610e7e4" prometheus.io/path: "/metrics" prometheus.io/port: "10254" spec: diff --git a/docker/sandbox-bundled/manifests/complete-agent.yaml b/docker/sandbox-bundled/manifests/complete-agent.yaml index 1c7f42143a..761469bd83 100644 --- a/docker/sandbox-bundled/manifests/complete-agent.yaml +++ b/docker/sandbox-bundled/manifests/complete-agent.yaml @@ -816,7 +816,7 @@ type: Opaque --- apiVersion: v1 data: - haSharedSecret: bTM0eldMZ3IxWE1jZXZ3cg== + haSharedSecret: aVh1N3lZb0F1c2l0NHVuRg== proxyPassword: "" proxyUsername: "" kind: Secret @@ -1413,7 +1413,7 @@ spec: metadata: annotations: checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81 - checksum/secret: f6b8f4ae413528aa882ed401dca20e11f8947e1ca829df6bedbf16698de03cdf + checksum/secret: 042e6b21a3852a65952e0701cd9667e53bfef57590eea4d116b261472f29a882 labels: app: docker-registry release: flyte-sandbox diff --git a/docker/sandbox-bundled/manifests/complete.yaml b/docker/sandbox-bundled/manifests/complete.yaml index 1cb7b26c28..e80cc05d20 100644 --- a/docker/sandbox-bundled/manifests/complete.yaml +++ b/docker/sandbox-bundled/manifests/complete.yaml @@ -798,7 +798,7 @@ type: Opaque --- apiVersion: v1 data: - haSharedSecret: Wkl6V2diWk50S1Ftd1R3Nw== + haSharedSecret: Q2EyanRtd1JjWmVKS2tHMw== proxyPassword: "" proxyUsername: "" kind: Secret @@ -1362,7 +1362,7 @@ spec: metadata: annotations: checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81 - checksum/secret: ab69c0935db6eb0dd11b9df6716336d8f993cb949d112fd0417d1133a88000cc + checksum/secret: 01201329c98a1417f04feeef00dc21e67cf73d99ac9b99486ce5788eca0c282c labels: app: docker-registry release: flyte-sandbox diff --git a/docker/sandbox-bundled/manifests/dev.yaml b/docker/sandbox-bundled/manifests/dev.yaml index ae2244449e..fe04e4c059 100644 --- a/docker/sandbox-bundled/manifests/dev.yaml +++ b/docker/sandbox-bundled/manifests/dev.yaml @@ -499,7 +499,7 @@ metadata: --- apiVersion: v1 data: - haSharedSecret: Z1BQdzNCck9SbGF2bFBNcg== + haSharedSecret: N3dIemE2TnF1b3l1SWdNTw== proxyPassword: "" proxyUsername: "" kind: Secret @@ -934,7 +934,7 @@ spec: metadata: annotations: checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81 - checksum/secret: 42de0a9877ccf019375a303e7a813c44ec6dc92758e3af6c0f10aeb4c3a1071c + checksum/secret: ca7423805c2fd3a98507c790af575a6e5389f50d6baa09bd8c49cb59c4452340 labels: app: docker-registry release: flyte-sandbox From 1502d234a3ddea8018ec779f1f3d79c64b70cd9e Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 19 Sep 2024 14:36:00 +0800 Subject: [PATCH 11/17] [Flyte][1][IDL] Binary IDL With MessagePack (#5751) * [Flyte][1][IDL] Binary IDL With MessagePack Signed-off-by: Future-Outlier * add comments in compiler Signed-off-by: Future-Outlier --------- Signed-off-by: Future-Outlier --- flyteidl/clients/go/assets/admin.swagger.json | 38 +++++++++++++------ .../gen/pb-es/flyteidl/core/literals_pb.ts | 4 ++ .../gen/pb-go/flyteidl/core/literals.pb.go | 4 +- .../cacheservice/cacheservice.swagger.json | 6 ++- .../datacatalog/datacatalog.swagger.json | 6 ++- .../flyteidl/service/admin.swagger.json | 6 ++- .../flyteidl/service/agent.swagger.json | 6 ++- .../flyteidl/service/dataproxy.swagger.json | 6 ++- .../external_plugin_service.swagger.json | 6 ++- .../flyteidl/service/signal.swagger.json | 6 ++- flyteidl/gen/pb_rust/flyteidl.core.rs | 2 + flyteidl/protos/flyteidl/core/literals.proto | 4 +- .../pkg/compiler/validators/bindings.go | 2 +- 13 files changed, 66 insertions(+), 30 deletions(-) diff --git a/flyteidl/clients/go/assets/admin.swagger.json b/flyteidl/clients/go/assets/admin.swagger.json index 6ebfd70f8d..241baeb53c 100644 --- a/flyteidl/clients/go/assets/admin.swagger.json +++ b/flyteidl/clients/go/assets/admin.swagger.json @@ -6661,10 +6661,12 @@ "properties": { "value": { "type": "string", - "format": "byte" + "format": "byte", + "description": "Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict." }, "tag": { - "type": "string" + "type": "string", + "description": "The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization." } }, "description": "A simple byte array with a tag to help different parts of the system communicate about what is in the byte array.\nIt's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data." @@ -7418,6 +7420,10 @@ "$ref": "#/definitions/coreLiteralMap", "description": "A map of strings to literals." }, + "offloaded_metadata": { + "$ref": "#/definitions/coreLiteralOffloadedMetadata", + "description": "Offloaded literal metadata\nWhen you deserialize the offloaded metadata, it would be of Literal and its type would be defined by LiteralType stored in offloaded_metadata." + }, "hash": { "type": "string", "title": "A hash representing this literal.\nThis is used for caching purposes. For more details refer to RFC 1893\n(https://github.com/flyteorg/flyte/blob/master/rfc/system/1893-caching-of-offloaded-objects.md)" @@ -7428,15 +7434,6 @@ "type": "string" }, "description": "Additional metadata for literals." - }, - "uri": { - "type": "string", - "description": "If this literal is offloaded, this field will contain metadata including the offload location." - }, - "size_bytes": { - "type": "string", - "format": "uint64", - "description": "Includes information about the size of the literal." } }, "description": "A simple value. This supports any level of nesting (e.g. array of array of array of Blobs) as well as simple primitives." @@ -7466,6 +7463,25 @@ }, "description": "A map of literals. This is a workaround since oneofs in proto messages cannot contain a repeated field." }, + "coreLiteralOffloadedMetadata": { + "type": "object", + "properties": { + "uri": { + "type": "string", + "description": "The location of the offloaded core.Literal." + }, + "size_bytes": { + "type": "string", + "format": "uint64", + "description": "The size of the offloaded data." + }, + "inferred_type": { + "$ref": "#/definitions/coreLiteralType", + "description": "The inferred literal type of the offloaded data." + } + }, + "description": "A message that contains the metadata of the offloaded data." + }, "coreLiteralType": { "type": "object", "properties": { diff --git a/flyteidl/gen/pb-es/flyteidl/core/literals_pb.ts b/flyteidl/gen/pb-es/flyteidl/core/literals_pb.ts index 95ebbd9de9..69fcde4375 100644 --- a/flyteidl/gen/pb-es/flyteidl/core/literals_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/core/literals_pb.ts @@ -215,11 +215,15 @@ export class BlobMetadata extends Message { */ export class Binary extends Message { /** + * Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict. + * * @generated from field: bytes value = 1; */ value = new Uint8Array(0); /** + * The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization. + * * @generated from field: string tag = 2; */ tag = ""; diff --git a/flyteidl/gen/pb-go/flyteidl/core/literals.pb.go b/flyteidl/gen/pb-go/flyteidl/core/literals.pb.go index 3f6e223749..2225e74077 100644 --- a/flyteidl/gen/pb-go/flyteidl/core/literals.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/core/literals.pb.go @@ -315,8 +315,8 @@ type Binary struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` - Tag string `protobuf:"bytes,2,opt,name=tag,proto3" json:"tag,omitempty"` + Value []byte `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` // Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict. + Tag string `protobuf:"bytes,2,opt,name=tag,proto3" json:"tag,omitempty"` // The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization. } func (x *Binary) Reset() { diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/cacheservice/cacheservice.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/cacheservice/cacheservice.swagger.json index c30c350754..204e9e7122 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/cacheservice/cacheservice.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/cacheservice/cacheservice.swagger.json @@ -117,10 +117,12 @@ "properties": { "value": { "type": "string", - "format": "byte" + "format": "byte", + "description": "Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict." }, "tag": { - "type": "string" + "type": "string", + "description": "The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization." } }, "description": "A simple byte array with a tag to help different parts of the system communicate about what is in the byte array.\nIt's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data." diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/datacatalog/datacatalog.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/datacatalog/datacatalog.swagger.json index dbfcc5b85e..990cc1ec4a 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/datacatalog/datacatalog.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/datacatalog/datacatalog.swagger.json @@ -91,10 +91,12 @@ "properties": { "value": { "type": "string", - "format": "byte" + "format": "byte", + "description": "Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict." }, "tag": { - "type": "string" + "type": "string", + "description": "The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization." } }, "description": "A simple byte array with a tag to help different parts of the system communicate about what is in the byte array.\nIt's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data." diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json index ef81380d1e..241baeb53c 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json @@ -6661,10 +6661,12 @@ "properties": { "value": { "type": "string", - "format": "byte" + "format": "byte", + "description": "Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict." }, "tag": { - "type": "string" + "type": "string", + "description": "The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization." } }, "description": "A simple byte array with a tag to help different parts of the system communicate about what is in the byte array.\nIt's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data." diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/agent.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/agent.swagger.json index 373b9c4c3d..070b6a8c60 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/agent.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/agent.swagger.json @@ -903,10 +903,12 @@ "properties": { "value": { "type": "string", - "format": "byte" + "format": "byte", + "description": "Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict." }, "tag": { - "type": "string" + "type": "string", + "description": "The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization." } }, "description": "A simple byte array with a tag to help different parts of the system communicate about what is in the byte array.\nIt's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data." diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/dataproxy.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/dataproxy.swagger.json index bff6ca737a..20f32b743d 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/dataproxy.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/dataproxy.swagger.json @@ -207,10 +207,12 @@ "properties": { "value": { "type": "string", - "format": "byte" + "format": "byte", + "description": "Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict." }, "tag": { - "type": "string" + "type": "string", + "description": "The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization." } }, "description": "A simple byte array with a tag to help different parts of the system communicate about what is in the byte array.\nIt's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data." diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/external_plugin_service.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/external_plugin_service.swagger.json index 029c42ffd3..e690cc556c 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/external_plugin_service.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/external_plugin_service.swagger.json @@ -233,10 +233,12 @@ "properties": { "value": { "type": "string", - "format": "byte" + "format": "byte", + "description": "Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict." }, "tag": { - "type": "string" + "type": "string", + "description": "The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization." } }, "description": "A simple byte array with a tag to help different parts of the system communicate about what is in the byte array.\nIt's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data." diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/signal.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/signal.swagger.json index d325ed4764..841cb04f26 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/signal.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/signal.swagger.json @@ -283,10 +283,12 @@ "properties": { "value": { "type": "string", - "format": "byte" + "format": "byte", + "description": "Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict." }, "tag": { - "type": "string" + "type": "string", + "description": "The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization." } }, "description": "A simple byte array with a tag to help different parts of the system communicate about what is in the byte array.\nIt's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data." diff --git a/flyteidl/gen/pb_rust/flyteidl.core.rs b/flyteidl/gen/pb_rust/flyteidl.core.rs index 441609be89..bfbf82203d 100644 --- a/flyteidl/gen/pb_rust/flyteidl.core.rs +++ b/flyteidl/gen/pb_rust/flyteidl.core.rs @@ -400,8 +400,10 @@ pub struct BlobMetadata { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Binary { + /// Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict. #[prost(bytes="vec", tag="1")] pub value: ::prost::alloc::vec::Vec, + /// The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization. #[prost(string, tag="2")] pub tag: ::prost::alloc::string::String, } diff --git a/flyteidl/protos/flyteidl/core/literals.proto b/flyteidl/protos/flyteidl/core/literals.proto index 1eb004482c..66e4821867 100644 --- a/flyteidl/protos/flyteidl/core/literals.proto +++ b/flyteidl/protos/flyteidl/core/literals.proto @@ -42,8 +42,8 @@ message BlobMetadata { // A simple byte array with a tag to help different parts of the system communicate about what is in the byte array. // It's strongly advisable that consumers of this type define a unique tag and validate the tag before parsing the data. message Binary { - bytes value = 1; - string tag = 2; + bytes value = 1; // Serialized data (MessagePack) for supported types like Dataclass, Pydantic BaseModel, and untyped dict. + string tag = 2; // The serialization format identifier (e.g., MessagePack). Consumers must define unique tags and validate them before deserialization. } // A strongly typed schema that defines the interface of data retrieved from the underlying storage medium. diff --git a/flytepropeller/pkg/compiler/validators/bindings.go b/flytepropeller/pkg/compiler/validators/bindings.go index 337d04966d..53535ba260 100644 --- a/flytepropeller/pkg/compiler/validators/bindings.go +++ b/flytepropeller/pkg/compiler/validators/bindings.go @@ -147,7 +147,7 @@ func validateBinding(w c.WorkflowBuilder, node c.Node, nodeParam string, binding } else if sourceType.GetMapValueType() != nil { sourceType = sourceType.GetMapValueType() } else if sourceType.GetStructure() != nil && sourceType.GetStructure().GetDataclassType() != nil { - + // This is for retrieving the literal type of an attribute in a dataclass or Pydantic BaseModel tmpType, exist = sourceType.GetStructure().GetDataclassType()[attr.GetStringValue()] if !exist { From 28f65b30ed745af683e80f9c95400fda411d9c07 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 19 Sep 2024 17:19:27 +0800 Subject: [PATCH 12/17] [Flyte][2][Literal Type For Scalar] Binary IDL With MessagePack (#5761) * [Flyte][1][IDL] Binary IDL With MessagePack Signed-off-by: Future-Outlier * add comments in compiler Signed-off-by: Future-Outlier * [Flyte][2][Literal Type For Scalar] Binary IDL With MessagePack Signed-off-by: Future-Outlier * fix propeller's test Signed-off-by: Future-Outlier * add tests for float as input Signed-off-by: Future-Outlier * fix propeller's test Signed-off-by: Future-Outlier --------- Signed-off-by: Future-Outlier --- flyteadmin/go.sum | 2 + flytepropeller/go.mod | 1 + flytepropeller/go.sum | 2 + .../pkg/compiler/validators/bindings_test.go | 13 ++-- .../pkg/compiler/validators/utils.go | 9 ++- .../pkg/compiler/validators/utils_test.go | 77 +++++++++++++++++++ go.sum | 2 + 7 files changed, 98 insertions(+), 8 deletions(-) diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 5b7d47b2a6..31a1714ed7 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -1174,6 +1174,8 @@ github.com/sendgrid/sendgrid-go v3.10.0+incompatible/go.mod h1:QRQt+LX/NmgVEvmdR github.com/serenize/snaker v0.0.0-20171204205717-a683aaf2d516/go.mod h1:Yow6lPLSAXx2ifx470yD/nUe22Dv5vBvxK/UK9UUTVs= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index f579049aff..6b55e8909e 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -22,6 +22,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.16.0 + github.com/shamaton/msgpack/v2 v2.2.2 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 07a92b902b..dc0ccb0464 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -374,6 +374,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= diff --git a/flytepropeller/pkg/compiler/validators/bindings_test.go b/flytepropeller/pkg/compiler/validators/bindings_test.go index 7e5b388391..bcb498eebd 100644 --- a/flytepropeller/pkg/compiler/validators/bindings_test.go +++ b/flytepropeller/pkg/compiler/validators/bindings_test.go @@ -776,9 +776,9 @@ func TestValidateBindings(t *testing.T) { _, ok := ValidateBindings(wf, n, bindings, vars, true, c.EdgeDirectionBidirectional, compileErrors) assert.False(t, ok) assert.Equal(t, "MismatchingTypes", string(compileErrors.Errors().List()[0].Code())) - assert.Equal(t, "Code: MismatchingTypes, Node Id: node1, Description: Variable [x]"+ - " (type [union_type:{variants:{simple:INTEGER structure:{tag:\"int\"}}}]) doesn't match expected type"+ - " [union_type:{variants:{simple:INTEGER structure:{tag:\"int_other\"}}}].", compileErrors.Errors().List()[0].Error()) + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "Code: MismatchingTypes, Node Id: node1, Description: Variable [x]") + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "(type [union_type:{variants:{simple:INTEGER") + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "doesn't match expected type") }) t.Run("List of Int to List of Unions Binding", func(t *testing.T) { @@ -1210,10 +1210,9 @@ func TestValidateBindings(t *testing.T) { _, ok := ValidateBindings(wf, n, bindings, vars, true, c.EdgeDirectionBidirectional, compileErrors) assert.False(t, ok) assert.Equal(t, "MismatchingTypes", string(compileErrors.Errors().List()[0].Code())) - assert.Equal(t, "Code: MismatchingTypes, Node Id: node1, Description: The output variable 'n2.n2_out'"+ - " has type [simple:INTEGER], but it's assigned to the input variable 'n.x' which has type"+ - " type [union_type:{variants:{simple:STRING structure:{tag:\"str\"}} variants:{simple:INTEGER structure:{tag:\"int1\"}}"+ - " variants:{simple:INTEGER structure:{tag:\"int2\"}}}].", compileErrors.Errors().List()[0].Error()) + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "Code: MismatchingTypes, Node Id: node1,") + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "Description: The output variable 'n2.n2_out'") + assert.Contains(t, compileErrors.Errors().List()[0].Error(), "has type [simple:INTEGER], but it's assigned to the input variable 'n.x' which has type") }) t.Run("Union Promise Union Literal", func(t *testing.T) { diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index 5f41a6e65e..fb4ba04548 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -44,7 +44,14 @@ func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { literalType = &core.LiteralType{Type: &core.LiteralType_Blob{Blob: scalar.GetBlob().GetMetadata().GetType()}} case *core.Scalar_Binary: - literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} + // If the binary has a tag, treat it as a structured type (e.g., dict, dataclass, Pydantic BaseModel). + // Otherwise, treat it as raw binary data. + // Reference: https://github.com/flyteorg/flyte/blob/master/rfc/system/5741-binary-idl-with-message-pack.md + if len(v.Binary.Tag) > 0 { + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRUCT}} + } else { + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} + } case *core.Scalar_Schema: literalType = &core.LiteralType{ Type: &core.LiteralType_Schema{ diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index 4a37f100dc..26e34988c3 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/golang/protobuf/proto" + "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" @@ -16,6 +17,82 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.Equal(t, core.SimpleType_NONE.String(), lt.GetSimple().String()) }) + t.Run("binary idl with raw binary data and no tag", func(t *testing.T) { + // Some arbitrary raw binary data + rawBinaryData := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} + + lv := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: rawBinaryData, + Tag: "", + }, + }, + }, + }, + } + lt := LiteralTypeForLiteral(lv) + assert.Equal(t, core.SimpleType_BINARY.String(), lt.GetSimple().String()) + }) + + t.Run("binary idl with messagepack input map[int]strings", func(t *testing.T) { + // Create a map[int]string and serialize it using MessagePack. + data := map[int]string{ + 1: "hello", + 2: "world", + -1: "foo", + } + // Serializing the map using MessagePack + serializedBinaryData, err := msgpack.Marshal(data) + if err != nil { + t.Fatalf("failed to serialize map: %v", err) + } + lv := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: serializedBinaryData, + Tag: "msgpack", + }, + }, + }, + }, + } + lt := LiteralTypeForLiteral(lv) + assert.Equal(t, core.SimpleType_STRUCT.String(), lt.GetSimple().String()) + }) + + t.Run("binary idl with messagepack input map[float]strings", func(t *testing.T) { + // Create a map[float]string and serialize it using MessagePack. + data := map[float64]string{ + 1.0: "hello", + 5.0: "world", + -1.0: "foo", + } + // Serializing the map using MessagePack + serializedBinaryData, err := msgpack.Marshal(data) + if err != nil { + t.Fatalf("failed to serialize map: %v", err) + } + lv := &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: serializedBinaryData, + Tag: "msgpack", + }, + }, + }, + }, + } + lt := LiteralTypeForLiteral(lv) + assert.Equal(t, core.SimpleType_STRUCT.String(), lt.GetSimple().String()) + }) + t.Run("homogeneous", func(t *testing.T) { lt := literalTypeForLiterals([]*core.Literal{ coreutils.MustMakeLiteral(5), diff --git a/go.sum b/go.sum index 68eebb1fde..63453bbd87 100644 --- a/go.sum +++ b/go.sum @@ -1210,6 +1210,8 @@ github.com/sendgrid/sendgrid-go v3.10.0+incompatible/go.mod h1:QRQt+LX/NmgVEvmdR github.com/serenize/snaker v0.0.0-20171204205717-a683aaf2d516/go.mod h1:Yow6lPLSAXx2ifx470yD/nUe22Dv5vBvxK/UK9UUTVs= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= From 0c3782bfdfc3b1a85e7b904ec7c9f79a7d9126a5 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Mon, 23 Sep 2024 13:36:46 -0700 Subject: [PATCH 13/17] Backoff on etcd errors (#5710) Signed-off-by: Haytham Abuelfutuh --- flytepropeller/pkg/controller/nodes/task/backoff/handler.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flytepropeller/pkg/controller/nodes/task/backoff/handler.go b/flytepropeller/pkg/controller/nodes/task/backoff/handler.go index fc890c7a09..aadbfed514 100644 --- a/flytepropeller/pkg/controller/nodes/task/backoff/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/backoff/handler.go @@ -195,8 +195,12 @@ func IsResourceQuotaExceeded(err error) bool { return apiErrors.IsForbidden(err) && strings.Contains(err.Error(), "exceeded quota") } +func IsEtcdError(err error) bool { + return apiErrors.IsForbidden(err) && strings.Contains(err.Error(), "etcdserver:") +} + func IsBackOffError(err error) bool { - return IsResourceQuotaExceeded(err) || apiErrors.IsTooManyRequests(err) || apiErrors.IsServerTimeout(err) + return IsResourceQuotaExceeded(err) || apiErrors.IsTooManyRequests(err) || apiErrors.IsServerTimeout(err) || IsEtcdError(err) } func GetComputeResourceAndQuantity(err error, resourceRegex *regexp.Regexp) v1.ResourceList { From cff510d19029d2a5b2ecea54c3d31b00d397967f Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 24 Sep 2024 06:31:05 +0800 Subject: [PATCH 14/17] [Docs][flyteagent] Update Databricks Agent Setup to V2 (#5766) Signed-off-by: Future-Outlier --- docs/deployment/agents/databricks.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/deployment/agents/databricks.rst b/docs/deployment/agents/databricks.rst index 0458fb3667..b81c9a28aa 100644 --- a/docs/deployment/agents/databricks.rst +++ b/docs/deployment/agents/databricks.rst @@ -147,7 +147,7 @@ Specify agent configuration container: container container_array: k8s-array sidecar: sidecar - spark: agent-service + databricks: agent-service enabled-plugins: - container - sidecar @@ -171,7 +171,7 @@ Specify agent configuration default-for-task-types: - container: container - container_array: k8s-array - - spark: agent-service + - databricks: agent-service .. group-tab:: Flyte core @@ -192,7 +192,7 @@ Specify agent configuration container: container sidecar: sidecar container_array: k8s-array - spark: agent-service + databricks: agent-service Add the Databricks access token ------------------------------- From 5ad419560065689b9be05daceb61ad7d22013797 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 24 Sep 2024 06:31:46 +0800 Subject: [PATCH 15/17] [Docs][flyteagent] Improve Agent Secret Setup (#5765) Signed-off-by: Future-Outlier --- docs/deployment/agents/chatgpt.rst | 38 +++++-------------- docs/deployment/agents/databricks.rst | 38 +++++-------------- docs/deployment/agents/openai_batch.rst | 34 ++++------------- .../deploying_agents_to_the_flyte_sandbox.md | 28 ++------------ 4 files changed, 28 insertions(+), 110 deletions(-) diff --git a/docs/deployment/agents/chatgpt.rst b/docs/deployment/agents/chatgpt.rst index cb0b44fa39..9483d95d35 100644 --- a/docs/deployment/agents/chatgpt.rst +++ b/docs/deployment/agents/chatgpt.rst @@ -78,45 +78,25 @@ Specify agent configuration Add the OpenAI API token ------------------------------- -1. Install flyteagent pod using helm: +1. Install the flyteagent pod using helm: -.. code-block:: +.. code-block:: bash helm repo add flyteorg https://flyteorg.github.io/flyte helm install flyteagent flyteorg/flyteagent --namespace flyte -2. Get the base64 value of your OpenAI API token: +2. Set Your OpenAI API Token as a Secret (Base64 Encoded): -.. code-block:: +.. code-block:: bash - echo -n "" | base64 + SECRET_VALUE=$(echo -n "" | base64) && \ + kubectl patch secret flyteagent -n flyte --patch "{\"data\":{\"flyte_openai_api_key\":\"$SECRET_VALUE\"}}" -3. Edit the flyteagent secret: +3. Restart development: - .. code-block:: bash - - kubectl edit secret flyteagent -n flyte - - .. code-block:: yaml - :emphasize-lines: 3 - - apiVersion: v1 - data: - flyte_openai_api_key: - kind: Secret - metadata: - annotations: - meta.helm.sh/release-name: flyteagent - meta.helm.sh/release-namespace: flyte - creationTimestamp: "2023-10-04T04:09:03Z" - labels: - app.kubernetes.io/managed-by: Helm - name: flyteagent - namespace: flyte - resourceVersion: "753" - uid: 5ac1e1b6-2a4c-4e26-9001-d4ba72c39e54 - type: Opaque +.. code-block:: bash + kubectl rollout restart deployment flyteagent -n flyte Upgrade the Flyte Helm release ------------------------------ diff --git a/docs/deployment/agents/databricks.rst b/docs/deployment/agents/databricks.rst index b81c9a28aa..b419144021 100644 --- a/docs/deployment/agents/databricks.rst +++ b/docs/deployment/agents/databricks.rst @@ -199,45 +199,25 @@ Add the Databricks access token You have to set the Databricks token to the Flyte configuration. -1. Install flyteagent pod using helm +1. Install the flyteagent pod using helm .. code-block:: helm repo add flyteorg https://flyteorg.github.io/flyte helm install flyteagent flyteorg/flyteagent --namespace flyte -2. Get the base64 value of your Databricks token. +2. Set Your Databricks Token as a Secret (Base64 Encoded): -.. code-block:: +.. code-block:: bash - echo -n "" | base64 + SECRET_VALUE=$(echo -n "" | base64) && \ + kubectl patch secret flyteagent -n flyte --patch "{\"data\":{\"flyte_databricks_access_token\":\"$SECRET_VALUE\"}}" -3. Edit the flyteagent secret - - .. code-block:: bash - - kubectl edit secret flyteagent -n flyte - - .. code-block:: yaml - :emphasize-lines: 3 - - apiVersion: v1 - data: - flyte_databricks_access_token: - kind: Secret - metadata: - annotations: - meta.helm.sh/release-name: flyteagent - meta.helm.sh/release-namespace: flyte - creationTimestamp: "2023-10-04T04:09:03Z" - labels: - app.kubernetes.io/managed-by: Helm - name: flyteagent - namespace: flyte - resourceVersion: "753" - uid: 5ac1e1b6-2a4c-4e26-9001-d4ba72c39e54 - type: Opaque +3. Restart development: + +.. code-block:: bash + kubectl rollout restart deployment flyteagent -n flyte Upgrade the deployment ---------------------- diff --git a/docs/deployment/agents/openai_batch.rst b/docs/deployment/agents/openai_batch.rst index 2cfa70471a..8e1c622b73 100644 --- a/docs/deployment/agents/openai_batch.rst +++ b/docs/deployment/agents/openai_batch.rst @@ -65,38 +65,18 @@ Add the OpenAI API token helm repo add flyteorg https://flyteorg.github.io/flyte helm install flyteagent flyteorg/flyteagent --namespace flyte -2. Get the base64 value of your OpenAI API token: +2. Set Your OpenAI API Token as a Secret (Base64 Encoded): -.. code-block:: - - echo -n "" | base64 +.. code-block:: bash -3. Edit the flyteagent secret: + SECRET_VALUE=$(echo -n "" | base64) && \ + kubectl patch secret flyteagent -n flyte --patch "{\"data\":{\"flyte_openai_api_key\":\"$SECRET_VALUE\"}}" - .. code-block:: bash +3. Restart development: - kubectl edit secret flyteagent -n flyte - - .. code-block:: yaml - :emphasize-lines: 3 - - apiVersion: v1 - data: - FLYTE_OPENAI_API_KEY: - kind: Secret - metadata: - annotations: - meta.helm.sh/release-name: flyteagent - meta.helm.sh/release-namespace: flyte - creationTimestamp: "2023-10-04T04:09:03Z" - labels: - app.kubernetes.io/managed-by: Helm - name: flyteagent - namespace: flyte - resourceVersion: "753" - uid: 5ac1e1b6-2a4c-4e26-9001-d4ba72c39e54 - type: Opaque +.. code-block:: bash + kubectl rollout restart deployment flyteagent -n flyte Upgrade the Flyte Helm release ------------------------------ diff --git a/docs/flyte_agents/deploying_agents_to_the_flyte_sandbox.md b/docs/flyte_agents/deploying_agents_to_the_flyte_sandbox.md index c4f1a2881e..2f068eb681 100644 --- a/docs/flyte_agents/deploying_agents_to_the_flyte_sandbox.md +++ b/docs/flyte_agents/deploying_agents_to_the_flyte_sandbox.md @@ -52,30 +52,8 @@ image: localhost:30000/flyteagent:example 3. Set up your secrets: Let's take Databricks agent as an example: ```bash -kubectl edit secret flyteagent -n flyte -``` -Get your `BASE64_ENCODED_DATABRICKS_TOKEN`: -```bash -echo -n "" | base64 -``` -Add your token to the `data` field: -```yaml -apiVersion: v1 -data: - flyte_databricks_access_token: -kind: Secret -metadata: - annotations: - meta.helm.sh/release-name: flyteagent - meta.helm.sh/release-namespace: flyte - creationTimestamp: "2023-10-04T04:09:03Z" - labels: - app.kubernetes.io/managed-by: Helm - name: flyteagent - namespace: flyte - resourceVersion: "753" - uid: 5ac1e1b6-2a4c-4e26-9001-d4ba72c39e54 -type: Opaque +SECRET_VALUE=$(echo -n "" | base64) && \ +kubectl patch secret flyteagent -n flyte --patch "{\"data\":{\"flyte_databricks_access_token\":\"$SECRET_VALUE\"}}" ``` :::{note} Please ensure two things: @@ -85,7 +63,7 @@ Please ensure two things: 4. Restart development: ```bash -kubectl rollout restart deployment flyte-sandbox -n flyte +kubectl rollout restart deployment flyteagent -n flyte ``` 5. Test your agent remotely in the Flyte sandbox: From c0fc6d4638eb05a4e987b31bb2a1ff3ca00ef383 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 23 Sep 2024 17:19:29 -0700 Subject: [PATCH 16/17] [RFC] Offloaded Raw Literals (#5103) --- rfc/system/5103-offloaded-literal.md | 119 +++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 rfc/system/5103-offloaded-literal.md diff --git a/rfc/system/5103-offloaded-literal.md b/rfc/system/5103-offloaded-literal.md new file mode 100644 index 0000000000..fb835a6aab --- /dev/null +++ b/rfc/system/5103-offloaded-literal.md @@ -0,0 +1,119 @@ +# [RFC] Offloaded Raw Literals + +**Authors:** + +- @wild-endeavor +- @EngHabu +- @katrogan + +## 1 Executive Summary + +Flyte depends on a series of `inputs.pb` and `outputs.pb` files to do communication between nodes. This has typically served us well, except for the occasional map task that produces a large Literal collection output. This large collection typically exceeds default gRPC and configured storage limits. We sometimes also run into this issue for large dataclasses. This RFC proposes a mechanism that allows the offloading of any Literal, to get around size limitations for passing large Literal protobuf messages in the system. + +## 2 Motivation +A [cursory search](https://discuss.flyte.org/?threads%5Bquery%5D=LIMIT_EXCEEDED) of Slack history shows a few times that this has come up before (and I remember some other instances, I think that search term just wasn't included). This is something that we've historically addressed by just increasing the size of the grpc message that's allowed, but this is an unsustainable solution and severely reduces the utility of large-fan-out map tasks. + +## 3 Proposed Implementation +We propose configuring propeller to offload large literal collections, using the following config + +```yaml +type LiteralOffloadingConfig struct { + Enabled bool + // Maps flytekit SDK to minimum supported version that can handle reading offloaded literals. + SupportedSDKVersions map[string]string + // Default, 10Mbs. Determines the size of a literal at which to trigger offloading + MinSizeInMBForOffloading uint64 + // Fail fast threshold + MaxSizeInMBForOffloading uint64 +} + +``` + +### 3.1 Offloaded Literal IDL +Update the `Literal` [message](https://github.com/flyteorg/flyte/blob/4a7c3c0040b1995a43939407b99ca3e87b1eb752/flyteidl/protos/flyteidl/core/literals.proto#L94-L114) +like so + +```protobuf +message Literal { + oneof value { + // A simple value. + Scalar scalar = 1; + // A collection of literals to allow nesting. + LiteralCollection collection = 2; + // A map of strings to literals. + LiteralMap map = 3; + } + ... + // ** new below this line ** + // If this literal is offloaded, this field will contain metadata including the offload location. + string uri = 6; + // Includes information about the size of the literal. + uint64 size_bytes = 7; +} + +``` + +### 3.2 Flyte Propeller +Once offloading is enabled in the deployment config, flytepropeller can read from the [RuntimeMetadata](https://github.com/flyteorg/flyte/blob/f448a0358d8706a09b65b96543134f629327d755/flyteidl/protos/flyteidl/core/tasks.proto#L71-L87) in the task config to determine the SDK version. + +When writing outputs in the [remote_file_output_writer](https://github.com/flyteorg/flyte/blob/2ca31119d6b9258661a71f38e450f93b6692402c/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_writer.go#L56-L84) the source code should detect whether the literal size exceeds the configured minimum and +- if the task is using a newer SDK version that supports reading offloaded literals, offload the literal to the configured storage backend and update the literal with the offload URI and size. +- if the task is using an older SDK version that doesn't support offloaded literals, fail the task with an error message indicating that the task output is too large and the user should update their SDK version. Downstream tasks will need to understand how to consume the offloaded literal and will need to be on newer version of the SDK as well. + +### 3.3 Flytekit & Copilot +Flytekit and Copilot will both need to detect that a Literal has been offloaded and know to download it. +- in Flytekit this can be done by checking the `uri` field in the Literal message when converting a literal [to_python_value](https://github.com/flyteorg/flytekit/blob/e394af0be9f904fbf8be675eaa8b8cdc24311ced/flytekit/core/type_engine.py#L1134) +- in Copilot, the data downloader [literal handling](https://github.com/flyteorg/flyte/blob/5f4199899922ca63f7690c82dfca42a783db64c3/flytecopilot/data/download.go#L219-L248) should fetch the value + +As a follow-up, we can also implement literal offloading in the SDK for conventional python tasks. Flytekit should also know how to offload the data. This should be done transparently to the user. + +We should fail fast in the SDK for too-large literals as part of the initial round of changes. + +**Open Question:** How will flytekit know to fail though if propeller hasn't been updated? + +### 3.4 Other Implications +#### Flytekit Remote +Flytekit Remote will need to be updated to handle offloaded literals. In order to fetch offloaded literals by URI, users must now authenticate with their cloud provider on their machines using a role which has read access to the _metadata bucket_. + +#### Console +Console code should show the offloaded literal URI and gracefully handle nil Literal [values](https://github.com/flyteorg/flyte/blob/4a7c3c0040b1995a43939407b99ca3e87b1eb752/flyteidl/protos/flyteidl/core/literals.proto#L96-L105). + +## 4 Metrics & Dashboards + +*What are the main metrics we should be measuring? For example, when interacting with an external system, it might be the external system latency. When adding a new table, how fast would it fill up?* + +## 5 Drawbacks + +*Are there any reasons why we should not do this? Here we aim to evaluate risk and check ourselves.* + +## 6 Alternatives + +Alternate suggestions that were proposed include + +* For map tasks, change the type of the output to a Union of the current user defined List and a new Offloaded type. We felt this would be a bit awkward since it changes the user-facing type itself (like if you were to pull up the map task definition in the API endpoint). It's also not extensible to other types of literals (maps of large dataclasses for example). + +* Build off of the input wrapper construct that's still in [PR](https://github.com/flyteorg/flyte/pull/4298). The idea was to have the wrapper contain in large cases, a reference to the data, and in small cases, the data itself. We didn't fully like this idea because the entire input set or output set needs to be offloaded. + * If the task downstream of a map task takes both the output list, along with some other input, after creating and upload the large pb file for the map task's output, Propeller would need to re-upload the entire large list or map (one time for each downstream task). If the offloading is done per literal, Propeller can just upload once and use. +* Modify the workflow CRD to include the offloading bits so that they're respected at execution time, and serialized at registration time. This is a bit heavier handed than just respecting the SDK version + +## 7 Potential Impact and Dependencies + +There's a couple edge cases that will just not work. + +* If the map task is of an older flytekit version but for some reason the downstream task is of a newer version, Propeller will fail unnecessarily. +* If the map task is a newer version, but the downstream task is an older version, the downstream task will fail correctly. +* If workflow is using an older SDK version and launches a child workflow with a newer SDK version, the parent workflow will fail to resolve the child workflow outputs + +Are there concerns about the fact that if we're offloading data once, and then sharing the pointer, we're no longer copying-by-value? Does this break any of the guarantees of Flyte and will we need to be more careful in the future around other changes to avoid issues? + +## 8 Unresolved questions + +Should we create a new oneof that's offloaded? + +Is there anything around sampled data, or automatically computed actual metadata (like number of elements in the list) that we should do? + +## 9 Conclusion + +Moving to literal offloading fixes a common and frustrating pain point around map tasks. It's a relatively simple change that should have a big impact on the usability of Flyte. + +``` From c01a059baeb5ce552eb35ba38956451b2e405361 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Tue, 24 Sep 2024 07:20:53 -0700 Subject: [PATCH 17/17] Use pluggable clock for auto refresh cache and make unit tests fast (#5767) * Use pluggable clock and make unit tests fast Signed-off-by: Jason Parraga * lint fixes Signed-off-by: Jason Parraga --------- Signed-off-by: Jason Parraga --- flytestdlib/.golangci.yml | 1 - flytestdlib/cache/auto_refresh.go | 51 ++++++++++++++++----- flytestdlib/cache/auto_refresh_test.go | 63 ++++++++++++++++---------- 3 files changed, 78 insertions(+), 37 deletions(-) diff --git a/flytestdlib/.golangci.yml b/flytestdlib/.golangci.yml index 7f4dbc80e8..e3bff2320b 100644 --- a/flytestdlib/.golangci.yml +++ b/flytestdlib/.golangci.yml @@ -26,7 +26,6 @@ linters: - structcheck - typecheck - unconvert - - unparam - unused - varcheck diff --git a/flytestdlib/cache/auto_refresh.go b/flytestdlib/cache/auto_refresh.go index 8218e577a8..19d38bdaff 100644 --- a/flytestdlib/cache/auto_refresh.go +++ b/flytestdlib/cache/auto_refresh.go @@ -9,8 +9,8 @@ import ( lru "github.com/hashicorp/golang-lru" "github.com/prometheus/client_golang/prometheus" - "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/util/workqueue" + "k8s.io/utils/clock" "github.com/flyteorg/flyte/flytestdlib/contextutils" "github.com/flyteorg/flyte/flytestdlib/errors" @@ -125,6 +125,7 @@ type autoRefresh struct { workqueue workqueue.RateLimitingInterface parallelizm int lock sync.RWMutex + clock clock.Clock } func getEvictionFunction(counter prometheus.Counter) func(key interface{}, value interface{}) { @@ -165,17 +166,29 @@ func (w *autoRefresh) Start(ctx context.Context) error { } enqueueCtx := contextutils.WithGoroutineLabel(ctx, fmt.Sprintf("%v-enqueue", w.name)) - - go wait.Until(func() { - err := w.enqueueBatches(enqueueCtx) - if err != nil { - logger.Errorf(enqueueCtx, "Failed to enqueue. Error: %v", err) - } - }, w.syncPeriod, enqueueCtx.Done()) + go w.enqueueLoop(enqueueCtx) return nil } +func (w *autoRefresh) enqueueLoop(ctx context.Context) { + timer := w.clock.NewTimer(w.syncPeriod) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-timer.C(): + err := w.enqueueBatches(ctx) + if err != nil { + logger.Errorf(ctx, "Failed to enqueue. Error: %v", err) + } + timer.Reset(w.syncPeriod) + } + } +} + // Update updates the item only if it exists in the cache, return true if we updated the item. func (w *autoRefresh) Update(id ItemID, item Item) (ok bool) { w.lock.Lock() @@ -221,7 +234,7 @@ func (w *autoRefresh) GetOrCreate(id ItemID, item Item) (Item, error) { batch := make([]ItemWrapper, 0, 1) batch = append(batch, itemWrapper{id: id, item: item}) w.workqueue.AddRateLimited(&batch) - w.processing.Store(id, time.Now()) + w.processing.Store(id, w.clock.Now()) return item, nil } @@ -265,7 +278,7 @@ func (w *autoRefresh) enqueueBatches(ctx context.Context) error { b := batch w.workqueue.AddRateLimited(&b) for i := 1; i < len(b); i++ { - w.processing.Store(b[i].GetID(), time.Now()) + w.processing.Store(b[i].GetID(), w.clock.Now()) } } @@ -365,7 +378,7 @@ func (w *autoRefresh) inProcessing(key interface{}) bool { item, found := w.processing.Load(key) if found { // handle potential race conditions where the item is in processing but not in the workqueue - if timeItem, ok := item.(time.Time); ok && time.Since(timeItem) > (w.syncPeriod*5) { + if timeItem, ok := item.(time.Time); ok && w.clock.Since(timeItem) > (w.syncPeriod*5) { w.processing.Delete(key) return false } @@ -377,6 +390,11 @@ func (w *autoRefresh) inProcessing(key interface{}) bool { // Instantiates a new AutoRefresh Cache that syncs items in batches. func NewAutoRefreshBatchedCache(name string, createBatches CreateBatchesFunc, syncCb SyncFunc, syncRateLimiter workqueue.RateLimiter, resyncPeriod time.Duration, parallelizm, size int, scope promutils.Scope) (AutoRefresh, error) { + return newAutoRefreshBatchedCacheWithClock(name, createBatches, syncCb, syncRateLimiter, resyncPeriod, parallelizm, size, scope, clock.RealClock{}) +} + +func newAutoRefreshBatchedCacheWithClock(name string, createBatches CreateBatchesFunc, syncCb SyncFunc, syncRateLimiter workqueue.RateLimiter, + resyncPeriod time.Duration, parallelizm, size int, scope promutils.Scope, clock clock.WithTicker) (AutoRefresh, error) { metrics := newMetrics(scope) lruCache, err := lru.NewWithEvict(size, getEvictionFunction(metrics.Evictions)) @@ -394,7 +412,11 @@ func NewAutoRefreshBatchedCache(name string, createBatches CreateBatchesFunc, sy processing: &sync.Map{}, toDelete: newSyncSet(), syncPeriod: resyncPeriod, - workqueue: workqueue.NewNamedRateLimitingQueue(syncRateLimiter, scope.CurrentScope()), + workqueue: workqueue.NewRateLimitingQueueWithConfig(syncRateLimiter, workqueue.RateLimitingQueueConfig{ + Name: scope.CurrentScope(), + Clock: clock, + }), + clock: clock, } return cache, nil @@ -406,3 +428,8 @@ func NewAutoRefreshCache(name string, syncCb SyncFunc, syncRateLimiter workqueue return NewAutoRefreshBatchedCache(name, SingleItemBatches, syncCb, syncRateLimiter, resyncPeriod, parallelizm, size, scope) } + +func newAutoRefreshCacheWithClock(name string, syncCb SyncFunc, syncRateLimiter workqueue.RateLimiter, resyncPeriod time.Duration, + parallelizm, size int, scope promutils.Scope, clock clock.WithTicker) (AutoRefresh, error) { + return newAutoRefreshBatchedCacheWithClock(name, SingleItemBatches, syncCb, syncRateLimiter, resyncPeriod, parallelizm, size, scope, clock) +} diff --git a/flytestdlib/cache/auto_refresh_test.go b/flytestdlib/cache/auto_refresh_test.go index 5e1c49777e..e6b210bfcd 100644 --- a/flytestdlib/cache/auto_refresh_test.go +++ b/flytestdlib/cache/auto_refresh_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "k8s.io/client-go/util/workqueue" + testingclock "k8s.io/utils/clock/testing" "github.com/flyteorg/flyte/flytestdlib/atomic" "github.com/flyteorg/flyte/flytestdlib/errors" @@ -74,12 +75,13 @@ func (p *panickingSyncer) sync(_ context.Context, _ Batch) ([]ItemSyncResponse, } func TestCacheFour(t *testing.T) { - testResyncPeriod := 10 * time.Millisecond + testResyncPeriod := 5 * time.Second rateLimiter := workqueue.DefaultControllerRateLimiter() + fakeClock := testingclock.NewFakeClock(time.Now()) t.Run("normal operation", func(t *testing.T) { // the size of the cache is at least as large as the number of items we're storing - cache, err := NewAutoRefreshCache("fake1", syncFakeItem, rateLimiter, testResyncPeriod, 10, 10, promutils.NewTestScope()) + cache, err := newAutoRefreshCacheWithClock("fake1", syncFakeItem, rateLimiter, testResyncPeriod, 10, 10, promutils.NewTestScope(), fakeClock) assert.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -94,18 +96,21 @@ func TestCacheFour(t *testing.T) { } assert.EventuallyWithT(t, func(c *assert.CollectT) { + // trigger periodic sync + fakeClock.Step(testResyncPeriod) + for i := 1; i <= 10; i++ { item, err := cache.Get(fmt.Sprintf("%d", i)) assert.NoError(c, err) assert.Equal(c, 10, item.(fakeCacheItem).val) } - }, 3*time.Second, 100*time.Millisecond) + }, 3*time.Second, time.Millisecond) cancel() }) t.Run("Not Found", func(t *testing.T) { // the size of the cache is at least as large as the number of items we're storing - cache, err := NewAutoRefreshCache("fake2", syncFakeItem, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope()) + cache, err := newAutoRefreshCacheWithClock("fake2", syncFakeItem, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope(), fakeClock) assert.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -133,7 +138,7 @@ func TestCacheFour(t *testing.T) { }) t.Run("Enqueue nothing", func(t *testing.T) { - cache, err := NewAutoRefreshCache("fake3", syncTerminalItem, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope()) + cache, err := newAutoRefreshCacheWithClock("fake3", syncTerminalItem, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope(), fakeClock) assert.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -147,15 +152,16 @@ func TestCacheFour(t *testing.T) { assert.NoError(t, err) } - // Wait half a second for all resync periods to complete + // Enqueue first batch + fakeClock.Step(testResyncPeriod) // If the cache tries to enqueue the item, a panic will be thrown. - time.Sleep(500 * time.Millisecond) + fakeClock.Step(testResyncPeriod) cancel() }) t.Run("Test update and delete cache", func(t *testing.T) { - cache, err := NewAutoRefreshCache("fake3", syncTerminalItem, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope()) + cache, err := newAutoRefreshCacheWithClock("fake3", syncTerminalItem, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope(), fakeClock) assert.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -167,24 +173,27 @@ func TestCacheFour(t *testing.T) { }) assert.NoError(t, err) - // Wait half a second for all resync periods to complete // If the cache tries to enqueue the item, a panic will be thrown. - time.Sleep(500 * time.Millisecond) + fakeClock.Step(testResyncPeriod) err = cache.DeleteDelayed(itemID) assert.NoError(t, err) - time.Sleep(500 * time.Millisecond) - item, err := cache.Get(itemID) - assert.Nil(t, item) - assert.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // trigger a sync + fakeClock.Step(testResyncPeriod) + + item, err := cache.Get(itemID) + assert.Nil(c, item) + assert.Error(c, err) + }, 3*time.Second, time.Millisecond) cancel() }) t.Run("Test panic on sync and shutdown", func(t *testing.T) { syncer := &panickingSyncer{} - cache, err := NewAutoRefreshCache("fake3", syncer.sync, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope()) + cache, err := newAutoRefreshCacheWithClock("fake3", syncer.sync, rateLimiter, testResyncPeriod, 10, 2, promutils.NewTestScope(), fakeClock) assert.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) @@ -198,15 +207,12 @@ func TestCacheFour(t *testing.T) { // wait for all workers to run assert.Eventually(t, func() bool { + // trigger a sync + fakeClock.Step(testResyncPeriod) + return syncer.callCount.Load() == int32(10) }, 5*time.Second, time.Millisecond) - // wait some more time - time.Sleep(500 * time.Millisecond) - - // all workers should have shut down. - assert.Equal(t, int32(10), syncer.callCount.Load()) - cancel() }) } @@ -214,6 +220,7 @@ func TestCacheFour(t *testing.T) { func TestQueueBuildUp(t *testing.T) { testResyncPeriod := time.Hour rateLimiter := workqueue.DefaultControllerRateLimiter() + fakeClock := testingclock.NewFakeClock(time.Now()) syncCount := atomic.NewInt32(0) m := sync.Map{} @@ -231,7 +238,7 @@ func TestQueueBuildUp(t *testing.T) { } size := 100 - cache, err := NewAutoRefreshCache("fake2", alwaysFailing, rateLimiter, testResyncPeriod, 10, size, promutils.NewTestScope()) + cache, err := newAutoRefreshCacheWithClock("fake2", alwaysFailing, rateLimiter, testResyncPeriod, 10, size, promutils.NewTestScope(), fakeClock) assert.NoError(t, err) ctx := context.Background() @@ -244,16 +251,24 @@ func TestQueueBuildUp(t *testing.T) { } assert.NoError(t, cache.Start(ctx)) - time.Sleep(5 * time.Second) - assert.Equal(t, int32(size), syncCount.Load()) + + // wait for all workers to run + assert.Eventually(t, func() bool { + // trigger a sync and unlock the work queue + fakeClock.Step(time.Millisecond) + + return syncCount.Load() == int32(size) + }, 5*time.Second, time.Millisecond) } func TestInProcessing(t *testing.T) { syncPeriod := time.Millisecond + fakeClock := testingclock.NewFakeClock(time.Now()) cache := &autoRefresh{ processing: &sync.Map{}, syncPeriod: syncPeriod, + clock: fakeClock, } assert.False(t, cache.inProcessing("test"))