diff --git a/flyteadmin/pkg/repositories/transformers/task_execution.go b/flyteadmin/pkg/repositories/transformers/task_execution.go index cedf7c1f13..9f24ed2aa4 100644 --- a/flyteadmin/pkg/repositories/transformers/task_execution.go +++ b/flyteadmin/pkg/repositories/transformers/task_execution.go @@ -372,6 +372,31 @@ func mergeMetadata(existing, latest *event.TaskExecutionMetadata) *event.TaskExe return existing } +func filterExternalResourceLogsByPhase(externalResources []*event.ExternalResourceInfo, phase core.TaskExecution_Phase) { + for _, externalResource := range externalResources { + externalResource.Logs = filterLogsByPhase(externalResource.Logs, phase) + } +} + +func filterLogsByPhase(logs []*core.TaskLog, phase core.TaskExecution_Phase) []*core.TaskLog { + filteredLogs := make([]*core.TaskLog, 0, len(logs)) + + for _, l := range logs { + if common.IsTaskExecutionTerminal(phase) && l.HideOnceFinished { + continue + } + // Some plugins like e.g. Dask, Ray start with or very quickly transition to core.TaskExecution_INITIALIZING + // once the CR has been created even though the underlying pods are still pending. We thus treat queued and + // initializing the same here. + if (phase == core.TaskExecution_QUEUED || phase == core.TaskExecution_INITIALIZING) && !l.ShowWhilePending { + continue + } + filteredLogs = append(filteredLogs, l) + + } + return filteredLogs +} + func UpdateTaskExecutionModel(ctx context.Context, request *admin.TaskExecutionEventRequest, taskExecutionModel *models.TaskExecution, inlineEventDataPolicy interfaces.InlineEventDataPolicy, storageClient *storage.DataStore) error { err := handleTaskExecutionInputs(ctx, taskExecutionModel, request, storageClient) @@ -384,6 +409,7 @@ func UpdateTaskExecutionModel(ctx context.Context, request *admin.TaskExecutionE return errors.NewFlyteAdminErrorf(codes.Internal, "failed to unmarshal task execution closure with error: %+v", err) } + isPhaseChange := taskExecutionModel.Phase != request.Event.Phase.String() existingTaskPhase := taskExecutionModel.Phase taskExecutionModel.Phase = request.Event.Phase.String() taskExecutionModel.PhaseVersion = request.Event.PhaseVersion @@ -393,7 +419,11 @@ func UpdateTaskExecutionModel(ctx context.Context, request *admin.TaskExecutionE reportedAt = request.Event.OccurredAt } taskExecutionClosure.UpdatedAt = reportedAt - taskExecutionClosure.Logs = mergeLogs(taskExecutionClosure.Logs, request.Event.Logs) + + mergedLogs := mergeLogs(taskExecutionClosure.Logs, request.Event.Logs) + filteredLogs := filterLogsByPhase(mergedLogs, request.Event.Phase) + taskExecutionClosure.Logs = filteredLogs + if len(request.Event.Reasons) > 0 { for _, reason := range request.Event.Reasons { taskExecutionClosure.Reasons = append( @@ -437,6 +467,11 @@ func UpdateTaskExecutionModel(ctx context.Context, request *admin.TaskExecutionE return errors.NewFlyteAdminErrorf(codes.Internal, "failed to merge task event custom_info with error: %v", err) } taskExecutionClosure.Metadata = mergeMetadata(taskExecutionClosure.Metadata, request.Event.Metadata) + + if isPhaseChange && taskExecutionClosure.Metadata != nil && len(taskExecutionClosure.Metadata.ExternalResources) > 0 { + filterExternalResourceLogsByPhase(taskExecutionClosure.Metadata.ExternalResources, request.Event.Phase) + } + if request.Event.EventVersion > taskExecutionClosure.EventVersion { taskExecutionClosure.EventVersion = request.Event.EventVersion } diff --git a/flyteadmin/pkg/repositories/transformers/task_execution_test.go b/flyteadmin/pkg/repositories/transformers/task_execution_test.go index 71ff3c60e3..e1e0fd973e 100644 --- a/flyteadmin/pkg/repositories/transformers/task_execution_test.go +++ b/flyteadmin/pkg/repositories/transformers/task_execution_test.go @@ -652,6 +652,183 @@ func TestUpdateTaskExecutionModelRunningToFailed(t *testing.T) { } +func TestUpdateTaskExecutionModelFilterLogLinks(t *testing.T) { + existingClosure := &admin.TaskExecutionClosure{ + Phase: core.TaskExecution_QUEUED, + StartedAt: taskEventOccurredAtProto, + CreatedAt: taskEventOccurredAtProto, + UpdatedAt: taskEventOccurredAtProto, + Logs: []*core.TaskLog{}, + Reason: "task submitted to k8s", + Reasons: []*admin.Reason{ + { + OccurredAt: taskEventOccurredAtProto, + Message: "task submitted to k8s", + }, + }, + } + + closureBytes, err := proto.Marshal(existingClosure) + assert.Nil(t, err) + + existingTaskExecution := models.TaskExecution{ + TaskExecutionKey: models.TaskExecutionKey{ + TaskKey: models.TaskKey{ + Project: sampleTaskID.Project, + Domain: sampleTaskID.Domain, + Name: sampleTaskID.Name, + Version: sampleTaskID.Version, + }, + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: sampleNodeExecID.NodeId, + ExecutionKey: models.ExecutionKey{ + Project: sampleNodeExecID.ExecutionId.Project, + Domain: sampleNodeExecID.ExecutionId.Domain, + Name: sampleNodeExecID.ExecutionId.Name, + }, + }, + RetryAttempt: &retryAttemptValue, + }, + Phase: "TaskExecutionPhase_TASK_PHASE_QUEUED", + InputURI: "input uri", + Closure: closureBytes, + StartedAt: &taskEventOccurredAt, + TaskExecutionCreatedAt: &taskEventOccurredAt, + TaskExecutionUpdatedAt: &taskEventOccurredAt, + } + + occuredAt := taskEventOccurredAt.Add(time.Minute) + occuredAtProto, err := ptypes.TimestampProto(occuredAt) + assert.Nil(t, err) + + updatedEventRequest := &admin.TaskExecutionEventRequest{ + Event: &event.TaskExecutionEvent{ + TaskId: sampleTaskID, + ParentNodeExecutionId: sampleNodeExecID, + Phase: core.TaskExecution_QUEUED, + OccurredAt: occuredAtProto, + Logs: []*core.TaskLog{ + { + Uri: "uri-show-pending", + ShowWhilePending: true, + }, + { + Uri: "uri-default", + }, + }, + Reason: "task update", + }, + } + + err = UpdateTaskExecutionModel(context.TODO(), updatedEventRequest, &existingTaskExecution, + interfaces.InlineEventDataPolicyStoreInline, commonMocks.GetMockStorageClient()) + assert.Nil(t, err) + + updatedClosure := &admin.TaskExecutionClosure{} + err = proto.Unmarshal(existingTaskExecution.Closure, updatedClosure) + assert.Nil(t, err) + + assert.Equal(t, updatedClosure.Logs, []*core.TaskLog{ + { + Uri: "uri-show-pending", + ShowWhilePending: true, + }, + }, + ) + +} + +func TestUpdateTaskExecutionModelFilterLogLinksArray(t *testing.T) { + existingClosure := &admin.TaskExecutionClosure{ + Phase: core.TaskExecution_RUNNING, + StartedAt: taskEventOccurredAtProto, + CreatedAt: taskEventOccurredAtProto, + UpdatedAt: taskEventOccurredAtProto, + Logs: []*core.TaskLog{}, + Reason: "task started", + Reasons: []*admin.Reason{ + { + OccurredAt: taskEventOccurredAtProto, + Message: "task started", + }, + }, + Metadata: &event.TaskExecutionMetadata{ + ExternalResources: []*event.ExternalResourceInfo{ + { + Logs: []*core.TaskLog{ + { + Uri: "uri-default", + }, + { + Uri: "uri-hide-finished", + HideOnceFinished: true, + }, + }, + }, + }, + }, + } + + closureBytes, err := proto.Marshal(existingClosure) + assert.Nil(t, err) + + existingTaskExecution := models.TaskExecution{ + TaskExecutionKey: models.TaskExecutionKey{ + TaskKey: models.TaskKey{ + Project: sampleTaskID.Project, + Domain: sampleTaskID.Domain, + Name: sampleTaskID.Name, + Version: sampleTaskID.Version, + }, + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: sampleNodeExecID.NodeId, + ExecutionKey: models.ExecutionKey{ + Project: sampleNodeExecID.ExecutionId.Project, + Domain: sampleNodeExecID.ExecutionId.Domain, + Name: sampleNodeExecID.ExecutionId.Name, + }, + }, + RetryAttempt: &retryAttemptValue, + }, + Phase: "TaskExecutionPhase_TASK_PHASE_RUNNING", + InputURI: "input uri", + Closure: closureBytes, + StartedAt: &taskEventOccurredAt, + TaskExecutionCreatedAt: &taskEventOccurredAt, + TaskExecutionUpdatedAt: &taskEventOccurredAt, + } + + occuredAt := taskEventOccurredAt.Add(time.Minute) + occuredAtProto, err := ptypes.TimestampProto(occuredAt) + assert.Nil(t, err) + + failedEventRequest := &admin.TaskExecutionEventRequest{ + Event: &event.TaskExecutionEvent{ + TaskId: sampleTaskID, + ParentNodeExecutionId: sampleNodeExecID, + Phase: core.TaskExecution_FAILED, + OccurredAt: occuredAtProto, + Reason: "something went wrong", + }, + } + + err = UpdateTaskExecutionModel(context.TODO(), failedEventRequest, &existingTaskExecution, + interfaces.InlineEventDataPolicyStoreInline, commonMocks.GetMockStorageClient()) + assert.Nil(t, err) + + updatedClosure := &admin.TaskExecutionClosure{} + err = proto.Unmarshal(existingTaskExecution.Closure, updatedClosure) + assert.Nil(t, err) + + assert.Equal(t, updatedClosure.Metadata.ExternalResources[0].Logs, []*core.TaskLog{ + { + Uri: "uri-default", + }, + }, + ) + +} + func TestUpdateTaskExecutionModelSingleEvents(t *testing.T) { existingClosure := &admin.TaskExecutionClosure{ Phase: core.TaskExecution_RUNNING, @@ -1208,6 +1385,125 @@ func TestMergeLogs(t *testing.T) { } } +func TestFilterLogsByPhase(t *testing.T) { + type testCase struct { + existing []*core.TaskLog + expected []*core.TaskLog + phase core.TaskExecution_Phase + name string + } + + testCases := []testCase{ + { + existing: []*core.TaskLog{ + { + Uri: "default-uri", + ShowWhilePending: false, + HideOnceFinished: false, + }, + { + Uri: "show-pending-uri", + ShowWhilePending: true, + HideOnceFinished: false, + }, + { + Uri: "hide-finished-uri", + ShowWhilePending: false, + HideOnceFinished: true, + }, + }, + expected: []*core.TaskLog{ + { + Uri: "show-pending-uri", + ShowWhilePending: true, + HideOnceFinished: false, + }, + }, + phase: core.TaskExecution_QUEUED, + name: "Filtered logs in QUEUED phase", + }, + { + existing: []*core.TaskLog{ + { + Uri: "default-uri", + ShowWhilePending: false, + HideOnceFinished: false, + }, + { + Uri: "show-pending-uri", + ShowWhilePending: true, + HideOnceFinished: false, + }, + { + Uri: "hide-finished-uri", + ShowWhilePending: false, + HideOnceFinished: true, + }, + }, + expected: []*core.TaskLog{ + { + Uri: "default-uri", + ShowWhilePending: false, + HideOnceFinished: false, + }, + { + Uri: "show-pending-uri", + ShowWhilePending: true, + HideOnceFinished: false, + }, + { + Uri: "hide-finished-uri", + ShowWhilePending: false, + HideOnceFinished: true, + }, + }, + phase: core.TaskExecution_RUNNING, + name: "Filtered logs in RUNNING phase", + }, + { + existing: []*core.TaskLog{ + { + Uri: "default-uri", + ShowWhilePending: false, + HideOnceFinished: false, + }, + { + Uri: "show-pending-uri", + ShowWhilePending: true, + HideOnceFinished: false, + }, + { + Uri: "hide-finished-uri", + ShowWhilePending: false, + HideOnceFinished: true, + }, + }, + expected: []*core.TaskLog{ + { + Uri: "default-uri", + ShowWhilePending: false, + HideOnceFinished: false, + }, + { + Uri: "show-pending-uri", + ShowWhilePending: true, + HideOnceFinished: false, + }, + }, + phase: core.TaskExecution_SUCCEEDED, + name: "Filtered logs in terminated phase", + }, + } + for _, filterTestCase := range testCases { + filteredLogs := filterLogsByPhase(filterTestCase.existing, filterTestCase.phase) + + assert.Equal(t, len(filterTestCase.expected), len(filteredLogs), fmt.Sprintf("%s failed", filterTestCase.name)) + for idx, expectedLog := range filterTestCase.expected { + assert.True(t, proto.Equal(expectedLog, filteredLogs[idx]), fmt.Sprintf("%s failed", filterTestCase.name)) + } + } +} + func TestMergeCustoms(t *testing.T) { t.Run("nothing to do", func(t *testing.T) { custom, err := mergeCustom(nil, nil) diff --git a/flyteidl/clients/go/assets/admin.swagger.json b/flyteidl/clients/go/assets/admin.swagger.json index 4e944181b5..e1f0b29579 100644 --- a/flyteidl/clients/go/assets/admin.swagger.json +++ b/flyteidl/clients/go/assets/admin.swagger.json @@ -8117,6 +8117,12 @@ }, "ttl": { "type": "string" + }, + "ShowWhilePending": { + "type": "boolean" + }, + "HideOnceFinished": { + "type": "boolean" } }, "title": "Log information for the task that is specific to a log sink\nWhen our log story is flushed out, we may have more metadata here like log link expiry" diff --git a/flyteidl/gen/pb-es/flyteidl/core/execution_pb.ts b/flyteidl/gen/pb-es/flyteidl/core/execution_pb.ts index e931e1a789..5283936b1f 100644 --- a/flyteidl/gen/pb-es/flyteidl/core/execution_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/core/execution_pb.ts @@ -427,6 +427,16 @@ export class TaskLog extends Message { */ ttl?: Duration; + /** + * @generated from field: bool ShowWhilePending = 5; + */ + ShowWhilePending = false; + + /** + * @generated from field: bool HideOnceFinished = 6; + */ + HideOnceFinished = false; + constructor(data?: PartialMessage) { super(); proto3.util.initPartial(data, this); @@ -439,6 +449,8 @@ export class TaskLog extends Message { { no: 2, name: "name", kind: "scalar", T: 9 /* ScalarType.STRING */ }, { no: 3, name: "message_format", kind: "enum", T: proto3.getEnumType(TaskLog_MessageFormat) }, { no: 4, name: "ttl", kind: "message", T: Duration }, + { no: 5, name: "ShowWhilePending", kind: "scalar", T: 8 /* ScalarType.BOOL */ }, + { no: 6, name: "HideOnceFinished", kind: "scalar", T: 8 /* ScalarType.BOOL */ }, ]); static fromBinary(bytes: Uint8Array, options?: Partial): TaskLog { diff --git a/flyteidl/gen/pb-go/flyteidl/core/execution.pb.go b/flyteidl/gen/pb-go/flyteidl/core/execution.pb.go index fe558cf94c..7befaca1ac 100644 --- a/flyteidl/gen/pb-go/flyteidl/core/execution.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/core/execution.pb.go @@ -583,10 +583,12 @@ type TaskLog struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Uri string `protobuf:"bytes,1,opt,name=uri,proto3" json:"uri,omitempty"` - Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` - MessageFormat TaskLog_MessageFormat `protobuf:"varint,3,opt,name=message_format,json=messageFormat,proto3,enum=flyteidl.core.TaskLog_MessageFormat" json:"message_format,omitempty"` - Ttl *durationpb.Duration `protobuf:"bytes,4,opt,name=ttl,proto3" json:"ttl,omitempty"` + Uri string `protobuf:"bytes,1,opt,name=uri,proto3" json:"uri,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + MessageFormat TaskLog_MessageFormat `protobuf:"varint,3,opt,name=message_format,json=messageFormat,proto3,enum=flyteidl.core.TaskLog_MessageFormat" json:"message_format,omitempty"` + Ttl *durationpb.Duration `protobuf:"bytes,4,opt,name=ttl,proto3" json:"ttl,omitempty"` + ShowWhilePending bool `protobuf:"varint,5,opt,name=ShowWhilePending,proto3" json:"ShowWhilePending,omitempty"` + HideOnceFinished bool `protobuf:"varint,6,opt,name=HideOnceFinished,proto3" json:"HideOnceFinished,omitempty"` } func (x *TaskLog) Reset() { @@ -649,6 +651,20 @@ func (x *TaskLog) GetTtl() *durationpb.Duration { return nil } +func (x *TaskLog) GetShowWhilePending() bool { + if x != nil { + return x.ShowWhilePending + } + return false +} + +func (x *TaskLog) GetHideOnceFinished() bool { + if x != nil { + return x.HideOnceFinished + } + return false +} + // Represents customized execution run-time attributes. type QualityOfServiceSpec struct { state protoimpl.MessageState @@ -832,7 +848,7 @@ var file_flyteidl_core_execution_proto_rawDesc = []byte{ 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x22, 0x2e, 0x0a, 0x09, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x4b, 0x69, 0x6e, 0x64, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x55, 0x53, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x53, 0x59, 0x53, - 0x54, 0x45, 0x4d, 0x10, 0x02, 0x22, 0xda, 0x01, 0x0a, 0x07, 0x54, 0x61, 0x73, 0x6b, 0x4c, 0x6f, + 0x54, 0x45, 0x4d, 0x10, 0x02, 0x22, 0xb2, 0x02, 0x0a, 0x07, 0x54, 0x61, 0x73, 0x6b, 0x4c, 0x6f, 0x67, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x69, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x4b, 0x0a, 0x0e, 0x6d, 0x65, 0x73, 0x73, 0x61, @@ -843,40 +859,46 @@ var file_flyteidl_core_execution_proto_rawDesc = []byte{ 0x72, 0x6d, 0x61, 0x74, 0x12, 0x2b, 0x0a, 0x03, 0x74, 0x74, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x03, 0x74, 0x74, - 0x6c, 0x22, 0x2f, 0x0a, 0x0d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x46, 0x6f, 0x72, 0x6d, - 0x61, 0x74, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, - 0x07, 0x0a, 0x03, 0x43, 0x53, 0x56, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x4a, 0x53, 0x4f, 0x4e, - 0x10, 0x02, 0x22, 0x5a, 0x0a, 0x14, 0x51, 0x75, 0x61, 0x6c, 0x69, 0x74, 0x79, 0x4f, 0x66, 0x53, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x53, 0x70, 0x65, 0x63, 0x12, 0x42, 0x0a, 0x0f, 0x71, 0x75, - 0x65, 0x75, 0x65, 0x69, 0x6e, 0x67, 0x5f, 0x62, 0x75, 0x64, 0x67, 0x65, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, - 0x71, 0x75, 0x65, 0x75, 0x65, 0x69, 0x6e, 0x67, 0x42, 0x75, 0x64, 0x67, 0x65, 0x74, 0x22, 0xce, - 0x01, 0x0a, 0x10, 0x51, 0x75, 0x61, 0x6c, 0x69, 0x74, 0x79, 0x4f, 0x66, 0x53, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x12, 0x3a, 0x0a, 0x04, 0x74, 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x24, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, - 0x65, 0x2e, 0x51, 0x75, 0x61, 0x6c, 0x69, 0x74, 0x79, 0x4f, 0x66, 0x53, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x2e, 0x54, 0x69, 0x65, 0x72, 0x48, 0x00, 0x52, 0x04, 0x74, 0x69, 0x65, 0x72, 0x12, - 0x39, 0x0a, 0x04, 0x73, 0x70, 0x65, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, - 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x51, 0x75, + 0x6c, 0x12, 0x2a, 0x0a, 0x10, 0x53, 0x68, 0x6f, 0x77, 0x57, 0x68, 0x69, 0x6c, 0x65, 0x50, 0x65, + 0x6e, 0x64, 0x69, 0x6e, 0x67, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x53, 0x68, 0x6f, + 0x77, 0x57, 0x68, 0x69, 0x6c, 0x65, 0x50, 0x65, 0x6e, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x2a, 0x0a, + 0x10, 0x48, 0x69, 0x64, 0x65, 0x4f, 0x6e, 0x63, 0x65, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, + 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x48, 0x69, 0x64, 0x65, 0x4f, 0x6e, 0x63, + 0x65, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x22, 0x2f, 0x0a, 0x0d, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x46, 0x6f, 0x72, 0x6d, 0x61, 0x74, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, + 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x43, 0x53, 0x56, 0x10, 0x01, + 0x12, 0x08, 0x0a, 0x04, 0x4a, 0x53, 0x4f, 0x4e, 0x10, 0x02, 0x22, 0x5a, 0x0a, 0x14, 0x51, 0x75, 0x61, 0x6c, 0x69, 0x74, 0x79, 0x4f, 0x66, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x53, 0x70, - 0x65, 0x63, 0x48, 0x00, 0x52, 0x04, 0x73, 0x70, 0x65, 0x63, 0x22, 0x34, 0x0a, 0x04, 0x54, 0x69, - 0x65, 0x72, 0x12, 0x0d, 0x0a, 0x09, 0x55, 0x4e, 0x44, 0x45, 0x46, 0x49, 0x4e, 0x45, 0x44, 0x10, - 0x00, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x49, 0x47, 0x48, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x4d, - 0x45, 0x44, 0x49, 0x55, 0x4d, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4c, 0x4f, 0x57, 0x10, 0x03, - 0x42, 0x0d, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x42, - 0xb4, 0x01, 0x0a, 0x11, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x42, 0x0e, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, - 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, - 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x63, - 0x6f, 0x72, 0x65, 0xa2, 0x02, 0x03, 0x46, 0x43, 0x58, 0xaa, 0x02, 0x0d, 0x46, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x43, 0x6f, 0x72, 0x65, 0xca, 0x02, 0x0d, 0x46, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x43, 0x6f, 0x72, 0x65, 0xe2, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x43, 0x6f, 0x72, 0x65, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x0e, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x3a, 0x3a, 0x43, 0x6f, 0x72, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x63, 0x12, 0x42, 0x0a, 0x0f, 0x71, 0x75, 0x65, 0x75, 0x65, 0x69, 0x6e, 0x67, 0x5f, 0x62, + 0x75, 0x64, 0x67, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, + 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0e, 0x71, 0x75, 0x65, 0x75, 0x65, 0x69, 0x6e, 0x67, + 0x42, 0x75, 0x64, 0x67, 0x65, 0x74, 0x22, 0xce, 0x01, 0x0a, 0x10, 0x51, 0x75, 0x61, 0x6c, 0x69, + 0x74, 0x79, 0x4f, 0x66, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x3a, 0x0a, 0x04, 0x74, + 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x24, 0x2e, 0x66, 0x6c, 0x79, 0x74, + 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x51, 0x75, 0x61, 0x6c, 0x69, 0x74, + 0x79, 0x4f, 0x66, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x54, 0x69, 0x65, 0x72, 0x48, + 0x00, 0x52, 0x04, 0x74, 0x69, 0x65, 0x72, 0x12, 0x39, 0x0a, 0x04, 0x73, 0x70, 0x65, 0x63, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, + 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x51, 0x75, 0x61, 0x6c, 0x69, 0x74, 0x79, 0x4f, 0x66, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x53, 0x70, 0x65, 0x63, 0x48, 0x00, 0x52, 0x04, 0x73, 0x70, + 0x65, 0x63, 0x22, 0x34, 0x0a, 0x04, 0x54, 0x69, 0x65, 0x72, 0x12, 0x0d, 0x0a, 0x09, 0x55, 0x4e, + 0x44, 0x45, 0x46, 0x49, 0x4e, 0x45, 0x44, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x49, 0x47, + 0x48, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x4d, 0x45, 0x44, 0x49, 0x55, 0x4d, 0x10, 0x02, 0x12, + 0x07, 0x0a, 0x03, 0x4c, 0x4f, 0x57, 0x10, 0x03, 0x42, 0x0d, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x69, + 0x67, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x42, 0xb4, 0x01, 0x0a, 0x11, 0x63, 0x6f, 0x6d, 0x2e, + 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x42, 0x0e, 0x45, + 0x78, 0x65, 0x63, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, + 0x3a, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, + 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, + 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0xa2, 0x02, 0x03, 0x46, 0x43, + 0x58, 0xaa, 0x02, 0x0d, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x43, 0x6f, 0x72, + 0x65, 0xca, 0x02, 0x0d, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x43, 0x6f, 0x72, + 0x65, 0xe2, 0x02, 0x19, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x43, 0x6f, 0x72, + 0x65, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x0e, + 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x43, 0x6f, 0x72, 0x65, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json index 4e944181b5..e1f0b29579 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/admin.swagger.json @@ -8117,6 +8117,12 @@ }, "ttl": { "type": "string" + }, + "ShowWhilePending": { + "type": "boolean" + }, + "HideOnceFinished": { + "type": "boolean" } }, "title": "Log information for the task that is specific to a log sink\nWhen our log story is flushed out, we may have more metadata here like log link expiry" diff --git a/flyteidl/gen/pb-go/gateway/flyteidl/service/agent.swagger.json b/flyteidl/gen/pb-go/gateway/flyteidl/service/agent.swagger.json index a2d429d019..a0e8cfed39 100644 --- a/flyteidl/gen/pb-go/gateway/flyteidl/service/agent.swagger.json +++ b/flyteidl/gen/pb-go/gateway/flyteidl/service/agent.swagger.json @@ -1724,6 +1724,12 @@ }, "ttl": { "type": "string" + }, + "ShowWhilePending": { + "type": "boolean" + }, + "HideOnceFinished": { + "type": "boolean" } }, "title": "Log information for the task that is specific to a log sink\nWhen our log story is flushed out, we may have more metadata here like log link expiry" diff --git a/flyteidl/gen/pb-js/flyteidl.d.ts b/flyteidl/gen/pb-js/flyteidl.d.ts index cceab76647..3485ae15fc 100644 --- a/flyteidl/gen/pb-js/flyteidl.d.ts +++ b/flyteidl/gen/pb-js/flyteidl.d.ts @@ -5743,6 +5743,12 @@ export namespace flyteidl { /** TaskLog ttl */ ttl?: (google.protobuf.IDuration|null); + + /** TaskLog ShowWhilePending */ + ShowWhilePending?: (boolean|null); + + /** TaskLog HideOnceFinished */ + HideOnceFinished?: (boolean|null); } /** Represents a TaskLog. */ @@ -5766,6 +5772,12 @@ export namespace flyteidl { /** TaskLog ttl. */ public ttl?: (google.protobuf.IDuration|null); + /** TaskLog ShowWhilePending. */ + public ShowWhilePending: boolean; + + /** TaskLog HideOnceFinished. */ + public HideOnceFinished: boolean; + /** * Creates a new TaskLog instance using the specified properties. * @param [properties] Properties to set diff --git a/flyteidl/gen/pb-js/flyteidl.js b/flyteidl/gen/pb-js/flyteidl.js index e29f5fc9ff..6d7ae28690 100644 --- a/flyteidl/gen/pb-js/flyteidl.js +++ b/flyteidl/gen/pb-js/flyteidl.js @@ -13800,6 +13800,8 @@ * @property {string|null} [name] TaskLog name * @property {flyteidl.core.TaskLog.MessageFormat|null} [messageFormat] TaskLog messageFormat * @property {google.protobuf.IDuration|null} [ttl] TaskLog ttl + * @property {boolean|null} [ShowWhilePending] TaskLog ShowWhilePending + * @property {boolean|null} [HideOnceFinished] TaskLog HideOnceFinished */ /** @@ -13849,6 +13851,22 @@ */ TaskLog.prototype.ttl = null; + /** + * TaskLog ShowWhilePending. + * @member {boolean} ShowWhilePending + * @memberof flyteidl.core.TaskLog + * @instance + */ + TaskLog.prototype.ShowWhilePending = false; + + /** + * TaskLog HideOnceFinished. + * @member {boolean} HideOnceFinished + * @memberof flyteidl.core.TaskLog + * @instance + */ + TaskLog.prototype.HideOnceFinished = false; + /** * Creates a new TaskLog instance using the specified properties. * @function create @@ -13881,6 +13899,10 @@ writer.uint32(/* id 3, wireType 0 =*/24).int32(message.messageFormat); if (message.ttl != null && message.hasOwnProperty("ttl")) $root.google.protobuf.Duration.encode(message.ttl, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + if (message.ShowWhilePending != null && message.hasOwnProperty("ShowWhilePending")) + writer.uint32(/* id 5, wireType 0 =*/40).bool(message.ShowWhilePending); + if (message.HideOnceFinished != null && message.hasOwnProperty("HideOnceFinished")) + writer.uint32(/* id 6, wireType 0 =*/48).bool(message.HideOnceFinished); return writer; }; @@ -13914,6 +13936,12 @@ case 4: message.ttl = $root.google.protobuf.Duration.decode(reader, reader.uint32()); break; + case 5: + message.ShowWhilePending = reader.bool(); + break; + case 6: + message.HideOnceFinished = reader.bool(); + break; default: reader.skipType(tag & 7); break; @@ -13953,6 +13981,12 @@ if (error) return "ttl." + error; } + if (message.ShowWhilePending != null && message.hasOwnProperty("ShowWhilePending")) + if (typeof message.ShowWhilePending !== "boolean") + return "ShowWhilePending: boolean expected"; + if (message.HideOnceFinished != null && message.hasOwnProperty("HideOnceFinished")) + if (typeof message.HideOnceFinished !== "boolean") + return "HideOnceFinished: boolean expected"; return null; }; diff --git a/flyteidl/gen/pb_python/flyteidl/core/execution_pb2.py b/flyteidl/gen/pb_python/flyteidl/core/execution_pb2.py index c2c9810083..2d59497e3a 100644 --- a/flyteidl/gen/pb_python/flyteidl/core/execution_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/core/execution_pb2.py @@ -14,7 +14,7 @@ from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x66lyteidl/core/execution.proto\x12\rflyteidl.core\x1a\x1egoogle/protobuf/duration.proto\"\xa7\x01\n\x11WorkflowExecution\"\x91\x01\n\x05Phase\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\x0b\n\x07RUNNING\x10\x02\x12\x0e\n\nSUCCEEDING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\x0b\n\x07\x46\x41ILING\x10\x05\x12\n\n\x06\x46\x41ILED\x10\x06\x12\x0b\n\x07\x41\x42ORTED\x10\x07\x12\r\n\tTIMED_OUT\x10\x08\x12\x0c\n\x08\x41\x42ORTING\x10\t\"\xb6\x01\n\rNodeExecution\"\xa4\x01\n\x05Phase\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\x0b\n\x07RUNNING\x10\x02\x12\r\n\tSUCCEEDED\x10\x03\x12\x0b\n\x07\x46\x41ILING\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\x0b\n\x07\x41\x42ORTED\x10\x06\x12\x0b\n\x07SKIPPED\x10\x07\x12\r\n\tTIMED_OUT\x10\x08\x12\x13\n\x0f\x44YNAMIC_RUNNING\x10\t\x12\r\n\tRECOVERED\x10\n\"\x96\x01\n\rTaskExecution\"\x84\x01\n\x05Phase\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\x0b\n\x07RUNNING\x10\x02\x12\r\n\tSUCCEEDED\x10\x03\x12\x0b\n\x07\x41\x42ORTED\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\x10\n\x0cINITIALIZING\x10\x06\x12\x19\n\x15WAITING_FOR_RESOURCES\x10\x07\"\xc8\x01\n\x0e\x45xecutionError\x12\x12\n\x04\x63ode\x18\x01 \x01(\tR\x04\x63ode\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12\x1b\n\terror_uri\x18\x03 \x01(\tR\x08\x65rrorUri\x12;\n\x04kind\x18\x04 \x01(\x0e\x32\'.flyteidl.core.ExecutionError.ErrorKindR\x04kind\".\n\tErrorKind\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x08\n\x04USER\x10\x01\x12\n\n\x06SYSTEM\x10\x02\"\xda\x01\n\x07TaskLog\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12K\n\x0emessage_format\x18\x03 \x01(\x0e\x32$.flyteidl.core.TaskLog.MessageFormatR\rmessageFormat\x12+\n\x03ttl\x18\x04 \x01(\x0b\x32\x19.google.protobuf.DurationR\x03ttl\"/\n\rMessageFormat\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x07\n\x03\x43SV\x10\x01\x12\x08\n\x04JSON\x10\x02\"Z\n\x14QualityOfServiceSpec\x12\x42\n\x0fqueueing_budget\x18\x01 \x01(\x0b\x32\x19.google.protobuf.DurationR\x0equeueingBudget\"\xce\x01\n\x10QualityOfService\x12:\n\x04tier\x18\x01 \x01(\x0e\x32$.flyteidl.core.QualityOfService.TierH\x00R\x04tier\x12\x39\n\x04spec\x18\x02 \x01(\x0b\x32#.flyteidl.core.QualityOfServiceSpecH\x00R\x04spec\"4\n\x04Tier\x12\r\n\tUNDEFINED\x10\x00\x12\x08\n\x04HIGH\x10\x01\x12\n\n\x06MEDIUM\x10\x02\x12\x07\n\x03LOW\x10\x03\x42\r\n\x0b\x64\x65signationB\xb4\x01\n\x11\x63om.flyteidl.coreB\x0e\x45xecutionProtoP\x01Z:github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core\xa2\x02\x03\x46\x43X\xaa\x02\rFlyteidl.Core\xca\x02\rFlyteidl\\Core\xe2\x02\x19\x46lyteidl\\Core\\GPBMetadata\xea\x02\x0e\x46lyteidl::Coreb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x66lyteidl/core/execution.proto\x12\rflyteidl.core\x1a\x1egoogle/protobuf/duration.proto\"\xa7\x01\n\x11WorkflowExecution\"\x91\x01\n\x05Phase\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\x0b\n\x07RUNNING\x10\x02\x12\x0e\n\nSUCCEEDING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\x0b\n\x07\x46\x41ILING\x10\x05\x12\n\n\x06\x46\x41ILED\x10\x06\x12\x0b\n\x07\x41\x42ORTED\x10\x07\x12\r\n\tTIMED_OUT\x10\x08\x12\x0c\n\x08\x41\x42ORTING\x10\t\"\xb6\x01\n\rNodeExecution\"\xa4\x01\n\x05Phase\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\x0b\n\x07RUNNING\x10\x02\x12\r\n\tSUCCEEDED\x10\x03\x12\x0b\n\x07\x46\x41ILING\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\x0b\n\x07\x41\x42ORTED\x10\x06\x12\x0b\n\x07SKIPPED\x10\x07\x12\r\n\tTIMED_OUT\x10\x08\x12\x13\n\x0f\x44YNAMIC_RUNNING\x10\t\x12\r\n\tRECOVERED\x10\n\"\x96\x01\n\rTaskExecution\"\x84\x01\n\x05Phase\x12\r\n\tUNDEFINED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\x0b\n\x07RUNNING\x10\x02\x12\r\n\tSUCCEEDED\x10\x03\x12\x0b\n\x07\x41\x42ORTED\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\x10\n\x0cINITIALIZING\x10\x06\x12\x19\n\x15WAITING_FOR_RESOURCES\x10\x07\"\xc8\x01\n\x0e\x45xecutionError\x12\x12\n\x04\x63ode\x18\x01 \x01(\tR\x04\x63ode\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12\x1b\n\terror_uri\x18\x03 \x01(\tR\x08\x65rrorUri\x12;\n\x04kind\x18\x04 \x01(\x0e\x32\'.flyteidl.core.ExecutionError.ErrorKindR\x04kind\".\n\tErrorKind\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x08\n\x04USER\x10\x01\x12\n\n\x06SYSTEM\x10\x02\"\xb2\x02\n\x07TaskLog\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12K\n\x0emessage_format\x18\x03 \x01(\x0e\x32$.flyteidl.core.TaskLog.MessageFormatR\rmessageFormat\x12+\n\x03ttl\x18\x04 \x01(\x0b\x32\x19.google.protobuf.DurationR\x03ttl\x12*\n\x10ShowWhilePending\x18\x05 \x01(\x08R\x10ShowWhilePending\x12*\n\x10HideOnceFinished\x18\x06 \x01(\x08R\x10HideOnceFinished\"/\n\rMessageFormat\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x07\n\x03\x43SV\x10\x01\x12\x08\n\x04JSON\x10\x02\"Z\n\x14QualityOfServiceSpec\x12\x42\n\x0fqueueing_budget\x18\x01 \x01(\x0b\x32\x19.google.protobuf.DurationR\x0equeueingBudget\"\xce\x01\n\x10QualityOfService\x12:\n\x04tier\x18\x01 \x01(\x0e\x32$.flyteidl.core.QualityOfService.TierH\x00R\x04tier\x12\x39\n\x04spec\x18\x02 \x01(\x0b\x32#.flyteidl.core.QualityOfServiceSpecH\x00R\x04spec\"4\n\x04Tier\x12\r\n\tUNDEFINED\x10\x00\x12\x08\n\x04HIGH\x10\x01\x12\n\n\x06MEDIUM\x10\x02\x12\x07\n\x03LOW\x10\x03\x42\r\n\x0b\x64\x65signationB\xb4\x01\n\x11\x63om.flyteidl.coreB\x0e\x45xecutionProtoP\x01Z:github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core\xa2\x02\x03\x46\x43X\xaa\x02\rFlyteidl.Core\xca\x02\rFlyteidl\\Core\xe2\x02\x19\x46lyteidl\\Core\\GPBMetadata\xea\x02\x0e\x46lyteidl::Coreb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -40,13 +40,13 @@ _globals['_EXECUTIONERROR_ERRORKIND']._serialized_start=743 _globals['_EXECUTIONERROR_ERRORKIND']._serialized_end=789 _globals['_TASKLOG']._serialized_start=792 - _globals['_TASKLOG']._serialized_end=1010 - _globals['_TASKLOG_MESSAGEFORMAT']._serialized_start=963 - _globals['_TASKLOG_MESSAGEFORMAT']._serialized_end=1010 - _globals['_QUALITYOFSERVICESPEC']._serialized_start=1012 - _globals['_QUALITYOFSERVICESPEC']._serialized_end=1102 - _globals['_QUALITYOFSERVICE']._serialized_start=1105 - _globals['_QUALITYOFSERVICE']._serialized_end=1311 - _globals['_QUALITYOFSERVICE_TIER']._serialized_start=1244 - _globals['_QUALITYOFSERVICE_TIER']._serialized_end=1296 + _globals['_TASKLOG']._serialized_end=1098 + _globals['_TASKLOG_MESSAGEFORMAT']._serialized_start=1051 + _globals['_TASKLOG_MESSAGEFORMAT']._serialized_end=1098 + _globals['_QUALITYOFSERVICESPEC']._serialized_start=1100 + _globals['_QUALITYOFSERVICESPEC']._serialized_end=1190 + _globals['_QUALITYOFSERVICE']._serialized_start=1193 + _globals['_QUALITYOFSERVICE']._serialized_end=1399 + _globals['_QUALITYOFSERVICE_TIER']._serialized_start=1332 + _globals['_QUALITYOFSERVICE_TIER']._serialized_end=1384 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/core/execution_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/core/execution_pb2.pyi index 2508c1b4ac..5c28a55418 100644 --- a/flyteidl/gen/pb_python/flyteidl/core/execution_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/core/execution_pb2.pyi @@ -103,7 +103,7 @@ class ExecutionError(_message.Message): def __init__(self, code: _Optional[str] = ..., message: _Optional[str] = ..., error_uri: _Optional[str] = ..., kind: _Optional[_Union[ExecutionError.ErrorKind, str]] = ...) -> None: ... class TaskLog(_message.Message): - __slots__ = ["uri", "name", "message_format", "ttl"] + __slots__ = ["uri", "name", "message_format", "ttl", "ShowWhilePending", "HideOnceFinished"] class MessageFormat(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = [] UNKNOWN: _ClassVar[TaskLog.MessageFormat] @@ -116,11 +116,15 @@ class TaskLog(_message.Message): NAME_FIELD_NUMBER: _ClassVar[int] MESSAGE_FORMAT_FIELD_NUMBER: _ClassVar[int] TTL_FIELD_NUMBER: _ClassVar[int] + SHOWWHILEPENDING_FIELD_NUMBER: _ClassVar[int] + HIDEONCEFINISHED_FIELD_NUMBER: _ClassVar[int] uri: str name: str message_format: TaskLog.MessageFormat ttl: _duration_pb2.Duration - def __init__(self, uri: _Optional[str] = ..., name: _Optional[str] = ..., message_format: _Optional[_Union[TaskLog.MessageFormat, str]] = ..., ttl: _Optional[_Union[_duration_pb2.Duration, _Mapping]] = ...) -> None: ... + ShowWhilePending: bool + HideOnceFinished: bool + def __init__(self, uri: _Optional[str] = ..., name: _Optional[str] = ..., message_format: _Optional[_Union[TaskLog.MessageFormat, str]] = ..., ttl: _Optional[_Union[_duration_pb2.Duration, _Mapping]] = ..., ShowWhilePending: bool = ..., HideOnceFinished: bool = ...) -> None: ... class QualityOfServiceSpec(_message.Message): __slots__ = ["queueing_budget"] diff --git a/flyteidl/gen/pb_rust/flyteidl.core.rs b/flyteidl/gen/pb_rust/flyteidl.core.rs index 215fa60f82..2b8cad3ef2 100644 --- a/flyteidl/gen/pb_rust/flyteidl.core.rs +++ b/flyteidl/gen/pb_rust/flyteidl.core.rs @@ -2155,6 +2155,10 @@ pub struct TaskLog { pub message_format: i32, #[prost(message, optional, tag="4")] pub ttl: ::core::option::Option<::prost_types::Duration>, + #[prost(bool, tag="5")] + pub show_while_pending: bool, + #[prost(bool, tag="6")] + pub hide_once_finished: bool, } /// Nested message and enum types in `TaskLog`. pub mod task_log { diff --git a/flyteidl/go.mod b/flyteidl/go.mod index 17673db704..d80bbbfa6d 100644 --- a/flyteidl/go.mod +++ b/flyteidl/go.mod @@ -7,9 +7,11 @@ toolchain go1.21.3 require ( github.com/flyteorg/flyte/flytestdlib v0.0.0-00010101000000-000000000000 github.com/go-test/deep v1.0.7 + github.com/golang/glog v1.2.0 github.com/golang/protobuf v1.5.3 github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 + github.com/grpc-ecosystem/grpc-gateway v1.16.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 github.com/jinzhu/copier v0.3.5 github.com/mitchellh/mapstructure v1.5.0 @@ -19,6 +21,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/net v0.23.0 golang.org/x/oauth2 v0.16.0 + google.golang.org/api v0.155.0 google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80 google.golang.org/grpc v1.62.1 google.golang.org/protobuf v1.33.0 @@ -100,7 +103,6 @@ require ( golang.org/x/term v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect - google.golang.org/api v0.155.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 // indirect diff --git a/flyteidl/go.sum b/flyteidl/go.sum index bfcf19eb85..c407c8e14b 100644 --- a/flyteidl/go.sum +++ b/flyteidl/go.sum @@ -224,6 +224,7 @@ github.com/grpc-ecosystem/go-grpc-middleware v1.1.0 h1:THDBEeQ9xZ8JEaCLyLQqXMMdR github.com/grpc-ecosystem/go-grpc-middleware v1.1.0/go.mod h1:f5nM7jw/oeRSadq3xCzHAvxcr8HZnzsqU6ILg/0NiiE= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 h1:BZHcxBETFHIdVyhyEfOvn/RdU/QGdLI4y34qQGjGWO0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= diff --git a/flyteidl/protos/flyteidl/core/execution.proto b/flyteidl/protos/flyteidl/core/execution.proto index d2eabdc577..4d55198955 100644 --- a/flyteidl/protos/flyteidl/core/execution.proto +++ b/flyteidl/protos/flyteidl/core/execution.proto @@ -89,6 +89,8 @@ message TaskLog { string name = 2; MessageFormat message_format = 3; google.protobuf.Duration ttl = 4; + bool ShowWhilePending = 5; + bool HideOnceFinished = 6; } // Represents customized execution run-time attributes. diff --git a/flyteplugins/go/tasks/logs/logging_utils.go b/flyteplugins/go/tasks/logs/logging_utils.go index 45d12624de..3322cc37d8 100644 --- a/flyteplugins/go/tasks/logs/logging_utils.go +++ b/flyteplugins/go/tasks/logs/logging_utils.go @@ -29,9 +29,11 @@ func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, tas return nil, nil } + containerID := v1.ContainerStatus{}.ContainerID if uint32(len(pod.Status.ContainerStatuses)) <= index { logger.Errorf(ctx, "containerStatus IndexOutOfBound, requested [%d], but total containerStatuses [%d] in pod phase [%v]", index, len(pod.Status.ContainerStatuses), pod.Status.Phase) - return nil, nil + } else { + containerID = pod.Status.ContainerStatuses[index].ContainerID } startTime := pod.CreationTimestamp.Unix() @@ -43,7 +45,7 @@ func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, tas PodUID: string(pod.GetUID()), Namespace: pod.Namespace, ContainerName: pod.Spec.Containers[index].Name, - ContainerID: pod.Status.ContainerStatuses[index].ContainerID, + ContainerID: containerID, LogName: nameSuffix, PodRFC3339StartTime: time.Unix(startTime, 0).Format(time.RFC3339), PodRFC3339FinishTime: time.Unix(finishTime, 0).Format(time.RFC3339), diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index 6c80cc4d24..376f261fac 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -226,7 +226,7 @@ func PhaseInfoQueued(t time.Time, version uint32, reason string) PhaseInfo { return pi } -func PhaseInfoQueuedWithTaskInfo(version uint32, reason string, info *TaskInfo) PhaseInfo { +func PhaseInfoQueuedWithTaskInfo(t time.Time, version uint32, reason string, info *TaskInfo) PhaseInfo { pi := phaseInfo(PhaseQueued, version, nil, info, false) pi.reason = reason return pi diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index ab9612c96c..b910da5ee3 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -624,8 +624,8 @@ func BuildIdentityPod() *v1.Pod { // resources requested is beyond the capability of the system. for this we will rely on configuration // and hence input gates. We should not allow bad requests that Request for large number of resource through. // In the case it makes through, we will fail after timeout -func DemystifyPending(status v1.PodStatus) (pluginsCore.PhaseInfo, error) { - phaseInfo, t := demystifyPendingHelper(status) +func DemystifyPending(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error) { + phaseInfo, t := demystifyPendingHelper(status, info) if phaseInfo.Phase().IsTerminal() { return phaseInfo, nil @@ -642,13 +642,14 @@ func DemystifyPending(status v1.PodStatus) (pluginsCore.PhaseInfo, error) { return phaseInfo, nil } - return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling"), nil + return pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling", phaseInfo.Info()), nil } -func demystifyPendingHelper(status v1.PodStatus) (pluginsCore.PhaseInfo, time.Time) { +func demystifyPendingHelper(status v1.PodStatus, info pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, time.Time) { // Search over the difference conditions in the status object. Note that the 'Pending' this function is // demystifying is the 'phase' of the pod status. This is different than the PodReady condition type also used below - phaseInfo := pluginsCore.PhaseInfoUndefined + phaseInfo := pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Demistify Pending", &info) + t := time.Now() for _, c := range status.Conditions { t = c.LastTransitionTime.Time @@ -656,7 +657,7 @@ func demystifyPendingHelper(status v1.PodStatus) (pluginsCore.PhaseInfo, time.Ti case v1.PodScheduled: if c.Status == v1.ConditionFalse { // Waiting to be scheduled. This usually refers to inability to acquire resources. - return pluginsCore.PhaseInfoQueued(t, pluginsCore.DefaultPhaseVersion, fmt.Sprintf("%s:%s", c.Reason, c.Message)), t + return pluginsCore.PhaseInfoQueuedWithTaskInfo(t, pluginsCore.DefaultPhaseVersion, fmt.Sprintf("%s:%s", c.Reason, c.Message), phaseInfo.Info()), t } case v1.PodReasonUnschedulable: @@ -669,7 +670,7 @@ func demystifyPendingHelper(status v1.PodStatus) (pluginsCore.PhaseInfo, time.Ti // reason: Unschedulable // status: "False" // type: PodScheduled - return pluginsCore.PhaseInfoQueued(t, pluginsCore.DefaultPhaseVersion, fmt.Sprintf("%s:%s", c.Reason, c.Message)), t + return pluginsCore.PhaseInfoQueuedWithTaskInfo(t, pluginsCore.DefaultPhaseVersion, fmt.Sprintf("%s:%s", c.Reason, c.Message), phaseInfo.Info()), t case v1.PodReady: if c.Status == v1.ConditionFalse { diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 7010247ee5..0c2e9ef5cc 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -1210,7 +1210,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskStatus.Phase()) }) @@ -1225,7 +1225,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskStatus.Phase()) }) @@ -1240,7 +1240,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskStatus.Phase()) }) @@ -1255,7 +1255,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskStatus.Phase()) }) @@ -1290,7 +1290,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseInitializing, taskStatus.Phase()) }) @@ -1307,7 +1307,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseInitializing, taskStatus.Phase()) }) @@ -1324,7 +1324,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseInitializing, taskStatus.Phase()) }) @@ -1343,7 +1343,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s2) + taskStatus, err := DemystifyPending(s2, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseInitializing, taskStatus.Phase()) }) @@ -1362,7 +1362,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s2) + taskStatus, err := DemystifyPending(s2, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskStatus.Phase()) assert.True(t, taskStatus.CleanupOnFailure()) @@ -1380,7 +1380,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhasePermanentFailure, taskStatus.Phase()) assert.True(t, taskStatus.CleanupOnFailure()) @@ -1398,7 +1398,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskStatus.Phase()) assert.True(t, taskStatus.CleanupOnFailure()) @@ -1416,7 +1416,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskStatus.Phase()) assert.True(t, taskStatus.CleanupOnFailure()) @@ -1436,7 +1436,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s2) + taskStatus, err := DemystifyPending(s2, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseInitializing, taskStatus.Phase()) }) @@ -1455,7 +1455,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s2) + taskStatus, err := DemystifyPending(s2, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhasePermanentFailure, taskStatus.Phase()) assert.True(t, taskStatus.CleanupOnFailure()) @@ -1475,7 +1475,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s2) + taskStatus, err := DemystifyPending(s2, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseInitializing, taskStatus.Phase()) }) @@ -1494,7 +1494,7 @@ func TestDemystifyPending(t *testing.T) { }, }, } - taskStatus, err := DemystifyPending(s2) + taskStatus, err := DemystifyPending(s2, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhasePermanentFailure, taskStatus.Phase()) assert.True(t, taskStatus.CleanupOnFailure()) @@ -1526,7 +1526,7 @@ func TestDemystifyPendingTimeout(t *testing.T) { s.Conditions[0].LastTransitionTime.Time = metav1.Now().Add(-config.GetK8sPluginConfig().PodPendingTimeout.Duration) t.Run("PodPendingExceedsTimeout", func(t *testing.T) { - taskStatus, err := DemystifyPending(s) + taskStatus, err := DemystifyPending(s, pluginsCore.TaskInfo{}) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskStatus.Phase()) assert.Equal(t, "PodPendingTimeout", taskStatus.Err().Code) @@ -1697,7 +1697,7 @@ func TestDemystifyPending_testcases(t *testing.T) { assert.NoError(t, err, "failed to read file %s", testFile) pod := &v1.Pod{} if assert.NoError(t, json.Unmarshal(data, pod), "failed to unmarshal json in %s. Expected of type v1.Pod", testFile) { - p, err := DemystifyPending(pod.Status) + p, err := DemystifyPending(pod.Status, pluginsCore.TaskInfo{}) if tt.isErr { assert.Error(t, err, "Error expected from method") } else { diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go index 1bbe07c02a..38a84f9b2b 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/plugin.go @@ -165,3 +165,25 @@ func AbortBehaviorDelete(resource client.Object) AbortBehavior { DeleteResource: true, } } + +// if we have the same Phase as the previous evaluation and updated the Reason but not the PhaseVersion we must +// update the PhaseVersion so an event is sent to reflect the Reason update. this does not handle the Running +// Phase because the legacy used `DefaultPhaseVersion + 1` which will only increment to 1. + +func MaybeUpdatePhaseVersion(phaseInfo *pluginsCore.PhaseInfo, pluginState *PluginState) { + if phaseInfo.Phase() != pluginsCore.PhaseRunning && phaseInfo.Phase() == pluginState.Phase && + phaseInfo.Version() <= pluginState.PhaseVersion && phaseInfo.Reason() != pluginState.Reason { + + *phaseInfo = phaseInfo.WithVersion(pluginState.PhaseVersion + 1) + } +} + +func MaybeUpdatePhaseVersionFromPluginContext(phaseInfo *pluginsCore.PhaseInfo, pluginContext *PluginContext) error { + pluginState := PluginState{} + _, err := (*pluginContext).PluginStateReader().Get(&pluginState) + if err != nil { + return err + } + MaybeUpdatePhaseVersion(phaseInfo, &pluginState) + return nil +} diff --git a/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go b/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go index fa47fa4729..143cf02e43 100644 --- a/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go @@ -62,4 +62,6 @@ type TemplateLogPlugin struct { MessageFormat core.TaskLog_MessageFormat `json:"messageFormat" pflag:"-,Log Message Format."` // Deprecated: Please, do not use DeprecatedScheme TemplateScheme `json:"scheme" pflag:",Templating scheme to use. Supported values are Pod and TaskExecution."` + ShowWhilePending bool `json:"showWhilePending" pflag:",If true, the log link will be shown even if the task is in a pending state."` + HideOnceFinished bool `json:"hideOnceFinished" pflag:",If true, the log link will be hidden once the task has finished."` } diff --git a/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go b/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go index e5481ecfbd..19aae6ba7c 100644 --- a/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go +++ b/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go @@ -200,9 +200,11 @@ func (p TemplateLogPlugin) GetTaskLogs(input Input) (Output, error) { taskLogs := make([]*core.TaskLog, 0, len(p.TemplateURIs)) for _, templateURI := range p.TemplateURIs { taskLogs = append(taskLogs, &core.TaskLog{ - Uri: replaceAll(templateURI, templateVars), - Name: p.DisplayName + input.LogName, - MessageFormat: p.MessageFormat, + Uri: replaceAll(templateURI, templateVars), + Name: p.DisplayName + input.LogName, + MessageFormat: p.MessageFormat, + ShowWhilePending: p.ShowWhilePending, + HideOnceFinished: p.HideOnceFinished, }) } diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go index 8257f00341..d3b4ab32f1 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go @@ -291,42 +291,44 @@ func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s OccurredAt: &occurredAt, } - // There is a short period between the `DaskJob` resource being created and `Status.JobStatus` being set by the `dask-operator`. - // In that period, the `JobStatus` will be an empty string. We're treating this as Initializing/Queuing. - isQueued := status == "" || - status == daskAPI.DaskJobCreated || - status == daskAPI.DaskJobClusterCreated - - if !isQueued { - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() - o, err := logPlugin.GetTaskLogs( - tasklog.Input{ - Namespace: job.ObjectMeta.Namespace, - PodName: job.Status.JobRunnerPodName, - LogName: "(User logs)", - TaskExecutionID: taskExecID, - }, - ) - if err != nil { - return pluginsCore.PhaseInfoUndefined, err - } - info.Logs = o.TaskLogs + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() + o, err := logPlugin.GetTaskLogs( + tasklog.Input{ + Namespace: job.ObjectMeta.Namespace, + PodName: job.Status.JobRunnerPodName, + LogName: "(User logs)", + TaskExecutionID: taskExecID, + }, + ) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err } + info.Logs = o.TaskLogs + + var phaseInfo pluginsCore.PhaseInfo switch status { case "": - return pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "unknown", &info), nil + phaseInfo = pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "unknown", &info) case daskAPI.DaskJobCreated: - return pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "job created", &info), nil + phaseInfo = pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "job created", &info) case daskAPI.DaskJobClusterCreated: - return pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "cluster created", &info), nil + phaseInfo = pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "cluster created", &info) case daskAPI.DaskJobFailed: reason := "Dask Job failed" - return pluginsCore.PhaseInfoRetryableFailure(errors.DownstreamSystemError, reason, &info), nil + phaseInfo = pluginsCore.PhaseInfoRetryableFailure(errors.DownstreamSystemError, reason, &info) case daskAPI.DaskJobSuccessful: - return pluginsCore.PhaseInfoSuccess(&info), nil + phaseInfo = pluginsCore.PhaseInfoSuccess(&info) + default: + phaseInfo = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &info) } - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &info), nil + + phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) + if phaseVersionUpdateErr != nil { + return phaseInfo, phaseVersionUpdateErr + } + + return phaseInfo, nil } func (daskResourceHandler) GetProperties() k8s.PluginProperties { diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index 616312ca12..fdb3e74182 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -2,6 +2,7 @@ package dask import ( "context" + "reflect" "testing" "time" @@ -147,7 +148,7 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTem } } -func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool) pluginsCore.TaskExecutionContext { +func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} @@ -198,6 +199,18 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc overrides.OnGetContainerImage().Return("") taskExecutionMetadata.OnGetOverrides().Return(overrides) taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } @@ -205,7 +218,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -316,7 +329,7 @@ func TestBuildResourceDaskCustomImages(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate(customImage, nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -349,7 +362,7 @@ func TestBuildResourceDaskDefaultResoureRequirements(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -406,7 +419,7 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", &protobufResources, "") - taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -461,7 +474,7 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -495,7 +508,7 @@ func TestBuildResouceDaskUsePodTemplate(t *testing.T) { flytek8s.DefaultPodTemplateStore.Store(podTemplate) daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, podTemplateName) - taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{}) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -615,7 +628,7 @@ func TestBuildResourceDaskExtendedResources(t *testing.T) { t.Run(f.name, func(t *testing.T) { taskTemplate := dummyDaskTaskTemplate("", nil, "") taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false) + taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false, k8s.PluginState{}) daskResourceHandler := daskResourceHandler{} r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -681,7 +694,7 @@ func TestBuildIdentityResourceDask(t *testing.T) { } taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) identityResources, err := daskResourceHandler.BuildIdentityResource(context.TODO(), taskContext.TaskExecutionMetadata()) if err != nil { panic(err) @@ -694,27 +707,27 @@ func TestGetTaskPhaseDask(t *testing.T) { ctx := context.TODO() taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false) + taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob("")) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing) assert.NotNil(t, taskPhase.Info()) - assert.Nil(t, taskPhase.Info().Logs) + assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobCreated)) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing) assert.NotNil(t, taskPhase.Info()) - assert.Nil(t, taskPhase.Info().Logs) + assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobClusterCreated)) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseInitializing) assert.NotNil(t, taskPhase.Info()) - assert.Nil(t, taskPhase.Info().Logs) + assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) taskPhase, err = daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobRunning)) @@ -738,3 +751,21 @@ func TestGetTaskPhaseDask(t *testing.T) { assert.NotNil(t, taskPhase.Info().Logs) assert.Nil(t, err) } + +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + daskResourceHandler := daskResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskTemplate := dummyDaskTaskTemplate("", nil, "") + taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, pluginState) + + taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 9d2e4a5aec..44604bf3f7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -49,11 +49,11 @@ func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.Jo func GetPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Time, taskPhaseInfo pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error) { if len(currentCondition.Type) == 0 { - return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil + return pluginsCore.PhaseInfoQueuedWithTaskInfo(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated", &taskPhaseInfo), nil } switch currentCondition.Type { case commonOp.JobCreated: - return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil + return pluginsCore.PhaseInfoQueuedWithTaskInfo(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated", &taskPhaseInfo), nil case commonOp.JobRunning: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil case commonOp.JobSucceeded: @@ -73,7 +73,7 @@ func GetMPIPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Tim taskPhaseInfo pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error) { switch currentCondition.Type { case commonOp.JobCreated: - return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "New job name submitted to MPI operator"), nil + return pluginsCore.PhaseInfoQueuedWithTaskInfo(occurredAt, pluginsCore.DefaultPhaseVersion, "New job name submitted to MPI operator", &taskPhaseInfo), nil case commonOp.JobRunning: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil case commonOp.JobSucceeded: diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 97199025a7..53e4d30ccb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -186,7 +186,14 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext CustomInfo: statusDetails, } - return common.GetMPIPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) + phaseInfo, err := common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) + + phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) + if phaseVersionUpdateErr != nil { + return phaseInfo, phaseVersionUpdateErr + } + + return phaseInfo, err } func init() { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 02224ec8a7..7db8269eaf 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -3,6 +3,7 @@ package mpi import ( "context" "fmt" + "reflect" "testing" "time" @@ -117,7 +118,7 @@ func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate { } } -func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -171,6 +172,18 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskExecutionMetadata.OnGetConsoleURL().Return("") taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } @@ -276,7 +289,7 @@ func dummyMPIJobResource(mpiResourceHandler mpiOperatorResourceHandler, mpiObj := dummyMPICustomObj(workers, launcher, slots) taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if err != nil { panic(err) } @@ -303,7 +316,7 @@ func TestBuildResourceMPI(t *testing.T) { mpiObj := dummyMPICustomObj(100, 50, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -339,13 +352,13 @@ func TestBuildResourceMPIForWrongInput(t *testing.T) { mpiObj := dummyMPICustomObj(0, 0, 1) taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj) - _, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + _, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.Error(t, err) mpiObj = dummyMPICustomObj(1, 1, 1) taskTemplate = dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) app, ok := resource.(*kubeflowv1.MPIJob) assert.Nil(t, err) assert.Equal(t, true, ok) @@ -459,7 +472,7 @@ func TestBuildResourceMPIExtendedResources(t *testing.T) { mpiObj := dummyMPICustomObj(100, 50, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride) + taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, k8s.PluginState{}) mpiResourceHandler := mpiOperatorResourceHandler{} r, err := mpiResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -491,7 +504,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, conditionType) } - taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil) + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, k8s.PluginState{}) taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResourceCreator(mpiOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -523,6 +536,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + mpiResourceHandler := mpiOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, pluginState) + + taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, mpiOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -535,7 +565,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning) - taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil) + taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil, k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) @@ -568,7 +598,7 @@ func TestReplicaCounts(t *testing.T) { mpiObj := dummyMPICustomObj(test.workerReplicaCount, test.launcherReplicaCount, 1) taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, resource) @@ -692,7 +722,7 @@ func TestBuildResourceMPIV1(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -767,7 +797,7 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -867,7 +897,7 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) { taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -883,7 +913,7 @@ func TestGetReplicaCount(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} tfObj := dummyMPICustomObj(1, 1, 0) taskTemplate := dummyMPITaskTemplate("the job", tfObj) - resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) MPIJob, ok := resource.(*kubeflowv1.MPIJob) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 6d0bad4ecd..8084b75b4c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -205,7 +205,14 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont CustomInfo: statusDetails, } - return common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) + phaseInfo, err := common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) + + phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) + if phaseVersionUpdateErr != nil { + return phaseInfo, phaseVersionUpdateErr + } + + return phaseInfo, err } func init() { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 70fcdcdc5c..6284b4d8f3 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -3,6 +3,7 @@ package pytorch import ( "context" "fmt" + "reflect" "testing" "time" @@ -123,7 +124,7 @@ func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate } } -func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string) pluginsCore.TaskExecutionContext { +func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, containerImage string, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -177,6 +178,18 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate, resources *corev1. taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskExecutionMetadata.OnGetConsoleURL().Return("") taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } @@ -281,7 +294,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl ptObj := dummyPytorchCustomObj(workers) taskTemplate := dummyPytorchTaskTemplate("job1", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) if err != nil { panic(err) } @@ -309,7 +322,7 @@ func TestBuildResourcePytorchElastic(t *testing.T) { ptObj := dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}) taskTemplate := dummyPytorchTaskTemplate("job2", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -352,7 +365,7 @@ func TestBuildResourcePytorch(t *testing.T) { ptObj := dummyPytorchCustomObj(100) taskTemplate := dummyPytorchTaskTemplate("job3", ptObj) - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -434,7 +447,7 @@ func TestBuildResourcePytorchContainerImage(t *testing.T) { for _, f := range fixtures { t.Run(tCfg.name+" "+f.name, func(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin) - taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride) + taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, nil, f.containerImageOverride, k8s.PluginState{}) pytorchResourceHandler := pytorchOperatorResourceHandler{} r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) @@ -576,7 +589,7 @@ func TestBuildResourcePytorchExtendedResources(t *testing.T) { t.Run(tCfg.name+" "+f.name, func(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job", tCfg.plugin) taskTemplate.ExtendedResources = f.extendedResourcesBase - taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "") + taskContext := dummyPytorchTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, "", k8s.PluginState{}) pytorchResourceHandler := pytorchOperatorResourceHandler{} r, err := pytorchResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) @@ -609,7 +622,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyPytorchJobResource(pytorchResourceHandler, 2, conditionType) } - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "", k8s.PluginState{}) taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -641,6 +654,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(2)), resourceRequirements, nil, "", pluginState) + + taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, taskCtx, dummyPytorchJobResource(pytorchResourceHandler, 4, commonOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -652,7 +682,7 @@ func TestGetLogs(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) @@ -672,7 +702,7 @@ func TestGetLogsElastic(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "") + taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil, "", k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) @@ -703,7 +733,7 @@ func TestReplicaCounts(t *testing.T) { ptObj := dummyPytorchCustomObj(test.workerReplicaCount) taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, res) @@ -821,7 +851,7 @@ func TestBuildResourcePytorchV1(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -883,7 +913,7 @@ func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -965,7 +995,7 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -1073,7 +1103,7 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) taskTemplate.TaskTypeVersion = 1 - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, res) @@ -1108,7 +1138,7 @@ func TestBuildResourcePytorchV1WithElastic(t *testing.T) { taskTemplate.TaskTypeVersion = 1 pytorchResourceHandler := pytorchOperatorResourceHandler{} - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -1157,7 +1187,7 @@ func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) { taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 - _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.Error(t, err) } } @@ -1175,7 +1205,7 @@ func TestGetReplicaCount(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} tfObj := dummyPytorchCustomObj(1) taskTemplate := dummyPytorchTaskTemplate("the job", tfObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "")) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) PytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index d69fd30b01..93b4d91cd2 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -185,7 +185,14 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginC CustomInfo: statusDetails, } - return common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) + phaseInfo, err := common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) + + phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) + if phaseVersionUpdateErr != nil { + return phaseInfo, phaseVersionUpdateErr + } + + return phaseInfo, err } func init() { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 0cf0f34fd2..8206bda130 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -3,6 +3,7 @@ package tensorflow import ( "context" "fmt" + "reflect" "testing" "time" @@ -118,7 +119,7 @@ func dummyTensorFlowTaskTemplate(id string, args ...interface{}) *core.TaskTempl } } -func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext { +func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -172,6 +173,18 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *core taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskExecutionMetadata.OnGetConsoleURL().Return("") taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } @@ -277,7 +290,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if err != nil { panic(err) } @@ -302,7 +315,7 @@ func TestGetReplicaCount(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tfObj := dummyTensorFlowCustomObj(1, 0, 0, 0) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) tensorflowJob, ok := resource.(*kubeflowv1.TFJob) @@ -320,7 +333,7 @@ func TestBuildResourceTensorFlow(t *testing.T) { tfObj := dummyTensorFlowCustomObj(100, 50, 1, 1) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -515,7 +528,7 @@ func TestBuildResourceTensorFlowExtendedResources(t *testing.T) { taskTemplate := *tCfg.taskTemplate taskTemplate.ExtendedResources = f.extendedResourcesBase tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - taskContext := dummyTensorFlowTaskContext(&taskTemplate, f.resources, f.extendedResourcesOverride) + taskContext := dummyTensorFlowTaskContext(&taskTemplate, f.resources, f.extendedResourcesOverride, k8s.PluginState{}) r, err := tensorflowResourceHandler.BuildResource(context.TODO(), taskContext) assert.NoError(t, err) assert.NotNil(t, r) @@ -548,7 +561,7 @@ func TestGetTaskPhase(t *testing.T) { return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, 1, conditionType) } - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil, k8s.PluginState{}) taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, taskCtx, dummyTensorFlowJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -580,6 +593,23 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil, pluginState) + + taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, taskCtx, dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, 1, commonOp.JobCreated)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetLogs(t *testing.T) { assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ IsKubernetesEnabled: true, @@ -593,7 +623,7 @@ func TestGetLogs(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, evaluatorReplicas, commonOp.JobRunning) - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil, k8s.PluginState{}) jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, false, workers, psReplicas, chiefReplicas, evaluatorReplicas) assert.NoError(t, err) @@ -640,7 +670,7 @@ func TestReplicaCounts(t *testing.T) { tfObj := dummyTensorFlowCustomObj(test.workerReplicaCount, test.psReplicaCount, test.chiefReplicaCount, test.evaluatorReplicaCount) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) if test.expectError { assert.Error(t, err) assert.Nil(t, resource) @@ -855,7 +885,7 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -944,7 +974,7 @@ func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1 with only worker replica", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) @@ -1057,7 +1087,7 @@ func TestBuildResourceTensorFlowV1ResourceTolerations(t *testing.T) { taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) taskTemplate.TaskTypeVersion = 1 - resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{})) assert.NoError(t, err) assert.NotNil(t, resource) diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go index f72d4eb1d7..2a08cd0e6c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go @@ -172,7 +172,7 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin } taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() - if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { + if pod.Status.Phase != v1.PodUnknown { taskLogs, err := logs.GetLogsForContainerInPod(ctx, logPlugin, taskExecID, pod, 0, logSuffix, extraLogTemplateVarsByScheme, taskTemplate) if err != nil { return pluginsCore.PhaseInfoUndefined, err @@ -187,9 +187,9 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin case v1.PodFailed: phaseInfo, err = flytek8s.DemystifyFailure(pod.Status, info) case v1.PodPending: - phaseInfo, err = flytek8s.DemystifyPending(pod.Status) + phaseInfo, err = flytek8s.DemystifyPending(pod.Status, info) case v1.PodReasonUnschedulable: - phaseInfo = pluginsCore.PhaseInfoQueued(transitionOccurredAt, pluginsCore.DefaultPhaseVersion, "pod unschedulable") + phaseInfo = pluginsCore.PhaseInfoQueuedWithTaskInfo(transitionOccurredAt, pluginsCore.DefaultPhaseVersion, "pod unschedulable", &info) case v1.PodUnknown: // DO NOTHING default: @@ -236,15 +236,9 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin if err != nil { return pluginsCore.PhaseInfoUndefined, err - } else if phaseInfo.Phase() != pluginsCore.PhaseRunning && phaseInfo.Phase() == pluginState.Phase && - phaseInfo.Version() <= pluginState.PhaseVersion && phaseInfo.Reason() != pluginState.Reason { - - // if we have the same Phase as the previous evaluation and updated the Reason but not the PhaseVersion we must - // update the PhaseVersion so an event is sent to reflect the Reason update. this does not handle the Running - // Phase because the legacy used `DefaultPhaseVersion + 1` which will only increment to 1. - phaseInfo = phaseInfo.WithVersion(pluginState.PhaseVersion + 1) } + k8s.MaybeUpdatePhaseVersion(&phaseInfo, &pluginState) return phaseInfo, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index ff0cfc6cd3..7774c50376 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -558,25 +558,34 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont } if len(rayJob.Status.JobDeploymentStatus) == 0 { - return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling"), nil + return pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "Scheduling", info), nil } + var phaseInfo pluginsCore.PhaseInfo + // KubeRay creates a Ray cluster first, and then submits a Ray job to the cluster switch rayJob.Status.JobDeploymentStatus { case rayv1.JobDeploymentStatusInitializing: - return pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil + phaseInfo, err = pluginsCore.PhaseInfoInitializing(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "cluster is creating", info), nil case rayv1.JobDeploymentStatusRunning: - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil + phaseInfo, err = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil case rayv1.JobDeploymentStatusComplete: - return pluginsCore.PhaseInfoSuccess(info), nil + phaseInfo, err = pluginsCore.PhaseInfoSuccess(info), nil case rayv1.JobDeploymentStatusFailed: failInfo := fmt.Sprintf("Failed to run Ray job %s with error: [%s] %s", rayJob.Name, rayJob.Status.Reason, rayJob.Status.Message) - return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil + phaseInfo, err = pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil default: // We already handle all known deployment status, so this should never happen unless a future version of ray // introduced a new job status. - return pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus) + phaseInfo, err = pluginsCore.PhaseInfoUndefined, fmt.Errorf("unknown job deployment status: %s", rayJob.Status.JobDeploymentStatus) + } + + phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) + if phaseVersionUpdateErr != nil { + return phaseInfo, phaseVersionUpdateErr } + + return phaseInfo, err } func init() { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 28f4749625..65ccfac643 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -2,6 +2,7 @@ package ray import ( "context" + "reflect" "testing" "time" @@ -676,7 +677,7 @@ func TestInjectLogsSidecar(t *testing.T) { } } -func newPluginContext() k8s.PluginContext { +func newPluginContext(pluginState k8s.PluginState) k8s.PluginContext { plg := &mocks2.PluginContext{} taskExecID := &mocks.TaskExecutionID{} @@ -703,6 +704,19 @@ func newPluginContext() k8s.PluginContext { tskCtx := &mocks.TaskExecutionMetadata{} tskCtx.OnGetTaskExecutionID().Return(taskExecID) plg.OnTaskExecutionMetadata().Return(tskCtx) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + plg.OnPluginStateReader().Return(&pluginStateReaderMock) + return plg } @@ -720,7 +734,7 @@ func init() { func TestGetTaskPhase(t *testing.T) { ctx := context.Background() rayJobResourceHandler := rayJobResourceHandler{} - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { rayJobPhase rayv1.JobDeploymentStatus @@ -751,8 +765,28 @@ func TestGetTaskPhase(t *testing.T) { } } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + rayJobResourceHandler := rayJobResourceHandler{} + + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + pluginCtx := newPluginContext(pluginState) + + rayObject := &rayv1.RayJob{} + rayObject.Status.JobDeploymentStatus = rayv1.JobDeploymentStatusInitializing + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + + assert.NoError(t, err) + assert.Equal(t, phaseInfo.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func TestGetEventInfo_LogTemplates(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob @@ -851,7 +885,7 @@ func TestGetEventInfo_LogTemplates(t *testing.T) { } func TestGetEventInfo_LogTemplates_V1(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob @@ -950,7 +984,7 @@ func TestGetEventInfo_LogTemplates_V1(t *testing.T) { } func TestGetEventInfo_DashboardURL(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob @@ -1002,7 +1036,7 @@ func TestGetEventInfo_DashboardURL(t *testing.T) { } func TestGetEventInfo_DashboardURL_V1(t *testing.T) { - pluginCtx := newPluginContext() + pluginCtx := newPluginContext(k8s.PluginState{}) testCases := []struct { name string rayJob rayv1.RayJob diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 22240e9e45..8b766a391a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -322,39 +322,13 @@ func (sparkResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx p } func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkApplication) (*pluginsCore.TaskInfo, error) { - state := sj.Status.AppState.State - isQueued := state == sparkOp.NewState || - state == sparkOp.PendingSubmissionState || - state == sparkOp.SubmittedState sparkConfig := GetSparkConfig() taskLogs := make([]*core.TaskLog, 0, 3) taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() - if !isQueued { - if sj.Status.DriverInfo.PodName != "" { - p, err := logs.InitializeLogPlugins(&sparkConfig.LogConfig.Mixed) - if err != nil { - return nil, err - } - - if p != nil { - o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Status.DriverInfo.PodName, - Namespace: sj.Namespace, - LogName: "(Driver Logs)", - TaskExecutionID: taskExecID, - }) - - if err != nil { - return nil, err - } - - taskLogs = append(taskLogs, o.TaskLogs...) - } - } - - p, err := logs.InitializeLogPlugins(&sparkConfig.LogConfig.User) + if sj.Status.DriverInfo.PodName != "" { + p, err := logs.InitializeLogPlugins(&sparkConfig.LogConfig.Mixed) if err != nil { return nil, err } @@ -363,7 +337,7 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl o, err := p.GetTaskLogs(tasklog.Input{ PodName: sj.Status.DriverInfo.PodName, Namespace: sj.Namespace, - LogName: "(User Logs)", + LogName: "(Driver Logs)", TaskExecutionID: taskExecID, }) @@ -373,29 +347,49 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl taskLogs = append(taskLogs, o.TaskLogs...) } + } + + p, err := logs.InitializeLogPlugins(&sparkConfig.LogConfig.User) + if err != nil { + return nil, err + } + + if p != nil { + o, err := p.GetTaskLogs(tasklog.Input{ + PodName: sj.Status.DriverInfo.PodName, + Namespace: sj.Namespace, + LogName: "(User Logs)", + TaskExecutionID: taskExecID, + }) - p, err = logs.InitializeLogPlugins(&sparkConfig.LogConfig.System) if err != nil { return nil, err } - if p != nil { - o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Name, - Namespace: sj.Namespace, - LogName: "(System Logs)", - TaskExecutionID: taskExecID, - }) + taskLogs = append(taskLogs, o.TaskLogs...) + } - if err != nil { - return nil, err - } + p, err = logs.InitializeLogPlugins(&sparkConfig.LogConfig.System) + if err != nil { + return nil, err + } - taskLogs = append(taskLogs, o.TaskLogs...) + if p != nil { + o, err := p.GetTaskLogs(tasklog.Input{ + PodName: sj.Name, + Namespace: sj.Namespace, + LogName: "(System Logs)", + TaskExecutionID: taskExecID, + }) + + if err != nil { + return nil, err } + + taskLogs = append(taskLogs, o.TaskLogs...) } - p, err := logs.InitializeLogPlugins(&sparkConfig.LogConfig.AllUser) + p, err = logs.InitializeLogPlugins(&sparkConfig.LogConfig.AllUser) if err != nil { return nil, err } @@ -412,9 +406,13 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl return nil, err } + // "All user" logs are shown already in the queuing and initializing phase. + for _, log := range o.TaskLogs { + log.ShowWhilePending = true + } + taskLogs = append(taskLogs, o.TaskLogs...) } - customInfoMap := make(map[string]string) // Spark UI. @@ -464,21 +462,32 @@ func (sparkResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s. } occurredAt := time.Now() + + var phaseInfo pluginsCore.PhaseInfo + switch app.Status.AppState.State { case sparkOp.NewState: - return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "job queued"), nil + phaseInfo = pluginsCore.PhaseInfoQueuedWithTaskInfo(occurredAt, pluginsCore.DefaultPhaseVersion, "job queued", info) case sparkOp.SubmittedState, sparkOp.PendingSubmissionState: - return pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted", info), nil + phaseInfo = pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "job submitted", info) case sparkOp.FailedSubmissionState: reason := fmt.Sprintf("Spark Job Submission Failed with Error: %s", app.Status.AppState.ErrorMessage) - return pluginsCore.PhaseInfoRetryableFailure(errors.DownstreamSystemError, reason, info), nil + phaseInfo = pluginsCore.PhaseInfoRetryableFailure(errors.DownstreamSystemError, reason, info) case sparkOp.FailedState: reason := fmt.Sprintf("Spark Job Failed with Error: %s", app.Status.AppState.ErrorMessage) - return pluginsCore.PhaseInfoRetryableFailure(errors.DownstreamSystemError, reason, info), nil + phaseInfo = pluginsCore.PhaseInfoRetryableFailure(errors.DownstreamSystemError, reason, info) case sparkOp.CompletedState: - return pluginsCore.PhaseInfoSuccess(info), nil + phaseInfo = pluginsCore.PhaseInfoSuccess(info) + default: + phaseInfo = pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info) } - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil + + phaseVersionUpdateErr := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext) + if phaseVersionUpdateErr != nil { + return phaseInfo, phaseVersionUpdateErr + } + + return phaseInfo, nil } func init() { diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 2b481834db..a560544228 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -3,6 +3,7 @@ package spark import ( "context" "os" + "reflect" "strconv" "testing" @@ -96,7 +97,7 @@ func TestGetEventInfo(t *testing.T) { }, }, })) - taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("blah-1", dummySparkConf), false) + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("blah-1", dummySparkConf), false, k8s.PluginState{}) info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState)) assert.NoError(t, err) assert.Len(t, info.Logs, 6) @@ -118,9 +119,14 @@ func TestGetEventInfo(t *testing.T) { assert.Equal(t, expectedLinks, generatedLinks) info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.SubmittedState)) + generatedLinks = make([]string, 0, len(info.Logs)) + for _, l := range info.Logs { + generatedLinks = append(generatedLinks, l.Uri) + } assert.NoError(t, err) - assert.Len(t, info.Logs, 1) - assert.Equal(t, "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.spark-app-name;streamFilter=typeLogStreamPrefix", info.Logs[0].Uri) + assert.Len(t, info.Logs, 5) + assert.Equal(t, expectedLinks[:5], generatedLinks) // No Spark Driver UI for Submitted state + assert.True(t, info.Logs[4].ShowWhilePending) // All User Logs should be shown while pending assert.NoError(t, setSparkConfig(&Config{ SparkHistoryServerURL: "spark-history.flyte", @@ -166,7 +172,7 @@ func TestGetTaskPhase(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} ctx := context.TODO() - taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("", dummySparkConf), false) + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("", dummySparkConf), false, k8s.PluginState{}) taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.NewState)) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseQueued) @@ -228,6 +234,24 @@ func TestGetTaskPhase(t *testing.T) { assert.Nil(t, err) } +func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { + sparkResourceHandler := sparkResourceHandler{} + ctx := context.TODO() + + pluginState := k8s.PluginState{ + Phase: pluginsCore.PhaseInitializing, + PhaseVersion: pluginsCore.DefaultPhaseVersion, + Reason: "task submitted to K8s", + } + + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("", dummySparkConf), false, pluginState) + + taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.SubmittedState)) + + assert.NoError(t, err) + assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) +} + func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication { return &sj.SparkApplication{ @@ -347,7 +371,7 @@ func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec * } } -func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) pluginsCore.TaskExecutionContext { +func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} inputReader.OnGetInputPrefixPath().Return("/input/prefix") @@ -407,6 +431,18 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) taskExecutionMetadata.On("GetK8sServiceAccount").Return("new-val") taskExecutionMetadata.On("GetConsoleURL").Return("") taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) + + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = pluginState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx } @@ -558,7 +594,7 @@ func TestBuildResourceContainer(t *testing.T) { defaultConfig := defaultPluginConfig() assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) - resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true)) + resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})) assert.Nil(t, err) assert.NotNil(t, resource) @@ -706,7 +742,7 @@ func TestBuildResourceContainer(t *testing.T) { dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4" taskTemplate = dummySparkTaskTemplateContainer("blah-1", dummyConfWithRequest) - resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) + resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false, k8s.PluginState{})) assert.Nil(t, err) assert.NotNil(t, resource) sparkApp, ok = resource.(*sj.SparkApplication) @@ -716,7 +752,7 @@ func TestBuildResourceContainer(t *testing.T) { assert.Equal(t, dummyConfWithRequest["spark.kubernetes.executor.request.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.executor.limit.cores"]) // Case 3: Interruptible False - resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) + resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false, k8s.PluginState{})) assert.Nil(t, err) assert.NotNil(t, resource) sparkApp, ok = resource.(*sj.SparkApplication) @@ -764,7 +800,7 @@ func TestBuildResourceContainer(t *testing.T) { // Case 4: Invalid Spark Task-Template taskTemplate.Custom = nil - resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) + resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false, k8s.PluginState{})) assert.NotNil(t, err) assert.Nil(t, resource) } @@ -784,7 +820,7 @@ func TestBuildResourcePodTemplate(t *testing.T) { taskTemplate.GetK8SPod() sparkResourceHandler := sparkResourceHandler{} - taskCtx := dummySparkTaskContext(taskTemplate, true) + taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{}) resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx) assert.Nil(t, err) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 99a83aeb3b..5470247ab7 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -265,7 +265,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase switch resource.Phase { case flyteIdl.TaskExecution_QUEUED: - return core.PhaseInfoQueuedWithTaskInfo(core.DefaultPhaseVersion, resource.Message, taskInfo), nil + return core.PhaseInfoQueuedWithTaskInfo(time.Now(), core.DefaultPhaseVersion, resource.Message, taskInfo), nil case flyteIdl.TaskExecution_WAITING_FOR_RESOURCES: return core.PhaseInfoWaitingForResourcesInfo(time.Now(), core.DefaultPhaseVersion, resource.Message, taskInfo), nil case flyteIdl.TaskExecution_INITIALIZING: diff --git a/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go index 6661550530..ad7da5f042 100644 --- a/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/bigquery/plugin.go @@ -278,7 +278,7 @@ func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase co switch resource.Status.State { case bigqueryStatusPending: - return core.PhaseInfoQueuedWithTaskInfo(version, "Query is PENDING", taskInfo), nil + return core.PhaseInfoQueuedWithTaskInfo(time.Now(), version, "Query is PENDING", taskInfo), nil case bigqueryStatusRunning: return core.PhaseInfoRunning(version, taskInfo), nil diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager_test.go b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager_test.go index 73d00a6062..a2bcb57014 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager_test.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/plugin_manager_test.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -761,13 +762,15 @@ func TestPluginManager_Handle_PluginState(t *testing.T) { }, } - phaseInfoQueued := pluginsCore.PhaseInfoQueuedWithTaskInfo(pluginStateQueued.K8sPluginState.PhaseVersion, pluginStateQueued.K8sPluginState.Reason, nil) + phaseInfoQueued := pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginStateQueued.K8sPluginState.PhaseVersion, pluginStateQueued.K8sPluginState.Reason, nil) phaseInfoQueuedVersion1 := pluginsCore.PhaseInfoQueuedWithTaskInfo( + time.Now(), pluginStateQueuedVersion1.K8sPluginState.PhaseVersion, pluginStateQueuedVersion1.K8sPluginState.Reason, nil, ) phaseInfoQueuedReasonBar := pluginsCore.PhaseInfoQueuedWithTaskInfo( + time.Now(), pluginStateQueuedReasonBar.K8sPluginState.PhaseVersion, pluginStateQueuedReasonBar.K8sPluginState.Reason, nil,