From e67aae0540f6b72475336922147352c7e92bad86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bu=C4=9Fra=20Gedik?= Date: Wed, 11 Sep 2024 23:49:07 -0700 Subject: [PATCH 1/6] Add listing api to stow storage (#5741) Signed-off-by: Bugra Gedik --- flyteadmin/pkg/common/mocks/storage.go | 4 + flytepropeller/pkg/utils/failing_datastore.go | 4 + flytestdlib/storage/cached_rawstore_test.go | 4 + flytestdlib/storage/mem_store.go | 4 + .../storage/mocks/composed_protobuf_store.go | 48 ++++++++++ flytestdlib/storage/storage.go | 38 ++++++++ flytestdlib/storage/stow_store.go | 46 +++++++++ flytestdlib/storage/stow_store_test.go | 96 ++++++++++++++++++- script/generate_helm.sh | 3 +- 9 files changed, 244 insertions(+), 3 deletions(-) diff --git a/flyteadmin/pkg/common/mocks/storage.go b/flyteadmin/pkg/common/mocks/storage.go index 7e91bf0485..bf29eedd3e 100644 --- a/flyteadmin/pkg/common/mocks/storage.go +++ b/flyteadmin/pkg/common/mocks/storage.go @@ -33,6 +33,10 @@ func (t *TestDataStore) Head(ctx context.Context, reference storage.DataReferenc return t.HeadCb(ctx, reference) } +func (t *TestDataStore) List(ctx context.Context, reference storage.DataReference, maxItems int, cursor storage.Cursor) ([]storage.DataReference, storage.Cursor, error) { + return nil, storage.NewCursorAtEnd(), fmt.Errorf("Not implemented yet") +} + func (t *TestDataStore) ReadProtobuf(ctx context.Context, reference storage.DataReference, msg proto.Message) error { return t.ReadProtobufCb(ctx, reference, msg) } diff --git a/flytepropeller/pkg/utils/failing_datastore.go b/flytepropeller/pkg/utils/failing_datastore.go index f3b65471c7..7948a85b81 100644 --- a/flytepropeller/pkg/utils/failing_datastore.go +++ b/flytepropeller/pkg/utils/failing_datastore.go @@ -27,6 +27,10 @@ func (FailingRawStore) Head(ctx context.Context, reference storage.DataReference return nil, fmt.Errorf("failed metadata fetch") } +func (FailingRawStore) List(ctx context.Context, reference storage.DataReference, maxItems int, cursor storage.Cursor) ([]storage.DataReference, storage.Cursor, error) { + return nil, storage.NewCursorAtEnd(), fmt.Errorf("Not implemented yet") +} + func (FailingRawStore) ReadRaw(ctx context.Context, reference storage.DataReference) (io.ReadCloser, error) { return nil, fmt.Errorf("failed read raw") } diff --git a/flytestdlib/storage/cached_rawstore_test.go b/flytestdlib/storage/cached_rawstore_test.go index b9751d7fa1..9c304790cb 100644 --- a/flytestdlib/storage/cached_rawstore_test.go +++ b/flytestdlib/storage/cached_rawstore_test.go @@ -73,6 +73,10 @@ func (d *dummyStore) Head(ctx context.Context, reference DataReference) (Metadat return d.HeadCb(ctx, reference) } +func (d *dummyStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { + return nil, NewCursorAtEnd(), fmt.Errorf("Not implemented yet") +} + func (d *dummyStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { return d.ReadRawCb(ctx, reference) } diff --git a/flytestdlib/storage/mem_store.go b/flytestdlib/storage/mem_store.go index a95a0a49ca..94083f6646 100644 --- a/flytestdlib/storage/mem_store.go +++ b/flytestdlib/storage/mem_store.go @@ -54,6 +54,10 @@ func (s *InMemoryStore) Head(ctx context.Context, reference DataReference) (Meta }, nil } +func (s *InMemoryStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { + return nil, NewCursorAtEnd(), fmt.Errorf("Not implemented yet") +} + func (s *InMemoryStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { if raw, found := s.cache[reference]; found { return ioutil.NopCloser(bytes.NewReader(raw)), nil diff --git a/flytestdlib/storage/mocks/composed_protobuf_store.go b/flytestdlib/storage/mocks/composed_protobuf_store.go index c9064c2ac5..49a0ee89dd 100644 --- a/flytestdlib/storage/mocks/composed_protobuf_store.go +++ b/flytestdlib/storage/mocks/composed_protobuf_store.go @@ -194,6 +194,54 @@ func (_m *ComposedProtobufStore) Head(ctx context.Context, reference storage.Dat return r0, r1 } +type ComposedProtobufStore_List struct { + *mock.Call +} + +func (_m ComposedProtobufStore_List) Return(_a0 []storage.DataReference, _a1 storage.Cursor, _a2 error) *ComposedProtobufStore_List { + return &ComposedProtobufStore_List{Call: _m.Call.Return(_a0, _a1, _a2)} +} + +func (_m *ComposedProtobufStore) OnList(ctx context.Context, reference storage.DataReference, maxItems int, cursor storage.Cursor) *ComposedProtobufStore_List { + c_call := _m.On("List", ctx, reference, maxItems, cursor) + return &ComposedProtobufStore_List{Call: c_call} +} + +func (_m *ComposedProtobufStore) OnListMatch(matchers ...interface{}) *ComposedProtobufStore_List { + c_call := _m.On("List", matchers...) + return &ComposedProtobufStore_List{Call: c_call} +} + +// List provides a mock function with given fields: ctx, reference, maxItems, cursor +func (_m *ComposedProtobufStore) List(ctx context.Context, reference storage.DataReference, maxItems int, cursor storage.Cursor) ([]storage.DataReference, storage.Cursor, error) { + ret := _m.Called(ctx, reference, maxItems, cursor) + + var r0 []storage.DataReference + if rf, ok := ret.Get(0).(func(context.Context, storage.DataReference, int, storage.Cursor) []storage.DataReference); ok { + r0 = rf(ctx, reference, maxItems, cursor) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]storage.DataReference) + } + } + + var r1 storage.Cursor + if rf, ok := ret.Get(1).(func(context.Context, storage.DataReference, int, storage.Cursor) storage.Cursor); ok { + r1 = rf(ctx, reference, maxItems, cursor) + } else { + r1 = ret.Get(1).(storage.Cursor) + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, storage.DataReference, int, storage.Cursor) error); ok { + r2 = rf(ctx, reference, maxItems, cursor) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + type ComposedProtobufStore_ReadProtobuf struct { *mock.Call } diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go index 3e84cb7acb..52e6905513 100644 --- a/flytestdlib/storage/storage.go +++ b/flytestdlib/storage/storage.go @@ -40,6 +40,41 @@ type Metadata interface { ContentMD5() string } +type CursorState int + +const ( + // Enum representing state of the cursor + AtStartCursorState CursorState = 0 + AtEndCursorState CursorState = 1 + AtCustomPosCursorState CursorState = 2 +) + +type Cursor struct { + cursorState CursorState + customPosition string +} + +func NewCursorAtStart() Cursor { + return Cursor{ + cursorState: AtStartCursorState, + customPosition: "", + } +} + +func NewCursorAtEnd() Cursor { + return Cursor{ + cursorState: AtEndCursorState, + customPosition: "", + } +} + +func NewCursorFromCustomPosition(customPosition string) Cursor { + return Cursor{ + cursorState: AtCustomPosCursorState, + customPosition: customPosition, + } +} + // DataStore is a simplified interface for accessing and storing data in one of the Cloud stores. // Today we rely on Stow for multi-cloud support, but this interface abstracts that part type DataStore struct { @@ -78,6 +113,9 @@ type RawStore interface { // Head gets metadata about the reference. This should generally be a light weight operation. Head(ctx context.Context, reference DataReference) (Metadata, error) + // List gets a list of items given a prefix, using a paginated API + List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) + // ReadRaw retrieves a byte array from the Blob store or an error ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) diff --git a/flytestdlib/storage/stow_store.go b/flytestdlib/storage/stow_store.go index ce4a75a0a1..6b731b9c86 100644 --- a/flytestdlib/storage/stow_store.go +++ b/flytestdlib/storage/stow_store.go @@ -92,6 +92,9 @@ type stowMetrics struct { HeadFailure labeled.Counter HeadLatency labeled.StopWatch + ListFailure labeled.Counter + ListLatency labeled.StopWatch + ReadFailure labeled.Counter ReadOpenLatency labeled.StopWatch @@ -251,6 +254,46 @@ func (s *StowStore) Head(ctx context.Context, reference DataReference) (Metadata return StowMetadata{exists: false}, errs.Wrapf(err, "path:%v", k) } +func (s *StowStore) List(ctx context.Context, reference DataReference, maxItems int, cursor Cursor) ([]DataReference, Cursor, error) { + _, c, k, err := reference.Split() + if err != nil { + s.metrics.BadReference.Inc(ctx) + return nil, NewCursorAtEnd(), err + } + + container, err := s.getContainer(ctx, locationIDMain, c) + if err != nil { + return nil, NewCursorAtEnd(), err + } + + t := s.metrics.ListLatency.Start(ctx) + var stowCursor string + if cursor.cursorState == AtStartCursorState { + stowCursor = stow.CursorStart + } else if cursor.cursorState == AtEndCursorState { + return nil, NewCursorAtEnd(), fmt.Errorf("Cursor cannot be at end for the List call") + } else { + stowCursor = cursor.customPosition + } + items, stowCursor, err := container.Items(k, stowCursor, maxItems) + if err == nil { + results := make([]DataReference, len(items)) + for index, item := range items { + results[index] = DataReference(item.URL().String()) + } + if stow.IsCursorEnd(stowCursor) { + cursor = NewCursorAtEnd() + } else { + cursor = NewCursorFromCustomPosition(stowCursor) + } + t.Stop() + return results, cursor, nil + } + + incFailureCounterForError(ctx, s.metrics.ListFailure, err) + return nil, NewCursorAtEnd(), errs.Wrapf(err, "path:%v", k) +} + func (s *StowStore) ReadRaw(ctx context.Context, reference DataReference) (io.ReadCloser, error) { _, c, k, err := reference.Split() if err != nil { @@ -434,6 +477,9 @@ func newStowMetrics(scope promutils.Scope) *stowMetrics { HeadFailure: labeled.NewCounter("head_failure", "Indicates failure in HEAD for a given reference", scope, labeled.EmitUnlabeledMetric), HeadLatency: labeled.NewStopWatch("head", "Indicates time to fetch metadata using the Head API", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + ListFailure: labeled.NewCounter("list_failure", "Indicates failure in item listing for a given reference", scope, labeled.EmitUnlabeledMetric), + ListLatency: labeled.NewStopWatch("list", "Indicates time to fetch item listing using the List API", time.Millisecond, scope, labeled.EmitUnlabeledMetric), + ReadFailure: labeled.NewCounter("read_failure", "Indicates failure in GET for a given reference", scope, labeled.EmitUnlabeledMetric, failureTypeOption), ReadOpenLatency: labeled.NewStopWatch("read_open", "Indicates time to first byte when reading", time.Millisecond, scope, labeled.EmitUnlabeledMetric), diff --git a/flytestdlib/storage/stow_store_test.go b/flytestdlib/storage/stow_store_test.go index 99678eb8ad..4de273dd93 100644 --- a/flytestdlib/storage/stow_store_test.go +++ b/flytestdlib/storage/stow_store_test.go @@ -10,6 +10,8 @@ import ( "net/url" "os" "path/filepath" + "sort" + "strconv" "testing" "time" @@ -73,8 +75,37 @@ func (m mockStowContainer) Item(id string) (stow.Item, error) { return nil, stow.ErrNotFound } -func (mockStowContainer) Items(prefix, cursor string, count int) ([]stow.Item, string, error) { - return []stow.Item{}, "", nil +func (m mockStowContainer) Items(prefix, cursor string, count int) ([]stow.Item, string, error) { + startIndex := 0 + if cursor != "" { + index, err := strconv.Atoi(cursor) + if err != nil { + return nil, "", fmt.Errorf("Invalid cursor '%s'", cursor) + } + startIndex = index + } + endIndexExc := min(len(m.items), startIndex+count) + + itemKeys := make([]string, len(m.items)) + index := 0 + for key := range m.items { + itemKeys[index] = key + index++ + } + sort.Strings(itemKeys) + + numItems := endIndexExc - startIndex + results := make([]stow.Item, numItems) + for index, itemKey := range itemKeys[startIndex:endIndexExc] { + results[index] = m.items[itemKey] + } + + if endIndexExc == len(m.items) { + cursor = "" + } else { + cursor = fmt.Sprintf("%d", endIndexExc) + } + return results, cursor, nil } func (m mockStowContainer) RemoveItem(id string) error { @@ -361,6 +392,67 @@ func TestStowStore_ReadRaw(t *testing.T) { }) } +func TestStowStore_List(t *testing.T) { + const container = "container" + t.Run("Listing", func(t *testing.T) { + ctx := context.Background() + fn := fQNFn["s3"] + s, err := NewStowRawStore(fn(container), &mockStowLoc{ + ContainerCb: func(id string) (stow.Container, error) { + if id == container { + return newMockStowContainer(container), nil + } + return nil, fmt.Errorf("container is not supported") + }, + CreateContainerCb: func(name string) (stow.Container, error) { + if name == container { + return newMockStowContainer(container), nil + } + return nil, fmt.Errorf("container is not supported") + }, + }, nil, false, metrics) + assert.NoError(t, err) + writeTestFile(ctx, t, s, "s3://container/a/1") + writeTestFile(ctx, t, s, "s3://container/a/2") + var maxResults = 10 + var dataReference DataReference = "s3://container/a" + items, cursor, err := s.List(ctx, dataReference, maxResults, NewCursorAtStart()) + assert.NoError(t, err) + assert.Equal(t, NewCursorAtEnd(), cursor) + assert.Equal(t, []DataReference{"a/1", "a/2"}, items) + }) + + t.Run("Listing with pagination", func(t *testing.T) { + ctx := context.Background() + fn := fQNFn["s3"] + s, err := NewStowRawStore(fn(container), &mockStowLoc{ + ContainerCb: func(id string) (stow.Container, error) { + if id == container { + return newMockStowContainer(container), nil + } + return nil, fmt.Errorf("container is not supported") + }, + CreateContainerCb: func(name string) (stow.Container, error) { + if name == container { + return newMockStowContainer(container), nil + } + return nil, fmt.Errorf("container is not supported") + }, + }, nil, false, metrics) + assert.NoError(t, err) + writeTestFile(ctx, t, s, "s3://container/a/1") + writeTestFile(ctx, t, s, "s3://container/a/2") + var maxResults = 1 + var dataReference DataReference = "s3://container/a" + items, cursor, err := s.List(ctx, dataReference, maxResults, NewCursorAtStart()) + assert.NoError(t, err) + assert.Equal(t, []DataReference{"a/1"}, items) + items, _, err = s.List(ctx, dataReference, maxResults, cursor) + assert.NoError(t, err) + assert.Equal(t, []DataReference{"a/2"}, items) + }) +} + func TestNewLocalStore(t *testing.T) { labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) t.Run("Valid config", func(t *testing.T) { diff --git a/script/generate_helm.sh b/script/generate_helm.sh index 1c836b9002..a0ae15c019 100755 --- a/script/generate_helm.sh +++ b/script/generate_helm.sh @@ -7,7 +7,8 @@ echo "Generating Helm" HELM_SKIP_INSTALL=${HELM_SKIP_INSTALL:-false} if [ "${HELM_SKIP_INSTALL}" != "true" ]; then - curl https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash + # See https://github.com/helm/helm/issues/13324 for a breaking change in latest version of helm + curl https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | DESIRED_VERSION=v3.15.4 bash fi helm version From 59bf19158eb8f60613055dcc3f61ee9b71be0531 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Thu, 12 Sep 2024 11:17:45 -0700 Subject: [PATCH 2/6] Use latest upload/download-artifact action version (#5743) Signed-off-by: Jason Parraga --- .github/workflows/single-binary.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/single-binary.yml b/.github/workflows/single-binary.yml index c40b33e3a4..27ed4fabbc 100644 --- a/.github/workflows/single-binary.yml +++ b/.github/workflows/single-binary.yml @@ -94,7 +94,7 @@ jobs: file: Dockerfile outputs: type=docker,dest=docker/sandbox-bundled/images/tar/amd64/flyte-binary.tar - name: Upload single binary image - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: single-binary-image path: docker/sandbox-bundled/images/tar @@ -133,7 +133,7 @@ jobs: echo "FLYTESNACKS_VERSION=${FLYTESNACKS_VERSION}" >> ${GITHUB_ENV} - name: Checkout uses: actions/checkout@v4 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: single-binary-image path: docker/sandbox-bundled/images/tar @@ -207,7 +207,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: single-binary-image path: docker/sandbox-bundled/images/tar From 44dc0eb3521c81b201e9ae545c47a05cc1afc632 Mon Sep 17 00:00:00 2001 From: Rob Ulbrich <141313113+robert-ulbrich-mercedes-benz@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:33:53 +0200 Subject: [PATCH 3/6] Introduced SMTP notification (#5535) Signed-off-by: Rob Ulbrich --- flyteadmin/go.mod | 2 +- flyteadmin/pkg/async/notifications/factory.go | 200 ++++++- .../pkg/async/notifications/factory_test.go | 32 +- .../notifications/implementations/emailers.go | 1 + .../implementations/smtp_emailer.go | 158 ++++++ .../implementations/smtp_emailer_test.go | 498 ++++++++++++++++++ .../notifications/interfaces/smtp_client.go | 22 + .../async/notifications/mocks/smtp_client.go | 321 +++++++++++ flyteadmin/pkg/rpc/adminservice/base.go | 5 +- .../interfaces/application_configuration.go | 9 +- flyteadmin/pkg/server/service.go | 17 +- 11 files changed, 1243 insertions(+), 22 deletions(-) create mode 100644 flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go create mode 100644 flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go create mode 100644 flyteadmin/pkg/async/notifications/interfaces/smtp_client.go create mode 100644 flyteadmin/pkg/async/notifications/mocks/smtp_client.go diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index b9eba5b83a..2eec0f8cf3 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -51,6 +51,7 @@ require ( github.com/wolfeidau/humanhash v1.1.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0 go.opentelemetry.io/otel v1.24.0 + golang.org/x/net v0.27.0 golang.org/x/oauth2 v0.16.0 golang.org/x/time v0.5.0 google.golang.org/api v0.155.0 @@ -189,7 +190,6 @@ require ( go.opentelemetry.io/proto/otlp v1.1.0 // indirect golang.org/x/crypto v0.25.0 // indirect golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect - golang.org/x/net v0.27.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/term v0.22.0 // indirect diff --git a/flyteadmin/pkg/async/notifications/factory.go b/flyteadmin/pkg/async/notifications/factory.go index f94129a1d5..483978238e 100644 --- a/flyteadmin/pkg/async/notifications/factory.go +++ b/flyteadmin/pkg/async/notifications/factory.go @@ -18,6 +18,7 @@ import ( "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" "github.com/flyteorg/flyte/flyteadmin/pkg/common" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" ) @@ -27,6 +28,7 @@ const maxRetries = 3 var enable64decoding = false var msgChan chan []byte + var once sync.Once type PublisherConfig struct { @@ -35,220 +37,404 @@ type PublisherConfig struct { type ProcessorConfig struct { QueueName string + AccountID string } type EmailerConfig struct { SenderEmail string - BaseURL string + + BaseURL string } // For sandbox only + func CreateMsgChan() { + once.Do(func() { + msgChan = make(chan []byte) + }) + } -func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Emailer { +func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Emailer { + // If an external email service is specified use that instead. + // TODO: Handling of this is messy, see https://github.com/flyteorg/flyte/issues/1063 + if config.NotificationsEmailerConfig.EmailerConfig.ServiceName != "" { + switch config.NotificationsEmailerConfig.EmailerConfig.ServiceName { + case implementations.Sendgrid: + return implementations.NewSendGridEmailer(config, scope) + + case implementations.SMTP: + + return implementations.NewSMTPEmailer(context.Background(), config, scope, sm) + default: + panic(fmt.Errorf("No matching email implementation for %s", config.NotificationsEmailerConfig.EmailerConfig.ServiceName)) + } + } switch config.Type { + case common.AWS: + region := config.AWSConfig.Region + if region == "" { + region = config.Region + } + awsConfig := aws.NewConfig().WithRegion(region).WithMaxRetries(maxRetries) + awsSession, err := session.NewSession(awsConfig) + if err != nil { + panic(err) + } + sesClient := ses.New(awsSession) + return implementations.NewAwsEmailer( + config, + scope, + sesClient, ) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), "Using default noop emailer implementation for config type [%s]", config.Type) + return implementations.NewNoopEmail() + } + } -func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Processor { +func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Processor { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + var sub pubsub.Subscriber + var emailer interfaces.Emailer + switch config.Type { + case common.AWS: + sqsConfig := gizmoAWS.SQSConfig{ - QueueName: config.NotificationsProcessorConfig.QueueName, + + QueueName: config.NotificationsProcessorConfig.QueueName, + QueueOwnerAccountID: config.NotificationsProcessorConfig.AccountID, + // The AWS configuration type uses SNS to SQS for notifications. + // Gizmo by default will decode the SQS message using Base64 decoding. + // However, the message body of SQS is the SNS message format which isn't Base64 encoded. + ConsumeBase64: &enable64decoding, } + if config.AWSConfig.Region != "" { + sqsConfig.Region = config.AWSConfig.Region + } else { + sqsConfig.Region = config.Region + } + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoAWS.NewSubscriber(sqsConfig) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo aws subscriber with config [%+v] and err: %v", sqsConfig, err) + } + return err + }) if err != nil { + panic(err) + } - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewProcessor(sub, emailer, scope) + case common.GCP: + projectID := config.GCPConfig.ProjectID + subscription := config.NotificationsProcessorConfig.QueueName + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + sub, err = gizmoGCP.NewSubscriber(context.TODO(), projectID, subscription) + if err != nil { + logger.Warnf(context.TODO(), "Failed to initialize new gizmo gcp subscriber with config [ProjectID: %s, Subscription: %s] and err: %v", projectID, subscription, err) + } + return err + }) + if err != nil { + panic(err) + } - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewGcpProcessor(sub, emailer, scope) + case common.Sandbox: - emailer = GetEmailer(config, scope) + + emailer = GetEmailer(config, scope, sm) + return implementations.NewSandboxProcessor(msgChan, emailer) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications processor implementation for config type [%s]", config.Type) + return implementations.NewNoopProcess() + } + } func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Publisher { + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + switch config.Type { + case common.AWS: + snsConfig := gizmoAWS.SNSConfig{ + Topic: config.NotificationsPublisherConfig.TopicName, } + if config.AWSConfig.Region != "" { + snsConfig.Region = config.AWSConfig.Region + } else { + snsConfig.Region = config.Region + } var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.GCP: + pubsubConfig := gizmoGCP.Config{ + Topic: config.NotificationsPublisherConfig.TopicName, } + pubsubConfig.ProjectID = config.GCPConfig.ProjectID + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) if err != nil { + panic(err) + } + return implementations.NewPublisher(publisher, scope) + case common.Sandbox: + CreateMsgChan() + return implementations.NewSandboxPublisher(msgChan) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop notifications publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } + } func NewEventsPublisher(config runtimeInterfaces.ExternalEventsConfig, scope promutils.Scope) interfaces.Publisher { + if !config.Enable { + return implementations.NewNoopPublish() + } + reconnectAttempts := config.ReconnectAttempts + reconnectDelay := time.Duration(config.ReconnectDelaySeconds) * time.Second + switch config.Type { + case common.AWS: + snsConfig := gizmoAWS.SNSConfig{ + Topic: config.EventsPublisherConfig.TopicName, } + snsConfig.Region = config.AWSConfig.Region var publisher pubsub.Publisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoAWS.NewPublisher(snsConfig) + return err + }) // Any persistent errors initiating Publisher with Amazon configurations results in a failed start up. + if err != nil { + panic(err) + } + return implementations.NewEventsPublisher(publisher, scope, config.EventsPublisherConfig.EventTypes) + case common.GCP: + pubsubConfig := gizmoGCP.Config{ + Topic: config.EventsPublisherConfig.TopicName, } + pubsubConfig.ProjectID = config.GCPConfig.ProjectID + var publisher pubsub.MultiPublisher + var err error + err = async.Retry(reconnectAttempts, reconnectDelay, func() error { + publisher, err = gizmoGCP.NewPublisher(context.TODO(), pubsubConfig) + return err + }) if err != nil { + panic(err) + } + return implementations.NewEventsPublisher(publisher, scope, config.EventsPublisherConfig.EventTypes) + case common.Local: + fallthrough + default: + logger.Infof(context.Background(), + "Using default noop events publisher implementation for config type [%s]", config.Type) + return implementations.NewNoopPublish() + } + } diff --git a/flyteadmin/pkg/async/notifications/factory_test.go b/flyteadmin/pkg/async/notifications/factory_test.go index 1bfd1f4596..43602525a5 100644 --- a/flyteadmin/pkg/async/notifications/factory_test.go +++ b/flyteadmin/pkg/async/notifications/factory_test.go @@ -9,52 +9,76 @@ import ( "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/implementations" runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/flytestdlib/promutils" ) var ( - scope = promutils.NewScope("test_sandbox_processor") + scope = promutils.NewScope("test_sandbox_processor") + notificationsConfig = runtimeInterfaces.NotificationsConfig{ + Type: "sandbox", } + testEmail = admin.EmailMessage{ + RecipientsEmail: []string{ + "a@example.com", + "b@example.com", }, + SenderEmail: "no-reply@example.com", + SubjectLine: "Test email", - Body: "This is a sample email.", + + Body: "This is a sample email.", } ) func TestGetEmailer(t *testing.T) { + defer func() { r := recover(); assert.NotNil(t, r) }() + cfg := runtimeInterfaces.NotificationsConfig{ + NotificationsEmailerConfig: runtimeInterfaces.NotificationsEmailerConfig{ + EmailerConfig: runtimeInterfaces.EmailServerConfig{ + ServiceName: "unsupported", }, }, } - GetEmailer(cfg, promutils.NewTestScope()) + GetEmailer(cfg, promutils.NewTestScope(), &mocks.SecretManager{}) // shouldn't reach here + t.Errorf("did not panic") + } func TestNewNotificationPublisherAndProcessor(t *testing.T) { + testSandboxPublisher := NewNotificationsPublisher(notificationsConfig, scope) + assert.IsType(t, testSandboxPublisher, &implementations.SandboxPublisher{}) - testSandboxProcessor := NewNotificationsProcessor(notificationsConfig, scope) + + testSandboxProcessor := NewNotificationsProcessor(notificationsConfig, scope, &mocks.SecretManager{}) + assert.IsType(t, testSandboxProcessor, &implementations.SandboxProcessor{}) go func() { + testSandboxProcessor.StartProcessing() + }() assert.Nil(t, testSandboxPublisher.Publish(context.Background(), "TEST_NOTIFICATION", &testEmail)) assert.Nil(t, testSandboxProcessor.StopProcessing()) + } diff --git a/flyteadmin/pkg/async/notifications/implementations/emailers.go b/flyteadmin/pkg/async/notifications/implementations/emailers.go index e630b5a4ea..0da3fbf600 100644 --- a/flyteadmin/pkg/async/notifications/implementations/emailers.go +++ b/flyteadmin/pkg/async/notifications/implementations/emailers.go @@ -4,4 +4,5 @@ type ExternalEmailer = string const ( Sendgrid ExternalEmailer = "sendgrid" + SMTP ExternalEmailer = "smtp" ) diff --git a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go new file mode 100644 index 0000000000..5a705bc0c1 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer.go @@ -0,0 +1,158 @@ +package implementations + +import ( + "crypto/tls" + "fmt" + "net/smtp" + "strings" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + + "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" + "github.com/flyteorg/flyte/flyteadmin/pkg/errors" + runtimeInterfaces "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +type SMTPEmailer struct { + config *runtimeInterfaces.NotificationsEmailerConfig + systemMetrics emailMetrics + tlsConf *tls.Config + auth *smtp.Auth + smtpClient interfaces.SMTPClient + CreateSMTPClientFunc func(connectString string) (interfaces.SMTPClient, error) +} + +func (s *SMTPEmailer) createClient(ctx context.Context) (interfaces.SMTPClient, error) { + newClient, err := s.CreateSMTPClientFunc(s.config.EmailerConfig.SMTPServer + ":" + s.config.EmailerConfig.SMTPPort) + + if err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error creating email client: %s", err)) + } + + if err = newClient.Hello("localhost"); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error initiating connection to SMTP server: %s", err)) + } + + if ok, _ := newClient.Extension("STARTTLS"); ok { + if err = newClient.StartTLS(s.tlsConf); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error initiating connection to SMTP server: %s", err)) + } + } + + if ok, _ := newClient.Extension("AUTH"); ok { + if err = newClient.Auth(*s.auth); err != nil { + return nil, s.emailError(ctx, fmt.Sprintf("Error authenticating email client: %s", err)) + } + } + + return newClient, nil +} + +func (s *SMTPEmailer) SendEmail(ctx context.Context, email *admin.EmailMessage) error { + + if s.smtpClient == nil || s.smtpClient.Noop() != nil { + + if s.smtpClient != nil { + err := s.smtpClient.Close() + if err != nil { + logger.Info(ctx, err) + } + } + smtpClient, err := s.createClient(ctx) + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error creating SMTP email client: %s", err)) + } + + s.smtpClient = smtpClient + } + + if err := s.smtpClient.Mail(email.SenderEmail); err != nil { + return s.emailError(ctx, fmt.Sprintf("Error creating email instance: %s", err)) + } + + for _, recipient := range email.RecipientsEmail { + if err := s.smtpClient.Rcpt(recipient); err != nil { + return s.emailError(ctx, fmt.Sprintf("Error adding email recipient: %s", err)) + } + } + + writer, err := s.smtpClient.Data() + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error adding email recipient: %s", err)) + } + + _, err = writer.Write([]byte(createMailBody(s.config.Sender, email))) + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error writing mail body: %s", err)) + } + + err = writer.Close() + + if err != nil { + return s.emailError(ctx, fmt.Sprintf("Error closing mail body: %s", err)) + } + + s.systemMetrics.SendSuccess.Inc() + return nil +} + +func (s *SMTPEmailer) emailError(ctx context.Context, error string) error { + s.systemMetrics.SendError.Inc() + logger.Error(ctx, error) + return errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails") +} + +func createMailBody(emailSender string, email *admin.EmailMessage) string { + headerMap := make(map[string]string) + headerMap["From"] = emailSender + headerMap["To"] = strings.Join(email.RecipientsEmail, ",") + headerMap["Subject"] = email.SubjectLine + headerMap["Content-Type"] = "text/html; charset=\"UTF-8\"" + + mailMessage := "" + + for k, v := range headerMap { + mailMessage += fmt.Sprintf("%s: %s\r\n", k, v) + } + + mailMessage += "\r\n" + email.Body + + return mailMessage +} + +func NewSMTPEmailer(ctx context.Context, config runtimeInterfaces.NotificationsConfig, scope promutils.Scope, sm core.SecretManager) interfaces.Emailer { + var tlsConfiguration *tls.Config + emailConf := config.NotificationsEmailerConfig.EmailerConfig + + smtpPassword, err := sm.Get(ctx, emailConf.SMTPPasswordSecretName) + if err != nil { + logger.Debug(ctx, "No SMTP password found.") + smtpPassword = "" + } + + auth := smtp.PlainAuth("", emailConf.SMTPUsername, smtpPassword, emailConf.SMTPServer) + + // #nosec G402 + tlsConfiguration = &tls.Config{ + InsecureSkipVerify: emailConf.SMTPSkipTLSVerify, + ServerName: emailConf.SMTPServer, + } + + return &SMTPEmailer{ + config: &config.NotificationsEmailerConfig, + systemMetrics: newEmailMetrics(scope.NewSubScope("smtp")), + tlsConf: tlsConfiguration, + auth: &auth, + CreateSMTPClientFunc: func(connectString string) (interfaces.SMTPClient, error) { + return smtp.Dial(connectString) + }, + } +} diff --git a/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go new file mode 100644 index 0000000000..558a5d6408 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/implementations/smtp_emailer_test.go @@ -0,0 +1,498 @@ +package implementations + +import ( + "context" + "crypto/tls" + "errors" + "net/smtp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/codes" + + notification_interfaces "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/interfaces" + notification_mocks "github.com/flyteorg/flyte/flyteadmin/pkg/async/notifications/mocks" + flyte_errors "github.com/flyteorg/flyte/flyteadmin/pkg/errors" + "github.com/flyteorg/flyte/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +type StringWriter struct { + buffer string + writeErr error + closeErr error +} + +func (s *StringWriter) Write(p []byte) (n int, err error) { + s.buffer = s.buffer + string(p) + return len(p), s.writeErr +} + +func (s *StringWriter) Close() error { + return s.closeErr +} + +func getNotificationsEmailerConfig() interfaces.NotificationsConfig { + return interfaces.NotificationsConfig{ + Type: "", + Region: "", + AWSConfig: interfaces.AWSConfig{}, + GCPConfig: interfaces.GCPConfig{}, + NotificationsPublisherConfig: interfaces.NotificationsPublisherConfig{}, + NotificationsProcessorConfig: interfaces.NotificationsProcessorConfig{}, + NotificationsEmailerConfig: interfaces.NotificationsEmailerConfig{ + EmailerConfig: interfaces.EmailServerConfig{ + ServiceName: SMTP, + SMTPServer: "smtpServer", + SMTPPort: "smtpPort", + SMTPUsername: "smtpUsername", + SMTPPasswordSecretName: "smtp_password", + }, + Subject: "subject", + Sender: "sender", + Body: "body"}, + ReconnectAttempts: 1, + ReconnectDelaySeconds: 2} +} + +func TestEmailCreation(t *testing.T) { + email := admin.EmailMessage{ + RecipientsEmail: []string{"john@doe.com", "teresa@tester.com"}, + SubjectLine: "subject", + Body: "Email Body", + SenderEmail: "sender@sender.com", + } + + body := createMailBody("sender@sender.com", &email) + assert.Contains(t, body, "From: sender@sender.com\r\n") + assert.Contains(t, body, "To: john@doe.com,teresa@tester.com") + assert.Contains(t, body, "Subject: subject\r\n") + assert.Contains(t, body, "Content-Type: text/html; charset=\"UTF-8\"\r\n") + assert.Contains(t, body, "Email Body") +} + +func TestNewSmtpEmailer(t *testing.T) { + secretManagerMock := mocks.SecretManager{} + secretManagerMock.On("Get", mock.Anything, "smtp_password").Return("password", nil) + + notificationsConfig := getNotificationsEmailerConfig() + + smtpEmailer := NewSMTPEmailer(context.Background(), notificationsConfig, promutils.NewTestScope(), &secretManagerMock) + + assert.NotNil(t, smtpEmailer) +} + +func TestCreateClient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil) + smtpClient.On("Extension", "STARTTLS").Return(true, "") + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil) + smtpClient.On("Extension", "AUTH").Return(true, "") + smtpClient.On("Auth", auth).Return(nil) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Nil(t, err) + assert.NotNil(t, client) + +} + +func TestCreateClientErrorCreatingClient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, errors.New("error creating client")) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorHello(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(errors.New("Error with hello")) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorStartTLS(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(errors.New("Error with startls")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestCreateClientErrorAuth(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(errors.New("Error with hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + client, err := smtpEmailer.createClient(context.Background()) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + assert.Nil(t, client) + +} + +func TestSendMail(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: ""} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.True(t, strings.Contains(stringWriter.buffer, "From: sender")) + assert.True(t, strings.Contains(stringWriter.buffer, "To: alice@flyte.org,bob@flyte.org")) + assert.True(t, strings.Contains(stringWriter.buffer, "Subject: subject")) + assert.True(t, strings.Contains(stringWriter.buffer, "This is an email.")) + assert.Nil(t, err) + +} + +func TestSendMailCreateClientError(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(errors.New("error hello")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorMail(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(errors.New("error sending mail")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorRecipient(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(errors.New("error adding recipient")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorData(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(nil, errors.New("error creating data writer")).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorWriting(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: "", writeErr: errors.New("error writing"), closeErr: nil} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func TestSendMailErrorClose(t *testing.T) { + auth := smtp.PlainAuth("", "user", "password", "localhost") + + tlsConf := tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + } + + stringWriter := StringWriter{buffer: "", writeErr: nil, closeErr: errors.New("error writing")} + + smtpClient := ¬ification_mocks.SMTPClient{} + smtpClient.On("Noop").Return(errors.New("no connection")).Times(1) + smtpClient.On("Close").Return(nil).Times(1) + smtpClient.On("Hello", "localhost").Return(nil).Times(1) + smtpClient.On("Extension", "STARTTLS").Return(true, "").Times(1) + smtpClient.On("StartTLS", &tls.Config{ + InsecureSkipVerify: false, + ServerName: "localhost", + MinVersion: tls.VersionTLS13, + }).Return(nil).Times(1) + smtpClient.On("Extension", "AUTH").Return(true, "").Times(1) + smtpClient.On("Auth", auth).Return(nil).Times(1) + smtpClient.On("Mail", "flyte@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "alice@flyte.org").Return(nil).Times(1) + smtpClient.On("Rcpt", "bob@flyte.org").Return(nil).Times(1) + smtpClient.On("Data").Return(&stringWriter, nil).Times(1) + + smtpEmailer := createSMTPEmailer(smtpClient, &tlsConf, &auth, nil) + + err := smtpEmailer.SendEmail(context.Background(), &admin.EmailMessage{ + SubjectLine: "subject", + SenderEmail: "flyte@flyte.org", + RecipientsEmail: []string{"alice@flyte.org", "bob@flyte.org"}, + Body: "This is an email.", + }) + + assert.True(t, strings.Contains(stringWriter.buffer, "From: sender")) + assert.True(t, strings.Contains(stringWriter.buffer, "To: alice@flyte.org,bob@flyte.org")) + assert.True(t, strings.Contains(stringWriter.buffer, "Subject: subject")) + assert.True(t, strings.Contains(stringWriter.buffer, "This is an email.")) + assert.Equal(t, flyte_errors.NewFlyteAdminErrorf(codes.Internal, "errors were seen while sending emails"), err) + +} + +func createSMTPEmailer(smtpClient notification_interfaces.SMTPClient, tlsConf *tls.Config, auth *smtp.Auth, creationErr error) *SMTPEmailer { + secretManagerMock := mocks.SecretManager{} + secretManagerMock.On("Get", mock.Anything, "smtp_password").Return("password", nil) + + notificationsConfig := getNotificationsEmailerConfig() + + return &SMTPEmailer{ + config: ¬ificationsConfig.NotificationsEmailerConfig, + systemMetrics: newEmailMetrics(promutils.NewTestScope()), + tlsConf: tlsConf, + auth: auth, + CreateSMTPClientFunc: func(connectString string) (notification_interfaces.SMTPClient, error) { + return smtpClient, creationErr + }, + smtpClient: smtpClient, + } +} diff --git a/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go b/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go new file mode 100644 index 0000000000..bdc6171f46 --- /dev/null +++ b/flyteadmin/pkg/async/notifications/interfaces/smtp_client.go @@ -0,0 +1,22 @@ +package interfaces + +import ( + "crypto/tls" + "io" + "net/smtp" +) + +// This interface is introduced to allow for mocking of the smtp.Client object. + +//go:generate mockery -name=SMTPClient -output=../mocks -case=underscore +type SMTPClient interface { + Hello(localName string) error + Extension(ext string) (bool, string) + Auth(a smtp.Auth) error + StartTLS(config *tls.Config) error + Noop() error + Close() error + Mail(from string) error + Rcpt(to string) error + Data() (io.WriteCloser, error) +} diff --git a/flyteadmin/pkg/async/notifications/mocks/smtp_client.go b/flyteadmin/pkg/async/notifications/mocks/smtp_client.go new file mode 100644 index 0000000000..11dafefc9c --- /dev/null +++ b/flyteadmin/pkg/async/notifications/mocks/smtp_client.go @@ -0,0 +1,321 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + io "io" + smtp "net/smtp" + + mock "github.com/stretchr/testify/mock" + + tls "crypto/tls" +) + +// SMTPClient is an autogenerated mock type for the SMTPClient type +type SMTPClient struct { + mock.Mock +} + +type SMTPClient_Auth struct { + *mock.Call +} + +func (_m SMTPClient_Auth) Return(_a0 error) *SMTPClient_Auth { + return &SMTPClient_Auth{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnAuth(a smtp.Auth) *SMTPClient_Auth { + c_call := _m.On("Auth", a) + return &SMTPClient_Auth{Call: c_call} +} + +func (_m *SMTPClient) OnAuthMatch(matchers ...interface{}) *SMTPClient_Auth { + c_call := _m.On("Auth", matchers...) + return &SMTPClient_Auth{Call: c_call} +} + +// Auth provides a mock function with given fields: a +func (_m *SMTPClient) Auth(a smtp.Auth) error { + ret := _m.Called(a) + + var r0 error + if rf, ok := ret.Get(0).(func(smtp.Auth) error); ok { + r0 = rf(a) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Close struct { + *mock.Call +} + +func (_m SMTPClient_Close) Return(_a0 error) *SMTPClient_Close { + return &SMTPClient_Close{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnClose() *SMTPClient_Close { + c_call := _m.On("Close") + return &SMTPClient_Close{Call: c_call} +} + +func (_m *SMTPClient) OnCloseMatch(matchers ...interface{}) *SMTPClient_Close { + c_call := _m.On("Close", matchers...) + return &SMTPClient_Close{Call: c_call} +} + +// Close provides a mock function with given fields: +func (_m *SMTPClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Data struct { + *mock.Call +} + +func (_m SMTPClient_Data) Return(_a0 io.WriteCloser, _a1 error) *SMTPClient_Data { + return &SMTPClient_Data{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SMTPClient) OnData() *SMTPClient_Data { + c_call := _m.On("Data") + return &SMTPClient_Data{Call: c_call} +} + +func (_m *SMTPClient) OnDataMatch(matchers ...interface{}) *SMTPClient_Data { + c_call := _m.On("Data", matchers...) + return &SMTPClient_Data{Call: c_call} +} + +// Data provides a mock function with given fields: +func (_m *SMTPClient) Data() (io.WriteCloser, error) { + ret := _m.Called() + + var r0 io.WriteCloser + if rf, ok := ret.Get(0).(func() io.WriteCloser); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.WriteCloser) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SMTPClient_Extension struct { + *mock.Call +} + +func (_m SMTPClient_Extension) Return(_a0 bool, _a1 string) *SMTPClient_Extension { + return &SMTPClient_Extension{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SMTPClient) OnExtension(ext string) *SMTPClient_Extension { + c_call := _m.On("Extension", ext) + return &SMTPClient_Extension{Call: c_call} +} + +func (_m *SMTPClient) OnExtensionMatch(matchers ...interface{}) *SMTPClient_Extension { + c_call := _m.On("Extension", matchers...) + return &SMTPClient_Extension{Call: c_call} +} + +// Extension provides a mock function with given fields: ext +func (_m *SMTPClient) Extension(ext string) (bool, string) { + ret := _m.Called(ext) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(ext) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 string + if rf, ok := ret.Get(1).(func(string) string); ok { + r1 = rf(ext) + } else { + r1 = ret.Get(1).(string) + } + + return r0, r1 +} + +type SMTPClient_Hello struct { + *mock.Call +} + +func (_m SMTPClient_Hello) Return(_a0 error) *SMTPClient_Hello { + return &SMTPClient_Hello{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnHello(localName string) *SMTPClient_Hello { + c_call := _m.On("Hello", localName) + return &SMTPClient_Hello{Call: c_call} +} + +func (_m *SMTPClient) OnHelloMatch(matchers ...interface{}) *SMTPClient_Hello { + c_call := _m.On("Hello", matchers...) + return &SMTPClient_Hello{Call: c_call} +} + +// Hello provides a mock function with given fields: localName +func (_m *SMTPClient) Hello(localName string) error { + ret := _m.Called(localName) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(localName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Mail struct { + *mock.Call +} + +func (_m SMTPClient_Mail) Return(_a0 error) *SMTPClient_Mail { + return &SMTPClient_Mail{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnMail(from string) *SMTPClient_Mail { + c_call := _m.On("Mail", from) + return &SMTPClient_Mail{Call: c_call} +} + +func (_m *SMTPClient) OnMailMatch(matchers ...interface{}) *SMTPClient_Mail { + c_call := _m.On("Mail", matchers...) + return &SMTPClient_Mail{Call: c_call} +} + +// Mail provides a mock function with given fields: from +func (_m *SMTPClient) Mail(from string) error { + ret := _m.Called(from) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(from) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Noop struct { + *mock.Call +} + +func (_m SMTPClient_Noop) Return(_a0 error) *SMTPClient_Noop { + return &SMTPClient_Noop{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnNoop() *SMTPClient_Noop { + c_call := _m.On("Noop") + return &SMTPClient_Noop{Call: c_call} +} + +func (_m *SMTPClient) OnNoopMatch(matchers ...interface{}) *SMTPClient_Noop { + c_call := _m.On("Noop", matchers...) + return &SMTPClient_Noop{Call: c_call} +} + +// Noop provides a mock function with given fields: +func (_m *SMTPClient) Noop() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_Rcpt struct { + *mock.Call +} + +func (_m SMTPClient_Rcpt) Return(_a0 error) *SMTPClient_Rcpt { + return &SMTPClient_Rcpt{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnRcpt(to string) *SMTPClient_Rcpt { + c_call := _m.On("Rcpt", to) + return &SMTPClient_Rcpt{Call: c_call} +} + +func (_m *SMTPClient) OnRcptMatch(matchers ...interface{}) *SMTPClient_Rcpt { + c_call := _m.On("Rcpt", matchers...) + return &SMTPClient_Rcpt{Call: c_call} +} + +// Rcpt provides a mock function with given fields: to +func (_m *SMTPClient) Rcpt(to string) error { + ret := _m.Called(to) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(to) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SMTPClient_StartTLS struct { + *mock.Call +} + +func (_m SMTPClient_StartTLS) Return(_a0 error) *SMTPClient_StartTLS { + return &SMTPClient_StartTLS{Call: _m.Call.Return(_a0)} +} + +func (_m *SMTPClient) OnStartTLS(config *tls.Config) *SMTPClient_StartTLS { + c_call := _m.On("StartTLS", config) + return &SMTPClient_StartTLS{Call: c_call} +} + +func (_m *SMTPClient) OnStartTLSMatch(matchers ...interface{}) *SMTPClient_StartTLS { + c_call := _m.On("StartTLS", matchers...) + return &SMTPClient_StartTLS{Call: c_call} +} + +// StartTLS provides a mock function with given fields: config +func (_m *SMTPClient) StartTLS(config *tls.Config) error { + ret := _m.Called(config) + + var r0 error + if rf, ok := ret.Get(0).(func(*tls.Config) error); ok { + r0 = rf(config) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/flyteadmin/pkg/rpc/adminservice/base.go b/flyteadmin/pkg/rpc/adminservice/base.go index 8df2c595c7..491a24a1f0 100644 --- a/flyteadmin/pkg/rpc/adminservice/base.go +++ b/flyteadmin/pkg/rpc/adminservice/base.go @@ -20,6 +20,7 @@ import ( workflowengineImpl "github.com/flyteorg/flyte/flyteadmin/pkg/workflowengine/impl" "github.com/flyteorg/flyte/flyteadmin/plugins" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" @@ -45,7 +46,7 @@ type AdminService struct { const defaultRetries = 3 func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, configuration runtimeIfaces.Configuration, - kubeConfig, master string, dataStorageClient *storage.DataStore, adminScope promutils.Scope) *AdminService { + kubeConfig, master string, dataStorageClient *storage.DataStore, adminScope promutils.Scope, sm core.SecretManager) *AdminService { applicationConfiguration := configuration.ApplicationConfiguration().GetTopLevelConfig() panicCounter := adminScope.MustNewCounter("initialization_panic", @@ -81,7 +82,7 @@ func NewAdminServer(ctx context.Context, pluginRegistry *plugins.Registry, confi pluginRegistry.RegisterDefault(plugins.PluginIDWorkflowExecutor, workflowExecutor) publisher := notifications.NewNotificationsPublisher(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope) - processor := notifications.NewNotificationsProcessor(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope) + processor := notifications.NewNotificationsProcessor(*configuration.ApplicationConfiguration().GetNotificationsConfig(), adminScope, sm) eventPublisher := notifications.NewEventsPublisher(*configuration.ApplicationConfiguration().GetExternalEventsConfig(), adminScope) go func() { logger.Info(ctx, "Started processing notifications.") diff --git a/flyteadmin/pkg/runtime/interfaces/application_configuration.go b/flyteadmin/pkg/runtime/interfaces/application_configuration.go index 3505150919..e3453db0f7 100644 --- a/flyteadmin/pkg/runtime/interfaces/application_configuration.go +++ b/flyteadmin/pkg/runtime/interfaces/application_configuration.go @@ -492,8 +492,13 @@ type NotificationsProcessorConfig struct { type EmailServerConfig struct { ServiceName string `json:"serviceName"` // Only one of these should be set. - APIKeyEnvVar string `json:"apiKeyEnvVar"` - APIKeyFilePath string `json:"apiKeyFilePath"` + APIKeyEnvVar string `json:"apiKeyEnvVar"` + APIKeyFilePath string `json:"apiKeyFilePath"` + SMTPServer string `json:"smtpServer"` + SMTPPort string `json:"smtpPort"` + SMTPSkipTLSVerify bool `json:"smtpSkipTLSVerify"` + SMTPUsername string `json:"smtpUsername"` + SMTPPasswordSecretName string `json:"smtpPasswordSecretName"` } // This section handles the configuration of notifications emails. diff --git a/flyteadmin/pkg/server/service.go b/flyteadmin/pkg/server/service.go index 587ea86e3b..840d0d9f17 100644 --- a/flyteadmin/pkg/server/service.go +++ b/flyteadmin/pkg/server/service.go @@ -43,6 +43,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/clients/go/assets" grpcService "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/gateway/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/secretmanager" "github.com/flyteorg/flyte/flytestdlib/contextutils" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -82,7 +83,7 @@ func SetMetricKeys(appConfig *runtimeIfaces.ApplicationConfig) { // Creates a new gRPC Server with all the configuration func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, storageCfg *storage.Config, authCtx interfaces.AuthenticationContext, - scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) { + scope promutils.Scope, sm core.SecretManager, opts ...grpc.ServerOption) (*grpc.Server, error) { logger.Infof(ctx, "Registering default middleware with blanket auth validation") pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer( @@ -152,7 +153,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c } configuration := runtime2.NewConfigurationProvider() - adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, adminScope) + adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, adminScope, sm) grpcService.RegisterAdminServiceServer(grpcServer, adminServer) if cfg.Security.UseAuth { grpcService.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService()) @@ -339,12 +340,15 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, // This will parse configuration and create the necessary objects for dealing with auth var authCtx interfaces.AuthenticationContext var err error + + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + // This code is here to support authentication without SSL. This setup supports a network topology where // Envoy does the SSL termination. The final hop is made over localhost only on a trusted machine. // Warning: Running authentication without SSL in any other topology is a severe security flaw. // See the auth.Config object for additional settings as well. if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + var oauth2Provider interfaces.OAuth2Provider var oauth2ResourceServer interfaces.OAuth2ResourceServer if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { @@ -373,7 +377,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, } } - grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageConfig, authCtx, scope) + grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageConfig, authCtx, scope, sm) if err != nil { return fmt.Errorf("failed to create a newGRPCServer. Error: %w", err) } @@ -448,13 +452,14 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error { certPool, cert, err := GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + if err != nil { return err } // This will parse configuration and create the necessary objects for dealing with auth var authCtx interfaces.AuthenticationContext if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) var oauth2Provider interfaces.OAuth2Provider var oauth2ResourceServer interfaces.OAuth2ResourceServer if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { @@ -483,7 +488,7 @@ func serveGatewaySecure(ctx context.Context, pluginRegistry *plugins.Registry, c } } - grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageCfg, authCtx, scope, grpc.Creds(credentials.NewServerTLSFromCert(cert))) + grpcServer, err := newGRPCServer(ctx, pluginRegistry, cfg, storageCfg, authCtx, scope, sm, grpc.Creds(credentials.NewServerTLSFromCert(cert))) if err != nil { return fmt.Errorf("failed to create a newGRPCServer. Error: %w", err) } From 7989209e15600b56fcf0f4c4a7c9af7bfeab6f3e Mon Sep 17 00:00:00 2001 From: Prafulla Mahindrakar Date: Mon, 16 Sep 2024 09:52:43 -0700 Subject: [PATCH 4/6] Added literal offloading for array node map tasks (#5697) * Added literal offloading for array node map tasks Signed-off-by: pmahindrakar-oss * fix Signed-off-by: pmahindrakar-oss * feedback Signed-off-by: pmahindrakar-oss * feedback Signed-off-by: pmahindrakar-oss * nit Signed-off-by: pmahindrakar-oss * feedback Signed-off-by: pmahindrakar-oss * add missing flag files Signed-off-by: pmahindrakar-oss * disabling flag until flytekit release is confirmed Signed-off-by: pmahindrakar-oss * nit Signed-off-by: pmahindrakar-oss --------- Signed-off-by: pmahindrakar-oss --- flyteadmin/go.mod | 1 + flyteadmin/go.sum | 2 + .../manager/impl/testutils/mock_requests.go | 22 +++ .../impl/validation/execution_validator.go | 8 +- .../validation/execution_validator_test.go | 34 ++++ .../pkg/manager/impl/validation/validation.go | 2 +- flytepropeller/go.mod | 1 + flytepropeller/go.sum | 2 + .../pkg/apis/flyteworkflow/v1alpha1/iface.go | 4 + .../pkg/compiler/transformers/k8s/inputs.go | 8 +- .../pkg/controller/config/config.go | 119 +++++++---- .../pkg/controller/config/config_flags.go | 4 + .../controller/config/config_flags_test.go | 56 ++++++ .../pkg/controller/config/config_test.go | 54 +++++ flytepropeller/pkg/controller/controller.go | 4 +- .../pkg/controller/nodes/array/handler.go | 17 +- .../controller/nodes/array/handler_test.go | 6 +- .../pkg/controller/nodes/common/utils.go | 98 +++++++++- .../pkg/controller/nodes/common/utils_test.go | 185 ++++++++++++++++++ .../pkg/controller/nodes/executor.go | 8 +- .../pkg/controller/nodes/executor_test.go | 32 +-- .../nodes/factory/handler_factory.go | 45 +++-- .../pkg/controller/workflow/executor_test.go | 18 +- go.mod | 1 + go.sum | 2 + 25 files changed, 642 insertions(+), 91 deletions(-) create mode 100644 flytepropeller/pkg/controller/config/config_test.go diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index 2eec0f8cf3..cfc2bfa010 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -82,6 +82,7 @@ require ( cloud.google.com/go/pubsub v1.34.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect + github.com/Masterminds/semver v1.5.0 // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 049add4bbc..5b7d47b2a6 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -75,6 +75,8 @@ github.com/DataDog/datadog-go v3.4.1+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/datadog-go v4.0.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20191210083620-6965a1cfed68/go.mod h1:gMGUEe16aZh0QN941HgDjwrdjU4iTthPoz2/AtDRADE= github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= diff --git a/flyteadmin/pkg/manager/impl/testutils/mock_requests.go b/flyteadmin/pkg/manager/impl/testutils/mock_requests.go index b868612269..b3d01897f1 100644 --- a/flyteadmin/pkg/manager/impl/testutils/mock_requests.go +++ b/flyteadmin/pkg/manager/impl/testutils/mock_requests.go @@ -241,6 +241,28 @@ func GetExecutionRequest() *admin.ExecutionCreateRequest { } } +func GetExecutionRequestWithOffloadedInputs(inputParam string, literalValue *core.Literal) *admin.ExecutionCreateRequest { + execReq := GetExecutionRequest() + execReq.Inputs = &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": { + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "s3://bucket/key", + SizeBytes: 100, + InferredType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + }, + }, + } + return execReq +} + func GetSampleWorkflowSpecForTest() *admin.WorkflowSpec { return &admin.WorkflowSpec{ Template: &core.WorkflowTemplate{ diff --git a/flyteadmin/pkg/manager/impl/validation/execution_validator.go b/flyteadmin/pkg/manager/impl/validation/execution_validator.go index 0a21165c93..f7b385b8a8 100644 --- a/flyteadmin/pkg/manager/impl/validation/execution_validator.go +++ b/flyteadmin/pkg/manager/impl/validation/execution_validator.go @@ -100,7 +100,13 @@ func CheckAndFetchInputsForExecution( } executionInputMap[name] = expectedInput.GetDefault() } else { - inputType := validators.LiteralTypeForLiteral(executionInputMap[name]) + var inputType *core.LiteralType + switch executionInputMap[name].GetValue().(type) { + case *core.Literal_OffloadedMetadata: + inputType = executionInputMap[name].GetOffloadedMetadata().GetInferredType() + default: + inputType = validators.LiteralTypeForLiteral(executionInputMap[name]) + } if !validators.AreTypesCastable(inputType, expectedInput.GetVar().GetType()) { return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got %s", name, expectedInput.GetVar().GetType(), inputType) } diff --git a/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go b/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go index 1329dc6f96..7e5f991788 100644 --- a/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/execution_validator_test.go @@ -105,6 +105,40 @@ func TestGetExecutionInputs(t *testing.T) { assert.EqualValues(t, expectedMap, actualInputs) } +func TestGetExecutionWithOffloadedInputs(t *testing.T) { + execLiteral := &core.Literal{ + Value: &core.Literal_OffloadedMetadata{ + OffloadedMetadata: &core.LiteralOffloadedMetadata{ + Uri: "s3://bucket/key", + SizeBytes: 100, + InferredType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + } + executionRequest := testutils.GetExecutionRequestWithOffloadedInputs("foo", execLiteral) + lpRequest := testutils.GetLaunchPlanRequest() + + actualInputs, err := CheckAndFetchInputsForExecution( + executionRequest.Inputs, + lpRequest.Spec.FixedInputs, + lpRequest.Spec.DefaultInputs, + ) + expectedMap := core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": execLiteral, + "bar": coreutils.MustMakeLiteral("bar-value"), + }, + } + assert.Nil(t, err) + assert.NotNil(t, actualInputs) + assert.EqualValues(t, expectedMap.GetLiterals()["foo"], actualInputs.Literals["foo"]) + assert.EqualValues(t, expectedMap.GetLiterals()["bar"], actualInputs.Literals["bar"]) +} + func TestValidateExecInputsWrongType(t *testing.T) { executionRequest := testutils.GetExecutionRequest() lpRequest := testutils.GetLaunchPlanRequest() diff --git a/flyteadmin/pkg/manager/impl/validation/validation.go b/flyteadmin/pkg/manager/impl/validation/validation.go index 6c9bd2fdbb..894eaee435 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation.go +++ b/flyteadmin/pkg/manager/impl/validation/validation.go @@ -234,7 +234,7 @@ func validateLiteralMap(inputMap *core.LiteralMap, fieldName string) error { if name == "" { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing key in %s", fieldName) } - if fixedInput == nil || fixedInput.GetValue() == nil { + if fixedInput.GetValue() == nil && fixedInput.GetOffloadedMetadata() == nil { return errors.NewFlyteAdminErrorf(codes.InvalidArgument, "missing valid literal in %s %s", fieldName, name) } if isDateTime(fixedInput) { diff --git a/flytepropeller/go.mod b/flytepropeller/go.mod index 5d828f9e9b..f579049aff 100644 --- a/flytepropeller/go.mod +++ b/flytepropeller/go.mod @@ -4,6 +4,7 @@ go 1.22 require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 + github.com/Masterminds/semver v1.5.0 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 github.com/flyteorg/flyte/flyteidl v0.0.0-00010101000000-000000000000 diff --git a/flytepropeller/go.sum b/flytepropeller/go.sum index 8bbdd06eba..07a92b902b 100644 --- a/flytepropeller/go.sum +++ b/flytepropeller/go.sum @@ -64,6 +64,8 @@ github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 h1:xJ0dAkuxJXf github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295/go.mod h1:e0aH495YLkrsIe9fhedd6aSR6fgU/qhKvtroi6y7G/M= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625/go.mod h1:6PnrZv6zUDkrNMw0mIoGRmGBR7i9LulhKPmxFq4rUiM= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/aws/aws-sdk-go v1.44.2 h1:5VBk5r06bgxgRKVaUtm1/4NT/rtrnH2E4cnAYv5zgQc= diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index bcd1064e67..486ac35a16 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -554,6 +554,10 @@ func GetOutputsFile(outputDir DataReference) DataReference { return outputDir + "/outputs.pb" } +func GetOutputsLiteralMetadataFile(literalKey string, outputDir DataReference) DataReference { + return outputDir + DataReference(fmt.Sprintf("/%s_offloaded_metadata.pb", literalKey)) +} + func GetInputsFile(inputDir DataReference) DataReference { return inputDir + "/inputs.pb" } diff --git a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go index 2d967c560e..0976df669b 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/inputs.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/inputs.go @@ -35,7 +35,13 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor continue } - inputType := validators.LiteralTypeForLiteral(inputVal) + var inputType *core.LiteralType + switch inputVal.GetValue().(type) { + case *core.Literal_OffloadedMetadata: + inputType = inputVal.GetOffloadedMetadata().GetInferredType() + default: + inputType = validators.LiteralTypeForLiteral(inputVal) + } if !validators.AreTypesCastable(inputType, v.Type) { errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String())) continue diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index a0217e186a..2d61c94970 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -34,12 +34,16 @@ package config import ( + "context" + "fmt" "time" + "github.com/Masterminds/semver" "k8s.io/apimachinery/pkg/types" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/contextutils" + "github.com/flyteorg/flyte/flytestdlib/logger" ) //go:generate pflags Config --default-var=defaultConfig @@ -120,6 +124,14 @@ var ( EventVersion: 0, DefaultParallelismBehavior: ParallelismBehaviorUnlimited, }, + LiteralOffloadingConfig: LiteralOffloadingConfig{ + Enabled: false, // Default keep this disabled and we will followup when flytekit is released with the offloaded changes. + SupportedSDKVersions: map[string]string{ // The key is the SDK name (matches the supported SDK in core.RuntimeMetadata_RuntimeType) and the value is the minimum supported version + "FLYTE_SDK": "1.13.5", // Expected release number with flytekit support from this PR https://github.com/flyteorg/flytekit/pull/2685 + }, + MinSizeInMBForOffloading: 10, // 10 MB is the default size for offloading + MaxSizeInMBForOffloading: 1000, // 1 GB is the default size before failing fast. + }, } ) @@ -127,40 +139,79 @@ var ( // the base configuration to start propeller // NOTE: when adding new fields, do not mark them as "omitempty" if it's desirable to read the value from env variables. type Config struct { - KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."` - MasterURL string `json:"master"` - Workers int `json:"workers" pflag:",Number of threads to process workflows"` - WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:",Frequency of re-evaluating workflows"` - DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:",Frequency of re-evaluating downstream tasks"` - LimitNamespace string `json:"limit-namespace" pflag:",Namespaces to watch for this propeller"` - ProfilerPort config.Port `json:"prof-port" pflag:",Profiler port"` - MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."` - DefaultRawOutputPrefix string `json:"rawoutput-prefix" pflag:",a fully qualified storage path of the form s3://flyte/abc/..., where all data sandboxes should be stored."` - Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` - MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."` - MetricKeys []string `json:"metrics-keys" pflag:",Metrics labels applied to prometheus metrics emitted by the service."` - EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"Enable remote Workflow launcher to Admin"` - MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"Maximum number of retries per workflow"` - MaxTTLInHours int `json:"max-ttl-hours" pflag:"Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"` - GCInterval config.Duration `json:"gc-interval" pflag:"Run periodic GC every 30 minutes"` - LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` - PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` - MaxDatasetSizeBytes int64 `json:"max-output-size-bytes" pflag:",Deprecated! Use storage.limits.maxDownloadMBs instead"` - EnableGrpcLatencyMetrics bool `json:"enable-grpc-latency-metrics" pflag:",Enable grpc latency metrics. Note Histograms metrics can be expensive on Prometheus servers."` - KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"` - NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"` - MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."` - EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."` - IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"` - ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"` - IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"` - ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"` - IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"` - ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` - ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` - CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` - NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` - ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"` + KubeConfigPath string `json:"kube-config" pflag:",Path to kubernetes client config file."` + MasterURL string `json:"master"` + Workers int `json:"workers" pflag:",Number of threads to process workflows"` + WorkflowReEval config.Duration `json:"workflow-reeval-duration" pflag:",Frequency of re-evaluating workflows"` + DownstreamEval config.Duration `json:"downstream-eval-duration" pflag:",Frequency of re-evaluating downstream tasks"` + LimitNamespace string `json:"limit-namespace" pflag:",Namespaces to watch for this propeller"` + ProfilerPort config.Port `json:"prof-port" pflag:",Profiler port"` + MetadataPrefix string `json:"metadata-prefix,omitempty" pflag:",MetadataPrefix should be used if all the metadata for Flyte executions should be stored under a specific prefix in CloudStorage. If not specified, the data will be stored in the base container directly."` + DefaultRawOutputPrefix string `json:"rawoutput-prefix" pflag:",a fully qualified storage path of the form s3://flyte/abc/..., where all data sandboxes should be stored."` + Queue CompositeQueueConfig `json:"queue,omitempty" pflag:",Workflow workqueue configuration, affects the way the work is consumed from the queue."` + MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."` + MetricKeys []string `json:"metrics-keys" pflag:",Metrics labels applied to prometheus metrics emitted by the service."` + EnableAdminLauncher bool `json:"enable-admin-launcher" pflag:"Enable remote Workflow launcher to Admin"` + MaxWorkflowRetries int `json:"max-workflow-retries" pflag:"Maximum number of retries per workflow"` + MaxTTLInHours int `json:"max-ttl-hours" pflag:"Maximum number of hours a completed workflow should be retained. Number between 1-23 hours"` + GCInterval config.Duration `json:"gc-interval" pflag:"Run periodic GC every 30 minutes"` + LeaderElection LeaderElectionConfig `json:"leader-election,omitempty" pflag:",Config for leader election."` + PublishK8sEvents bool `json:"publish-k8s-events" pflag:",Enable events publishing to K8s events API."` + MaxDatasetSizeBytes int64 `json:"max-output-size-bytes" pflag:",Deprecated! Use storage.limits.maxDownloadMBs instead"` + EnableGrpcLatencyMetrics bool `json:"enable-grpc-latency-metrics" pflag:",Enable grpc latency metrics. Note Histograms metrics can be expensive on Prometheus servers."` + KubeConfig KubeClientConfig `json:"kube-client-config" pflag:",Configuration to control the Kubernetes client"` + NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"` + MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."` + EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."` + IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"` + ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"` + IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"` + ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"` + IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"` + ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` + ClusterID string `json:"cluster-id" pflag:",Unique cluster id running this flytepropeller instance with which to annotate execution events"` + CreateFlyteWorkflowCRD bool `json:"create-flyteworkflow-crd" pflag:",Enable creation of the FlyteWorkflow CRD on startup"` + NodeExecutionWorkerCount int `json:"node-execution-worker-count" pflag:",Number of workers to evaluate node executions, currently only used for array nodes"` + ArrayNode ArrayNodeConfig `json:"array-node-config,omitempty" pflag:",Configuration for array nodes"` + LiteralOffloadingConfig LiteralOffloadingConfig `json:"literalOffloadingConfig" pflag:",config used for literal offloading."` +} + +type LiteralOffloadingConfig struct { + Enabled bool + // Maps flytekit and union SDK names to minimum supported version that can handle reading offloaded literals. + SupportedSDKVersions map[string]string + // Default, 10Mbs. Determines the size of a literal at which to trigger offloading + MinSizeInMBForOffloading int64 + // Fail fast threshold + MaxSizeInMBForOffloading int64 +} + +// IsSupportedSDKVersion returns true if the provided SDK and version are supported by the literal offloading config. +func (l LiteralOffloadingConfig) IsSupportedSDKVersion(sdk string, versionString string) bool { + if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok { + c, err := semver.NewConstraint(fmt.Sprintf(">= %s", leastSupportedVersion)) + if err != nil { + // This should never happen + logger.Warnf(context.TODO(), "Failed to parse version constraint %s", leastSupportedVersion) + return false + } + version, err := semver.NewVersion(versionString) + if err != nil { + // This should never happen + logger.Warnf(context.TODO(), "Failed to parse version %s", versionString) + return false + } + return c.Check(version) + } + return false +} + +// GetSupportedSDKVersion returns the least supported version for the provided SDK. +func (l LiteralOffloadingConfig) GetSupportedSDKVersion(sdk string) string { + if leastSupportedVersion, ok := l.SupportedSDKVersions[sdk]; ok { + return leastSupportedVersion + } + return "" } // KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client. diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index 858fc8a8ba..b2e88e88e6 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -112,5 +112,9 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "node-execution-worker-count"), defaultConfig.NodeExecutionWorkerCount, "Number of workers to evaluate node executions, currently only used for array nodes") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "array-node-config.event-version"), defaultConfig.ArrayNode.EventVersion, "ArrayNode eventing version. 0 => legacy (drop-in replacement for maptask), 1 => new") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "array-node-config.default-parallelism-behavior"), defaultConfig.ArrayNode.DefaultParallelismBehavior, "Default parallelism behavior for array nodes") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.Enabled"), defaultConfig.LiteralOffloadingConfig.Enabled, "") + cmdFlags.StringToString(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.SupportedSDKVersions"), defaultConfig.LiteralOffloadingConfig.SupportedSDKVersions, "") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.MinSizeInMBForOffloading"), defaultConfig.LiteralOffloadingConfig.MinSizeInMBForOffloading, "") + cmdFlags.Int64(fmt.Sprintf("%v%v", prefix, "literalOffloadingConfig.MaxSizeInMBForOffloading"), defaultConfig.LiteralOffloadingConfig.MaxSizeInMBForOffloading, "") return cmdFlags } diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go index 27e7b76efa..aadb24b36a 100755 --- a/flytepropeller/pkg/controller/config/config_flags_test.go +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -967,4 +967,60 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_literalOffloadingConfig.Enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.Enabled", testValue) + if vBool, err := cmdFlags.GetBool("literalOffloadingConfig.Enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LiteralOffloadingConfig.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.SupportedSDKVersions", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "a=1,b=2" + + cmdFlags.Set("literalOffloadingConfig.SupportedSDKVersions", testValue) + if vStringToString, err := cmdFlags.GetStringToString("literalOffloadingConfig.SupportedSDKVersions"); err == nil { + testDecodeRaw_Config(t, vStringToString, &actual.LiteralOffloadingConfig.SupportedSDKVersions) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.MinSizeInMBForOffloading", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.MinSizeInMBForOffloading", testValue) + if vInt64, err := cmdFlags.GetInt64("literalOffloadingConfig.MinSizeInMBForOffloading"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.LiteralOffloadingConfig.MinSizeInMBForOffloading) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_literalOffloadingConfig.MaxSizeInMBForOffloading", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("literalOffloadingConfig.MaxSizeInMBForOffloading", testValue) + if vInt64, err := cmdFlags.GetInt64("literalOffloadingConfig.MaxSizeInMBForOffloading"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt64), &actual.LiteralOffloadingConfig.MaxSizeInMBForOffloading) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flytepropeller/pkg/controller/config/config_test.go b/flytepropeller/pkg/controller/config/config_test.go new file mode 100644 index 0000000000..afc9ed2fea --- /dev/null +++ b/flytepropeller/pkg/controller/config/config_test.go @@ -0,0 +1,54 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsSupportedSDKVersion(t *testing.T) { + t.Run("supported version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.True(t, config.IsSupportedSDKVersion("flytekit", "0.16.0")) + }) + + t.Run("unsupported version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.15.0")) + }) + + t.Run("unsupported SDK", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("unknown", "0.16.0")) + }) + + t.Run("invalid version", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "0.16.0", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "invalid")) + }) + + t.Run("invalid constraint", func(t *testing.T) { + config := LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + "flytekit": "invalid", + }, + } + assert.False(t, config.IsSupportedSDKVersion("flytekit", "0.16.0")) + }) +} diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go index c59aa9745d..39047e811d 100644 --- a/flytepropeller/pkg/controller/controller.go +++ b/flytepropeller/pkg/controller/controller.go @@ -436,14 +436,14 @@ func New(ctx context.Context, cfg *config.Config, kubeClientset kubernetes.Inter recoveryClient := recovery.NewClient(adminClient) nodeHandlerFactory, err := factory.NewHandlerFactory(ctx, launchPlanActor, launchPlanActor, - kubeClient, kubeClientset, catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, scope) + kubeClient, kubeClientset, catalogClient, recoveryClient, &cfg.EventConfig, cfg.LiteralOffloadingConfig, cfg.ClusterID, signalClient, scope) if err != nil { return nil, errors.Wrapf(err, "failed to create node handler factory") } nodeExecutor, err := nodes.NewExecutor(ctx, cfg.NodeConfig, store, controller.enqueueWorkflowForNodeUpdates, eventSink, launchPlanActor, launchPlanActor, storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, - catalogClient, recoveryClient, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) + catalogClient, recoveryClient, cfg.LiteralOffloadingConfig, &cfg.EventConfig, cfg.ClusterID, signalClient, nodeHandlerFactory, scope) if err != nil { return nil, errors.Wrapf(err, "Failed to create Controller.") } diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index a101ed5a30..5e9f910e14 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -45,6 +45,7 @@ var ( // arrayNodeHandler is a handle implementation for processing array nodes type arrayNodeHandler struct { eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig gatherOutputsRequestChannel chan *gatherOutputsRequest metrics metrics nodeExecutionRequestChannel chan *nodeExecutionRequest @@ -498,7 +499,6 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // attempt best effort at initializing outputLiterals with output variable names. currently // only TaskNode and WorkflowNode contain node interfaces. outputLiterals := make(map[string]*idlcore.Literal) - switch arrayNode.GetSubNodeSpec().GetKind() { case v1alpha1.NodeKindTask: taskID := *arrayNode.GetSubNodeSpec().TaskRef @@ -547,6 +547,18 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu return handler.UnknownTransition, fmt.Errorf("worker error(s) encountered: %s", workerErrorCollector.Summary(events.MaxErrorMessageLength)) } + // only offload literal if config is enabled for this feature. + if a.literalOffloadingConfig.Enabled { + for outputLiteralKey, outputLiteral := range outputLiterals { + // if the size of the output Literal is > threshold then we write the literal to the offloaded store and populate the literal with its zero value and update the offloaded url + // use the OffloadLargeLiteralKey to create {OffloadLargeLiteralKey}_offloaded_metadata.pb file in the datastore. + // Update the url in the outputLiteral with the offloaded url and also update the size of the literal. + offloadedOutputFile := v1alpha1.GetOutputsLiteralMetadataFile(outputLiteralKey, nCtx.NodeStatus().GetOutputDir()) + if err := common.OffloadLargeLiteral(ctx, nCtx.DataStore(), offloadedOutputFile, outputLiteral, a.literalOffloadingConfig); err != nil { + return handler.UnknownTransition, err + } + } + } outputLiteralMap := &idlcore.LiteralMap{ Literals: outputLiterals, } @@ -649,7 +661,7 @@ func (a *arrayNodeHandler) Setup(_ context.Context, _ interfaces.SetupContext) e } // New initializes a new arrayNodeHandler -func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope promutils.Scope) (interfaces.NodeHandler, error) { +func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, literalOffloadingConfig config.LiteralOffloadingConfig, scope promutils.Scope) (interfaces.NodeHandler, error) { // create k8s PluginState byte mocks to reuse instead of creating for each subNode evaluation pluginStateBytesNotStarted, err := bytesFromK8sPluginState(k8s.PluginState{Phase: k8s.PluginPhaseNotStarted}) if err != nil { @@ -676,6 +688,7 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ eventConfig: deepCopiedEventConfig, + literalOffloadingConfig: literalOffloadingConfig, gatherOutputsRequestChannel: make(chan *gatherOutputsRequest), metrics: newMetrics(arrayScope), nodeExecutionRequestChannel: make(chan *nodeExecutionRequest), diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index 648d70e36c..cb2f2898a6 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -52,6 +52,8 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter adminClient := launchplan.NewFailFastLaunchPlanExecutor() enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} eventConfig := &config.EventConfig{ErrorOnAlreadyExists: true} + offloadingConfig := config.LiteralOffloadingConfig{Enabled: false} + literalOffloadingConfig := config.LiteralOffloadingConfig{Enabled: true, MinSizeInMBForOffloading: 1024, MaxSizeInMBForOffloading: 1024 * 1024} mockEventSink := eventmocks.NewMockEventSink() mockHandlerFactory := &mocks.HandlerFactory{} mockHandlerFactory.OnGetHandlerMatch(mock.Anything).Return(nodeHandler, nil) @@ -62,11 +64,11 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter // create node executor nodeExecutor, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, dataStore, enqueueWorkflowFunc, mockEventSink, adminClient, - adminClient, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, eventConfig, "clusterID", mockSignalClient, mockHandlerFactory, scope) + adminClient, "s3://bucket/", mockKubeClient, noopCatalogClient, mockRecoveryClient, offloadingConfig, eventConfig, "clusterID", mockSignalClient, mockHandlerFactory, scope) assert.NoError(t, err) // return ArrayNodeHandler - arrayNodeHandler, err := New(nodeExecutor, eventConfig, scope) + arrayNodeHandler, err := New(nodeExecutor, eventConfig, literalOffloadingConfig, scope) if err != nil { return nil, err } diff --git a/flytepropeller/pkg/controller/nodes/common/utils.go b/flytepropeller/pkg/controller/nodes/common/utils.go index 04ddc5183d..89bb0afe2e 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils.go +++ b/flytepropeller/pkg/controller/nodes/common/utils.go @@ -2,17 +2,28 @@ package common import ( "context" + "fmt" "strconv" + "github.com/golang/protobuf/proto" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/encoding" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/storage" ) -const maxUniqueIDLength = 20 +const ( + maxUniqueIDLength = 20 + MB = 1024 * 1024 // 1 MB in bytes (1 MiB) +) // GenerateUniqueID is the UniqueId of a node is unique within a given workflow execution. // In order to achieve that we track the lineage of the node. @@ -67,3 +78,88 @@ func GetTargetEntity(ctx context.Context, nCtx interfaces.NodeExecutionContext) } return targetEntity } + +// OffloadLargeLiteral offloads the large literal if meets the threshold conditions +func OffloadLargeLiteral(ctx context.Context, datastore *storage.DataStore, dataReference storage.DataReference, + toBeOffloaded *idlcore.Literal, literalOffloadingConfig config.LiteralOffloadingConfig) error { + literalSizeBytes := int64(proto.Size(toBeOffloaded)) + literalSizeMB := literalSizeBytes / MB + // check if the literal is large + if literalSizeMB >= literalOffloadingConfig.MaxSizeInMBForOffloading { + errString := fmt.Sprintf("Literal size [%d] MB is larger than the max size [%d] MB for offloading", literalSizeMB, literalOffloadingConfig.MaxSizeInMBForOffloading) + logger.Errorf(ctx, errString) + return fmt.Errorf(errString) + } + if literalSizeMB < literalOffloadingConfig.MinSizeInMBForOffloading { + logger.Debugf(ctx, "Literal size [%d] MB is smaller than the min size [%d] MB for offloading", literalSizeMB, literalOffloadingConfig.MinSizeInMBForOffloading) + return nil + } + + inferredType := validators.LiteralTypeForLiteral(toBeOffloaded) + if inferredType == nil { + errString := "Failed to determine literal type for offloaded literal" + logger.Errorf(ctx, errString) + return fmt.Errorf(errString) + } + + // offload the literal + if err := datastore.WriteProtobuf(ctx, dataReference, storage.Options{}, toBeOffloaded); err != nil { + logger.Errorf(ctx, "Failed to offload literal at location [%s] with error [%s]", dataReference, err) + return err + } + + // update the literal with the offloaded URI, size and inferred type + toBeOffloaded.Value = &idlcore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlcore.LiteralOffloadedMetadata{ + Uri: dataReference.String(), + SizeBytes: uint64(literalSizeBytes), + InferredType: inferredType, + }, + } + logger.Infof(ctx, "Offloaded literal at location [%s] with size [%d] MB and inferred type [%s]", dataReference, literalSizeMB, inferredType) + return nil +} + +// CheckOffloadingCompat checks if the upstream and downstream nodes are compatible with the literal offloading feature and returns an error if not contained in phase info object +func CheckOffloadingCompat(ctx context.Context, nCtx interfaces.NodeExecutionContext, inputLiterals map[string]*core.Literal, node v1alpha1.ExecutableNode, literalOffloadingConfig config.LiteralOffloadingConfig) *handler.PhaseInfo { + consumesOffloadLiteral := false + for _, val := range inputLiterals { + if val != nil && val.GetOffloadedMetadata() != nil { + consumesOffloadLiteral = true + break + } + } + if !consumesOffloadLiteral { + return nil + } + var phaseInfo handler.PhaseInfo + + // Return early if the node is not of type NodeKindTask + if node.GetKind() != v1alpha1.NodeKindTask { + return nil + } + + // Process NodeKindTask + taskID := *node.GetTaskID() + taskNode, err := nCtx.ExecutionContext().GetTask(taskID) + if err != nil { + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "GetTaskIDFailure", err.Error(), nil) + return &phaseInfo + } + runtimeData := taskNode.CoreTask().GetMetadata().GetRuntime() + if !literalOffloadingConfig.IsSupportedSDKVersion(runtimeData.GetType().String(), runtimeData.GetVersion()) { + if !literalOffloadingConfig.Enabled { + errMsg := fmt.Sprintf("task [%s] is trying to consume offloaded literals but feature is not enabled", taskID) + logger.Errorf(ctx, errMsg) + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_USER, "LiteralOffloadingDisabled", errMsg, nil) + return &phaseInfo + } + leastSupportedVersion := literalOffloadingConfig.GetSupportedSDKVersion(runtimeData.GetType().String()) + errMsg := fmt.Sprintf("Literal offloading is not supported for this task as its registered with SDK version [%s] which is less than the least supported version [%s] for this feature", runtimeData.GetVersion(), leastSupportedVersion) + logger.Errorf(ctx, errMsg) + phaseInfo = handler.PhaseInfoFailure(core.ExecutionError_USER, "LiteralOffloadingNotSupported", errMsg, nil) + return &phaseInfo + } + + return nil +} diff --git a/flytepropeller/pkg/controller/nodes/common/utils_test.go b/flytepropeller/pkg/controller/nodes/common/utils_test.go index 9e451da69a..7d5ce1e372 100644 --- a/flytepropeller/pkg/controller/nodes/common/utils_test.go +++ b/flytepropeller/pkg/controller/nodes/common/utils_test.go @@ -1,11 +1,22 @@ package common import ( + "context" "testing" "github.com/stretchr/testify/assert" + idlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" + executorMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors/mocks" + nodeMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces/mocks" + "github.com/flyteorg/flyte/flytestdlib/contextutils" + "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/flyteorg/flyte/flytestdlib/promutils/labeled" + "github.com/flyteorg/flyte/flytestdlib/storage" ) type ParentInfo struct { @@ -66,3 +77,177 @@ func TestCreateParentInfoNil(t *testing.T) { assert.Equal(t, uint32(1), parent.CurrentAttempt()) assert.True(t, parent.IsInDynamicChain()) } + +func init() { + labeled.SetMetricKeys(contextutils.AppNameKey) +} + +func TestOffloadLargeLiteral(t *testing.T) { + t.Run("offload successful with valid size", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 0, + MaxSizeInMBForOffloading: 1, + } + inferredType := validators.LiteralTypeForLiteral(toBeOffloaded) + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.NoError(t, err) + assert.Equal(t, "foo/bar", toBeOffloaded.GetOffloadedMetadata().GetUri()) + assert.Equal(t, uint64(6), toBeOffloaded.GetOffloadedMetadata().GetSizeBytes()) + assert.Equal(t, inferredType.GetSimple(), toBeOffloaded.GetOffloadedMetadata().InferredType.GetSimple()) + + }) + + t.Run("offload fails with size larger than max", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 0, + MaxSizeInMBForOffloading: 0, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.Error(t, err) + }) + + t.Run("offload not attempted with size smaller than min", func(t *testing.T) { + ctx := context.Background() + datastore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + dataReference := storage.DataReference("foo/bar") + toBeOffloaded := &idlCore.Literal{ + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{ + Value: &idlCore.Scalar_Primitive{ + Primitive: &idlCore.Primitive{ + Value: &idlCore.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + MinSizeInMBForOffloading: 2, + MaxSizeInMBForOffloading: 3, + } + err := OffloadLargeLiteral(ctx, datastore, dataReference, toBeOffloaded, literalOffloadingConfig) + assert.NoError(t, err) + assert.Nil(t, toBeOffloaded.GetOffloadedMetadata()) + }) +} + +func TestCheckOffloadingCompat(t *testing.T) { + ctx := context.Background() + nCtx := &nodeMocks.NodeExecutionContext{} + executionContext := &executorMocks.ExecutionContext{} + executableTask := &mocks.ExecutableTask{} + node := &mocks.ExecutableNode{} + node.OnGetKind().Return(v1alpha1.NodeKindTask) + nCtx.OnExecutionContext().Return(executionContext) + executionContext.OnGetTask("task1").Return(executableTask, nil) + executableTask.OnCoreTask().Return(&idlCore.TaskTemplate{ + Metadata: &idlCore.TaskMetadata{ + Runtime: &idlCore.RuntimeMetadata{ + Type: idlCore.RuntimeMetadata_FLYTE_SDK, + Version: "0.16.0", + }, + }, + }) + taskID := "task1" + node.OnGetTaskID().Return(&taskID) + t.Run("supported version success", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + idlCore.RuntimeMetadata_FLYTE_SDK.String(): "0.16.0", + }, + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.Nil(t, phaseInfo) + }) + t.Run("unsupported version", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + SupportedSDKVersions: map[string]string{ + idlCore.RuntimeMetadata_FLYTE_SDK.String(): "0.17.0", + }, + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.NotNil(t, phaseInfo) + assert.Equal(t, idlCore.ExecutionError_USER, phaseInfo.GetErr().GetKind()) + assert.Equal(t, "LiteralOffloadingNotSupported", phaseInfo.GetErr().GetCode()) + }) + t.Run("offloading config disabled with offloaded data", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_OffloadedMetadata{ + OffloadedMetadata: &idlCore.LiteralOffloadedMetadata{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + Enabled: false, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.NotNil(t, phaseInfo) + assert.Equal(t, idlCore.ExecutionError_USER, phaseInfo.GetErr().GetKind()) + assert.Equal(t, "LiteralOffloadingDisabled", phaseInfo.GetErr().GetCode()) + }) + t.Run("offloading config enabled with no offloaded data", func(t *testing.T) { + inputLiterals := map[string]*idlCore.Literal{ + "foo": { + Value: &idlCore.Literal_Scalar{ + Scalar: &idlCore.Scalar{}, + }, + }, + } + literalOffloadingConfig := config.LiteralOffloadingConfig{ + Enabled: true, + } + phaseInfo := CheckOffloadingCompat(ctx, nCtx, inputLiterals, node, literalOffloadingConfig) + assert.Nil(t, phaseInfo) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 47c91edc51..2c3103e4ad 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -491,6 +491,7 @@ type nodeExecutor struct { defaultExecutionDeadline time.Duration enqueueWorkflow v1alpha1.EnqueueWorkflow eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig interruptibleFailureThreshold int32 maxNodeRetriesForSystemFailures uint32 metrics *nodeMetrics @@ -764,6 +765,10 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur } if nodeInputs != nil { + p := common.CheckOffloadingCompat(ctx, nCtx, nodeInputs.Literals, node, c.literalOffloadingConfig) + if p != nil { + return *p, nil + } inputsFile := v1alpha1.GetInputsFile(dataDir) if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { c.metrics.InputsWriteFailure.Inc(ctx) @@ -1417,7 +1422,7 @@ func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructur func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, + catalogClient catalog.Client, recoveryClient recovery.Client, literalOffloadingConfig config.LiteralOffloadingConfig, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, nodeHandlerFactory interfaces.HandlerFactory, scope promutils.Scope) (interfaces.Node, error) { // TODO we may want to make this configurable. @@ -1469,6 +1474,7 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora defaultExecutionDeadline: nodeConfig.DefaultDeadlines.DefaultNodeExecutionDeadline.Duration, enqueueWorkflow: enQWorkflow, eventConfig: eventConfig, + literalOffloadingConfig: literalOffloadingConfig, interruptibleFailureThreshold: nodeConfig.InterruptibleFailureThreshold, maxNodeRetriesForSystemFailures: uint32(nodeConfig.MaxNodeRetriesOnSystemFailures), metrics: metrics, diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go index ea7da42112..7fc4c05992 100644 --- a/flytepropeller/pkg/controller/nodes/executor_test.go +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -69,7 +69,7 @@ func TestSetInputsForStartNode(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) exec, err := NewExecutor(ctx, config.GetConfig().NodeConfig, mockStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + adminClient, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) inputs := &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -116,7 +116,7 @@ func TestSetInputsForStartNode(t *testing.T) { failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) execFail, err := NewExecutor(ctx, config.GetConfig().NodeConfig, failStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) t.Run("StorageFailure", func(t *testing.T) { w := createDummyBaseWorkflow(mockStorage) @@ -145,7 +145,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -156,7 +156,7 @@ func TestNodeExecutor_Initialize(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("error")) - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -176,7 +176,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -281,7 +281,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -696,7 +696,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { nodeConfig := config.GetConfig().NodeConfig nodeConfig.EnableCRDebugMetadata = test.enableCRDebugMetadata execIface, err := NewExecutor(ctx, nodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -771,7 +771,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -885,7 +885,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -952,7 +952,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -983,7 +983,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1018,7 +1018,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1131,7 +1131,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -1249,7 +1249,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) // Node not yet started @@ -1889,7 +1889,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { hf := &nodemocks.HandlerFactory{} hf.On("Setup", mock.Anything, mock.Anything, mock.Anything).Return(nil) execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, hf, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*recursiveNodeExecutor) @@ -2666,7 +2666,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Cache(t *testing.T) { mockHandlerFactory.OnGetHandler(v1alpha1.NodeKindTask).Return(mockHandler, nil) nodeExecutor, err := NewExecutor(ctx, nodeConfig, dataStore, enqueueWorkflow, mockEventSink, adminClient, adminClient, rawOutputPrefix, fakeKubeClient, catalogClient, - recoveryClient, eventConfig, testClusterID, signalClient, mockHandlerFactory, testScope) + recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, mockHandlerFactory, testScope) assert.NoError(t, err) return nodeExecutor diff --git a/flytepropeller/pkg/controller/nodes/factory/handler_factory.go b/flytepropeller/pkg/controller/nodes/factory/handler_factory.go index 424bd15f10..72dcff5310 100644 --- a/flytepropeller/pkg/controller/nodes/factory/handler_factory.go +++ b/flytepropeller/pkg/controller/nodes/factory/handler_factory.go @@ -28,16 +28,17 @@ import ( type handlerFactory struct { handlers map[v1alpha1.NodeKind]interfaces.NodeHandler - workflowLauncher launchplan.Executor - launchPlanReader launchplan.Reader - kubeClient executors.Client - kubeClientset kubernetes.Interface - catalogClient catalog.Client - recoveryClient recovery.Client - eventConfig *config.EventConfig - clusterID string - signalClient service.SignalServiceClient - scope promutils.Scope + workflowLauncher launchplan.Executor + launchPlanReader launchplan.Reader + kubeClient executors.Client + kubeClientset kubernetes.Interface + catalogClient catalog.Client + recoveryClient recovery.Client + eventConfig *config.EventConfig + literalOffloadingConfig config.LiteralOffloadingConfig + clusterID string + signalClient service.SignalServiceClient + scope promutils.Scope } func (f *handlerFactory) GetHandler(kind v1alpha1.NodeKind) (interfaces.NodeHandler, error) { @@ -54,7 +55,7 @@ func (f *handlerFactory) Setup(ctx context.Context, executor interfaces.Node, se return err } - arrayHandler, err := array.New(executor, f.eventConfig, f.scope) + arrayHandler, err := array.New(executor, f.eventConfig, f.literalOffloadingConfig, f.scope) if err != nil { return err } @@ -79,18 +80,20 @@ func (f *handlerFactory) Setup(ctx context.Context, executor interfaces.Node, se func NewHandlerFactory(ctx context.Context, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, kubeClient executors.Client, kubeClientset kubernetes.Interface, catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, + literalOffloadingConfig config.LiteralOffloadingConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (interfaces.HandlerFactory, error) { return &handlerFactory{ - workflowLauncher: workflowLauncher, - launchPlanReader: launchPlanReader, - kubeClient: kubeClient, - kubeClientset: kubeClientset, - catalogClient: catalogClient, - recoveryClient: recoveryClient, - eventConfig: eventConfig, - clusterID: clusterID, - signalClient: signalClient, - scope: scope, + workflowLauncher: workflowLauncher, + launchPlanReader: launchPlanReader, + kubeClient: kubeClient, + kubeClientset: kubeClientset, + catalogClient: catalogClient, + recoveryClient: recoveryClient, + eventConfig: eventConfig, + literalOffloadingConfig: literalOffloadingConfig, + clusterID: clusterID, + signalClient: signalClient, + scope: scope, }, nil } diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index 85667b0e26..a3d028e94b 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -242,11 +242,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -328,11 +328,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -398,7 +398,7 @@ func BenchmarkWorkflowExecutor(b *testing.B) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() handlerFactory := &nodemocks.HandlerFactory{} nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, scope) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, scope) assert.NoError(b, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -512,7 +512,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() @@ -613,11 +613,11 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() recoveryClient := &recoveryMocks.Client{} - handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) + handlerFactory, err := factory.NewHandlerFactory(ctx, adminClient, adminClient, fakeKubeClient, mockClientset, catalogClient, recoveryClient, eventConfig, config.LiteralOffloadingConfig{}, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) execStatsHolder, err := execStats.NewExecutionStatsHolder() assert.NoError(t, err) @@ -685,7 +685,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { handlerFactory.OnSetupMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) handlerFactory.OnGetHandlerMatch(mock.Anything).Return(h, nil) nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient, adminClient, - "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) + "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, config.LiteralOffloadingConfig{}, eventConfig, testClusterID, signalClient, handlerFactory, promutils.NewTestScope()) assert.NoError(t, err) t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) { diff --git a/go.mod b/go.mod index 8fd55ed61a..8c8053def6 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.4.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 // indirect github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 // indirect + github.com/Masterminds/semver v1.5.0 // indirect github.com/NYTimes/gizmo v1.3.6 // indirect github.com/Shopify/sarama v1.26.4 // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect diff --git a/go.sum b/go.sum index ae60f26800..68eebb1fde 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20191210083620-6965a1cf github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625 h1:cQyO5JQ2iuHnEcF3v24kdDMsgh04RjyFPDtuvD6PCE0= github.com/GoogleCloudPlatform/spark-on-k8s-operator v0.0.0-20200723154620-6f35a1152625/go.mod h1:6PnrZv6zUDkrNMw0mIoGRmGBR7i9LulhKPmxFq4rUiM= github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= +github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/semver/v3 v3.0.3/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= From 367459b4256c46234274bbf83087382110594711 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 18 Sep 2024 00:18:44 -0700 Subject: [PATCH 5/6] pytorch object.inv moved (#5755) Signed-off-by: Yee Hing Tong --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index 24f6feb97e..2be3b0185f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -287,7 +287,7 @@ "python": ("https://docs.python.org/3", None), "numpy": ("https://numpy.org/doc/stable", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), - "torch": ("https://pytorch.org/docs/master/", None), + "torch": ("https://pytorch.org/docs/main/", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), "matplotlib": ("https://matplotlib.org", None), "pandera": ("https://pandera.readthedocs.io/en/stable/", None), From 312910d7ae308818a9d294dd76697a666e40fc0c Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 18 Sep 2024 19:15:34 +0800 Subject: [PATCH 6/6] [RFC] Binary IDL With MessagePack Bytes (#5742) --- .../5741-binary-idl-with-message-pack.md | 620 ++++++++++++++++++ 1 file changed, 620 insertions(+) create mode 100644 rfc/system/5741-binary-idl-with-message-pack.md diff --git a/rfc/system/5741-binary-idl-with-message-pack.md b/rfc/system/5741-binary-idl-with-message-pack.md new file mode 100644 index 0000000000..ae04b0903f --- /dev/null +++ b/rfc/system/5741-binary-idl-with-message-pack.md @@ -0,0 +1,620 @@ +# Binary IDL With MessagePack Bytes + +**Authors:** + +- [@Han-Ru](https://github.com/future-outlier) +- [@Yee Hing Tong](https://github.com/wild-endeavor) +- [@Ping-Su](https://github.com/pingsutw) +- [@Eduardo Apolinario](https://github.com/eapolinario) +- [@Haytham Abuelfutuh](https://github.com/EngHabu) +- [@Ketan Umare](https://github.com/kumare3) + +## 1 Executive Summary +### Literal Value +Literal Value will be `Binary`. + +Use `bytes` in `Binary` instead of `Protobuf struct`. + +- To Literal + +| Before | Now | +|-----------------------------------|----------------------------------------------| +| Python Val -> JSON String -> Protobuf Struct | Python Val -> (Dict ->) Bytes -> Binary (value: MessagePack Bytes, tag: msgpack) IDL Object | + +- To Python Value + +| Before | Now | +|-----------------------------------|----------------------------------------------| +| Protobuf Struct -> JSON String -> Python Val | Binary (value: MessagePack Bytes, tag: msgpack) IDL Object -> Bytes -> (Dict ->) -> Python Val | + + +Note: + +1. If a Python value can't be directly converted to `MessagePack Bytes`, we can first convert it to a `Dict`, and then convert it to `MessagePack Bytes`. + + - **For example:** The Pydantic-to-literal function workflow will be: + `BaseModel` -> `dict` -> `MessagePack Bytes` -> `Binary (value: MessagePack Bytes, tag: msgpack) IDL Object`. + + - **For pure `dict` in Python:** The to-literal function workflow will be: + `dict` -> `MessagePack Bytes` -> `Binary (value: MessagePack Bytes, tag: msgpack) IDL Object`. + +2. There is **NO JSON** involved in the new type at all. Only **JSON Schema** is used to construct `DataClass` or `Pydantic BaseModel`. + + +### Literal Type +Literal Type will be `SimpleType.STRUCT`. +`Json Schema` will be stored in `Literal Type's metadata`. + +1. Dataclass, Pydantic BaseModel and pure dict in python will all use `SimpleType.STRUCT`. +2. We will put `Json Schema` in Literal Type's `metadata` field, this will be used in flytekit remote api to construct dataclass/Pydantic BaseModel by `Json Schema`. +3. We will use libraries written in golang to compare `Json Schema` to solve this issue: ["[BUG] Union types fail for e.g. two different dataclasses"](https://github.com/flyteorg/flyte/issues/5489). + +Note: The `metadata` of `Literal Type` and `Literal Value` are not the same. + +## 2 Motivation + +Prior to this RFC, in flytekit, when handling dataclasses, Pydantic base models, and dictionaries, we store data using a JSON string within Protobuf struct datatype. + +This approach causes issues with integers, as Protobuf struct does not support int types, leading to their conversion to floats. + +This results in performance issues since we need to recursively iterate through all attributes/keys in dataclasses and dictionaries to ensure floats types are converted to int. + +In addition to performance issues, the required code is complicated and error prone. + +Note: We have more than 10 issues about dict, dataclass and Pydantic. + +This feature can solve them all. + +## 3 Proposed Implementation +### Before +```python +@task +def t1() -> dict: + ... + return {"a": 1} # Protobuf Struct {"a": 1.0} + +@task +def t2(a: dict): + print(a["integer"]) # wrong, will be a float +``` +### After +```python +@task +def t1() -> dict: # Literal(scalar=Scalar(binary=Binary(value=b'msgpack_bytes', tag="msgpack"))) + ... + return {"a": 1} # Protobuf Binary value=b'\x81\xa1a\x01', produced by msgpack + +@task +def t2(a: dict): + print(a["integer"]) # correct, it will be a integer +``` + +#### Note +- We will use implement `to_python_value` to every type transformer to ensure backward compatibility. +For example, `Binary IDL Object` -> python value and `Protobuf Struct IDL Object` -> python value are both supported. + +### How to turn a value to bytes? +#### Use MsgPack to convert a value into bytes +##### Python +```python +import msgpack + +# Encode +def to_literal(): + msgpack_bytes = msgpack.dumps(python_val) + return Literal(scalar=Scalar(binary=Binary(value=b'msgpack_bytes', tag="msgpack"))) + +# Decode +def to_python_value(): + # lv: literal value + if lv.scalar.binary.tag == "msgpack": + msgpack_bytes = lv.scalar.binary.value + else: + raise ValueError(f"{tag} is not supported to decode this Binary Literal: {lv.scalar.binary}.") + return msgpack.loads(msgpack_bytes) +``` +reference: https://github.com/msgpack/msgpack-python + +##### Golang +```go +package main + +import ( + "fmt" + "github.com/shamaton/msgpack/v2" +) + +func main() { + // Example data to encode + data := map[string]int{"a": 1} + + // Encode the data + encodedData, err := msgpack.Marshal(data) + if err != nil { + panic(err) + } + + // Print the encoded data + fmt.Printf("Encoded data: %x\n", encodedData) // Output: 81a16101 + + // Decode the data + var decodedData map[string]int + err = msgpack.Unmarshal(encodedData, &decodedData) + if err != nil { + panic(err) + } + + // Print the decoded data + fmt.Printf("Decoded data: %+v\n", decodedData) // Output: map[a:1] +} +``` + +reference: [shamaton/msgpack GitHub Repository](https://github.com/shamaton/msgpack) + +Notes: + +1. **MessagePack Implementations**: + - We can explore all MessagePack implementations for Golang at the [MessagePack official website](https://msgpack.org/index.html). + +2. **Library Comparison**: + - The library [github.com/vmihailenco/msgpack](https://github.com/vmihailenco/msgpack) doesn't support strict type deserialization (for example, `map[int]string`), but [github.com/shamaton/msgpack/v2](https://github.com/shamaton/msgpack) supports this feature. This is super important for backward compatibility. + +3. **Library Popularity**: + - While [github.com/shamaton/msgpack/v2](https://github.com/shamaton/msgpack) has fewer stars on GitHub, it has proven to be reliable in various test cases. All cases created by me have passed successfully, which you can find in this [pull request](https://github.com/flyteorg/flytekit/pull/2751). + +4. **Project Activity**: + - [github.com/shamaton/msgpack/v2](https://github.com/shamaton/msgpack) is still an actively maintained project. The author responds quickly to issues and questions, making it a well-supported choice for projects requiring ongoing maintenance and active support. + +5. **Testing Process**: + - I initially started with [github.com/vmihailenco/msgpack](https://github.com/vmihailenco/msgpack) but switched to [github.com/shamaton/msgpack/v2](https://github.com/shamaton/msgpack) due to its better support for strict typing and the active support provided by the author. + + +##### JavaScript +```javascript +import { encode, decode } from '@msgpack/msgpack'; + +// Example data to encode +const data = { a: 1 }; + +// Encode the data +const encodedData = encode(data); + +// Print the encoded data +console.log(encodedData); // + +// Decode the data +const decodedData = decode(encodedData); + +// Print the decoded data +console.log(decodedData); // { a: 1 } +``` +reference: https://github.com/msgpack/msgpack-javascript + +### FlyteIDL +#### Literal Value + +Here is the [IDL definition](https://github.com/flyteorg/flyte/blob/7989209e15600b56fcf0f4c4a7c9af7bfeab6f3e/flyteidl/protos/flyteidl/core/literals.proto#L42-L47). + +The `bytes` field is used for serialized data, and the `tag` field specifies the serialization format identifier. +#### Literal Type +```proto +import "google/protobuf/struct.proto"; + +enum SimpleType { + NONE = 0; + INTEGER = 1; + FLOAT = 2; + STRING = 3; + BOOLEAN = 4; + DATETIME = 5; + DURATION = 6; + BINARY = 7; + ERROR = 8; + STRUCT = 9; // Use this one. +} +message LiteralType { + SimpleType simple = 1; // Use this one. + google.protobuf.Struct metadata = 6; // Store Json Schema to differentiate different dataclass. +} +``` + +### FlytePropeller +1. Attribute Access for dictionary, Dataclass, and Pydantic in workflow. +Dict[type, type] is supported already, we have to support Dataclass, Pydantic and dict now. +```python +from flytekit import task, workflow +from dataclasses import dataclass + +@dataclass +class DC: + a: int + +@task +def t1() -> DC: + return DC(a=1) + +@task +def t2(x: int): + print("x:", x) + return + +@workflow +def wf(): + o = t1() + t2(x=o.a) +``` +2. Create a Literal Type for Scalar when doing type validation. +```go +func literalTypeForScalar(scalar *core.Scalar) *core.LiteralType { + ... + case *core.Scalar_Binary: + literalType = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_BINARY}} + ... + return literalType +} +``` +3. Support input and default input. +```go +// Literal Input +func ExtractFromLiteral(literal *core.Literal) (interface{}, error) { + switch literalValue := literal.Value.(type) { + case *core.Literal_Scalar: + ... + case *core.Scalar_Binary: + return scalarValue.Binary, nil + } +} +// Default Input +func MakeDefaultLiteralForType(typ *core.LiteralType) (*core.Literal, error) { + switch t := typ.GetType().(type) { + case *core.LiteralType_Simple: + case core.SimpleType_BINARY: + return MakeLiteral([]byte{}) + } +} +// Use Message Pack as Default Tag for deserialization. +// "tag" will default be "msgpack" +func MakeBinaryLiteral(v []byte, tag string) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: v, + Tag: tag, + }, + }, + }, + }, + } +} +``` +4. Compiler +```go +func (t trivialChecker) CastsFrom(upstreamType *flyte.LiteralType) bool { + if upstreamType.GetEnumType() != nil { + if t.literalType.GetSimple() == flyte.SimpleType_STRING { + return true + } + } + + if t.literalType.GetEnumType() != nil { + if upstreamType.GetSimple() == flyte.SimpleType_STRING { + return true + } + } + + if GetTagForType(upstreamType) != "" && GetTagForType(t.literalType) != GetTagForType(upstreamType) { + return false + } + + // Here is the new way to check if dataclass/pydantic BaseModel are castable or not. + if upstreamTypeCopy.GetSimple() == flyte.SimpleType_STRUCT &&\ + downstreamTypeCopy.GetSimple() == flyte.SimpleType_STRUCT { + // Json Schema is stored in Metadata + upstreamMetadata := upstreamTypeCopy.GetMetadata() + downstreamMetadata := downstreamTypeCopy.GetMetadata() + + // There's bug in flytekit's dataclass Transformer to generate JSON Scheam before, + // in some case, we the JSON Schema will be nil, so we can only pass it to support + // backward compatible. (reference task should be supported.) + if upstreamMetadata == nil || downstreamMetadata == nil { + return true + } + + return isSameTypeInJSON(upstreamMetadata, downstreamMetadata) ||\ + isSuperTypeInJSON(upstreamMetadata, downstreamMetadata) + } + + upstreamTypeCopy := *upstreamType + downstreamTypeCopy := *t.literalType + upstreamTypeCopy.Structure = &flyte.TypeStructure{} + downstreamTypeCopy.Structure = &flyte.TypeStructure{} + upstreamTypeCopy.Metadata = &structpb.Struct{} + downstreamTypeCopy.Metadata = &structpb.Struct{} + upstreamTypeCopy.Annotation = &flyte.TypeAnnotation{} + downstreamTypeCopy.Annotation = &flyte.TypeAnnotation{} + return upstreamTypeCopy.String() == downstreamTypeCopy.String() +} +``` +### FlyteKit +#### Attribute Access + +In all transformers, we should implement a function called `from_binary_idl` to convert the Binary IDL Object into the desired type. + +A base method can be added to the `TypeTransformer` class, allowing child classes to override it as needed. + +During attribute access, Flyte Propeller will deserialize the msgpack bytes into a map object in golang, retrieve the specific attribute, and then serialize it back into msgpack bytes (resulting in a Binary IDL Object containing msgpack bytes). + +This implies that when converting a literal to a Python value, we will receive `msgpack bytes` instead of the `expected Python type`. + +```python +# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. +# This is relevant for cases like Dict[int, str]. +# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed. +def _default_flytekit_decoder(data: bytes) -> Any: + return msgpack.unpackb(data, raw=False, strict_map_key=False) + + +def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + # Handle msgpack serialization + if binary_idl_object.tag == "msgpack": + try: + # Retrieve the existing decoder for the expected type + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + # Create a new decoder if not already cached + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder) + self._msgpack_decoder[expected_python_type] = decoder + # Decode the binary IDL object into the expected Python type + return decoder.decode(binary_idl_object.value) + else: + # Raise an error if the binary format is not supported + raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") +``` + +Note: +1. This base method can handle primitive types, nested typed dictionaries, nested typed lists, and combinations of nested typed dictionaries and lists. + +2. Dataclass transformer needs its own `from_binary_idl` method to handle specific cases such as [discriminated classes](https://github.com/flyteorg/flyte/issues/5588). + +3. Flyte types (e.g., FlyteFile, FlyteDirectory, StructuredDataset, and FlyteSchema) will need their own `from_binary_idl` methods, as they must handle downloading files from remote object storage when converting literals to Python values. + +For example, see the FlyteFile implementation: https://github.com/flyteorg/flytekit/pull/2751/files#diff-22cf9c7153b54371b4a77331ddf276a082cf4b3c5e7bd1595dd67232288594fdR522-R552 + +#### pyflyte run +The behavior will remain unchanged. + +We will pass the value to our class, which inherits from `click.ParamType`, and use the corresponding type transformer to convert the input to the correct type. + +### Dict Transformer +There are 2 cases in Dict Transformer, `Dict[type, type]` and `dict`. + +For `Dict[type, type]`, we will stay everything the same as before. + +#### Literal Value +For `dict`, the life cycle of it will be as below. + +Before: +- `to_literal`: `dict` -> `JSON String` -> `Protobuf Struct` +- `to_python_val`: `Protobuf Struct` -> `JSON String` -> `dict` + +After: +- `to_literal`: `dict` -> `msgpack bytes` -> `Binary(value=b'msgpack_bytes', tag="msgpack")` +- `to_python_val`: `Binary(value=b'msgpack_bytes', tag="msgpack")` -> `msgpack bytes` -> `dict` + +#### JSON Schema +The JSON Schema of `dict` will be empty. +### Dataclass Transformer +#### Literal Value +Before: +- `to_literal`: `dataclass` -> `JSON String` -> `Protobuf Struct` +- `to_python_val`: `Protobuf Struct` -> `JSON String` -> `dataclass` + +After: +- `to_literal`: `dataclass` -> `msgpack bytes` -> `Binary(value=b'msgpack_bytes', tag="msgpack")` +- `to_python_val`: `Binary(value=b'msgpack_bytes', tag="msgpack")` -> `msgpack bytes` -> `dataclass` + +Note: We will use mashumaro's `MessagePackEncoder` and `MessagePackDecoder` to serialize and deserialize dataclass value in python. +```python +from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder +``` + +#### Literal Type's TypeStructure's dataclass_type +This is used for compiling dataclass attribute access. + +With it, we can retrieve the literal type of an attribute and validate it in Flyte's propeller compiler. + +For more details, check here: https://github.com/flyteorg/flytekit/blob/fb55841f8660b2a31e99381dd06e42f8cd22758e/flytekit/core/type_engine.py#L454-L525 + +#### JSON Schema +The JSON Schema of `dataclass` will be generated by `marshmallow` or `mashumaro`. +Check here: https://github.com/flyteorg/flytekit/blob/8c6f6f0f17d113447e1b10b03e25a34bad79685c/flytekit/core/type_engine.py#L442-L474 + + +### Pydantic Transformer +#### Literal Value +Pydantic can't be serialized to `msgpack bytes` directly. +But `dict` can. + +- `to_literal`: `BaseModel` -> `dict` -> `msgpack bytes` -> `Binary(value=b'msgpack_bytes', tag="msgpack")` +- `to_python_val`: `Binary(value=b'msgpack_bytes', tag="msgpack")` -> `msgpack bytes` -> `dict` -> `BaseModel` + +Note: Pydantic BaseModel can't be serialized directly by `msgpack`, but this implementation will still ensure 100% correct. + +```python +@dataclass +class DC_inside: + a: int + b: float + +@dataclass +class DC: + a: int + b: float + c: str + d: Dict[str, int] + e: DC_inside + +class MyDCModel(BaseModel): + dc: DC + +my_dc = MyDCModel(dc=DC(a=1, b=2.0, c="3", d={"4": 5}, e=DC_inside(a=6, b=7.0))) +# {'dc': {'a': 1, 'b': 2.0, 'c': '3', 'd': {'4': 5}, 'e': {'a': 6, 'b': 7.0}}} +``` + +#### Literal Type's TypeStructure's dataclass_type +This is used for compiling Pydantic BaseModel attribute access. + +With it, we can retrieve an attribute's literal type and validate it in Flyte's propeller compiler. + +Although this feature is not currently implemented, it will function similarly to the dataclass transformer in the future. + +#### JSON Schema +The JSON Schema of `BaseModel` will be generated by Pydantic's API. +```python +@dataclass +class DC_inside: + a: int + b: float + +@dataclass +class DC: + a: int + b: float + c: str + d: Dict[str, int] + e: DC_inside + +class MyDCModel(BaseModel): + dc: DC + +my_dc = MyDCModel(dc=DC(a=1, b=2.0, c="3", d={"4": 5}, e=DC_inside(a=6, b=7.0))) +my_dc.model_json_schema() +""" +{'$defs': {'DC': {'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'number'}, 'c': {'title': 'C', 'type': 'string'}, 'd': {'additionalProperties': {'type': 'integer'}, 'title': 'D', 'type': 'object'}, 'e': {'$ref': '#/$defs/DC_inside'}}, 'required': ['a', 'b', 'c', 'd', 'e'], 'title': 'DC', 'type': 'object'}, 'DC_inside': {'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['a', 'b'], 'title': 'DC_inside', 'type': 'object'}}, 'properties': {'dc': {'$ref': '#/$defs/DC'}}, 'required': ['dc'], 'title': 'MyDCModel', 'type': 'object'} +""" +``` + +### FlyteCtl + +In FlyteCtl, we can construct input for the execution. + +When we receive `SimpleType.STRUCT`, we can construct a `Binary IDL Object` using the following logic in `flyteidl/clients/go/coreutils/literals.go`: + +```go +if newT.Simple == core.SimpleType_STRUCT { + if _, isValueStringType := v.(string); !isValueStringType { + byteValue, err := msgpack.Marshal(v) + if err != nil { + return nil, fmt.Errorf("unable to marshal to json string for struct value %v", v) + } + strValue = string(byteValue) + } +} +``` + +This is how users can create an execution by using a YAML file: +```bash +flytectl create execution --execFile ./flytectl/create_dataclass_task.yaml -p flytesnacks -d development +``` + +Example YAML file (`create_dataclass_task.yaml`): +```yaml +iamRoleARN: "" +inputs: + input: + a: 1 + b: 3.14 + c: example_string + d: + "1": 100 + "2": 200 + e: + a: 1 + b: 3.14 +envs: {} +kubeServiceAcct: "" +targetDomain: "" +targetProject: "" +task: dataclass_example.dataclass_task +version: OSyTikiBTAkjBgrL5JVOVw +``` + +### FlyteCopilot + +When we need to pass an attribute access value to a copilot task, we must modify the code to convert a Binary Literal value with the `msgpack` tag into a primitive value. + +(Currently, we will only support primitive values.) + +You can reference the relevant section of code here: + +[FlyteCopilot - Data Download](https://github.com/flyteorg/flyte/blob/7989209e15600b56fcf0f4c4a7c9af7bfeab6f3e/flytecopilot/data/download.go#L88-L95) + +### FlyteConsole +#### How users input into launch form? +When FlyteConsole receives a literal type of `SimpleType.STRUCT`, the input method depends on the availability of a JSON schema: + +1. No JSON Schema provided: + +Input is expected as `a Javascript Object` (e.g., `{"a": 1}`). + +2. JSON Schema provided: + +Users can input values based on the schema's expected type and construct an appropriate `Javascript Object`. + +Note: + +For `dataclass` and Pydantic `BaseModel`, a JSON schema will be provided in their literal type, and the input form will be constructed accordingly. + +##### What happens after the user enters data? + +Input values -> Javascript Object -> msgpack bytes -> Binary IDL With MessagePack Bytes + +#### Displaying Inputs/Outputs in the Console +Use `msgpack` to deserialize bytes into an Object and display it in Flyte Console. + +#### Copying Inputs/Outputs in the Console +Allow users to copy the `Object` to the clipboard, as currently implemented. + +#### Pasting and Copying from FlyteConsole +Currently, we should support JSON pasting if the content is a JavaScript object. However, there's a question of how we might handle other formats like YAML or MsgPack bytes, especially if copied from a binary file. + +For now, focusing on supporting JSON pasting makes sense. However, adding support for YAML and MsgPack bytes could be valuable future enhancements. + +## 4 Metrics & Dashboards + +None + +## 5 Drawbacks + +None + +## 6 Alternatives + +MsgPack is a good choice because it's more smaller and faster than UTF-8 Encoded JSON String. + +You can see the performance comparison here: https://github.com/flyteorg/flyte/pull/5607#issuecomment-2333174325 + +We will use `msgpack` to do it. + +## 7 Potential Impact and Dependencies +None. + +## 8. Unresolved Questions +### Conditional Branch +Currently, our support for `DataClass/BaseModel/Dict[type, type]` within conditional branches is incomplete. At present, we only support comparisons of primitive types. However, there are two key challenges when attempting to handle these more complex types: + +1. **Primitive Type Comparison vs. Binary IDL Object:** + - In conditional branches, we receive a `Binary IDL Object` during type comparison, which needs to be converted into a `Primitive IDL Object`. + - The issue is that we don't know the expected Python type or primitive type beforehand, making this conversion ambiguous. + +2. **MsgPack Incompatibility with `Primitive_Datetime` and `Primitive_Duration`:** + - MsgPack does not natively support the `Primitive_Datetime` and `Primitive_Duration` types, and instead converts them to strings. + - This can lead to inconsistencies in comparison logic. One potential workaround is to convert both types to strings for comparison. However, it is uncertain whether this approach is the best solution. + +## 9 Conclusion + +1. Binary IDL with MessagePack Bytes provides a better representation for dataclasses, Pydantic BaseModels, and untyped dictionaries in Flyte. + +2. This approach ensures 100% accuracy of each attribute and enables attribute access.