diff --git a/flyteplugins/go/tasks/logs/logging_utils.go b/flyteplugins/go/tasks/logs/logging_utils.go index 0ca515d7c8..6af1889e9f 100644 --- a/flyteplugins/go/tasks/logs/logging_utils.go +++ b/flyteplugins/go/tasks/logs/logging_utils.go @@ -8,6 +8,7 @@ import ( v1 "k8s.io/api/core/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flytestdlib/logger" ) @@ -18,7 +19,7 @@ type logPlugin struct { } // Internal -func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, taskExecID *core.TaskExecutionIdentifier, pod *v1.Pod, index uint32, nameSuffix string, extraLogTemplateVarsByScheme *tasklog.TemplateVarsByScheme) ([]*core.TaskLog, error) { +func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, taskExecID pluginsCore.TaskExecutionID, pod *v1.Pod, index uint32, nameSuffix string, extraLogTemplateVarsByScheme *tasklog.TemplateVarsByScheme) ([]*core.TaskLog, error) { if logPlugin == nil { return nil, nil } @@ -53,7 +54,7 @@ func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, tas PodRFC3339FinishTime: time.Unix(finishTime, 0).Format(time.RFC3339), PodUnixStartTime: startTime, PodUnixFinishTime: finishTime, - TaskExecutionIdentifier: taskExecID, + TaskExecutionID: taskExecID, ExtraTemplateVarsByScheme: extraLogTemplateVarsByScheme, }, ) diff --git a/flyteplugins/go/tasks/logs/logging_utils_test.go b/flyteplugins/go/tasks/logs/logging_utils_test.go index fbf86b9933..066fdd96c8 100644 --- a/flyteplugins/go/tasks/logs/logging_utils_test.go +++ b/flyteplugins/go/tasks/logs/logging_utils_test.go @@ -10,34 +10,41 @@ import ( v12 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + coreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" ) const podName = "PodName" -var dummyTaskExecID = &core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "n0", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", +func dummyTaskExecID() pluginCore.TaskExecutionID { + tID := &coreMocks.TaskExecutionID{} + tID.OnGetGeneratedName().Return("generated-name") + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Name: "my-task-name", + Project: "my-task-project", + Domain: "my-task-domain", + Version: "1", }, - }, - RetryAttempt: 1, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my-execution-name", + Project: "my-execution-project", + Domain: "my-execution-domain", + }, + }, + RetryAttempt: 1, + }) + tID.OnGetUniqueNodeID().Return("n0-0-n0") + return tID } func TestGetLogsForContainerInPod_NoPlugins(t *testing.T) { logPlugin, err := InitializeLogPlugins(&LogConfig{}) assert.NoError(t, err) - l, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, nil, 0, " Suffix", nil) + l, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), nil, 0, " Suffix", nil) assert.NoError(t, err) assert.Nil(t, l) } @@ -49,7 +56,7 @@ func TestGetLogsForContainerInPod_NoLogs(t *testing.T) { CloudwatchLogGroup: "/kubernetes/flyte-production", }) assert.NoError(t, err) - p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, nil, 0, " Suffix", nil) + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), nil, 0, " Suffix", nil) assert.NoError(t, err) assert.Nil(t, p) } @@ -80,7 +87,7 @@ func TestGetLogsForContainerInPod_BadIndex(t *testing.T) { } pod.Name = podName - p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 1, " Suffix", nil) + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 1, " Suffix", nil) assert.NoError(t, err) assert.Nil(t, p) } @@ -105,7 +112,7 @@ func TestGetLogsForContainerInPod_MissingStatus(t *testing.T) { } pod.Name = podName - p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 1, " Suffix", nil) + p, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 1, " Suffix", nil) assert.NoError(t, err) assert.Nil(t, p) } @@ -135,7 +142,7 @@ func TestGetLogsForContainerInPod_Cloudwatch(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " Suffix", nil) assert.Nil(t, err) assert.Len(t, logs, 1) } @@ -165,7 +172,7 @@ func TestGetLogsForContainerInPod_K8s(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " Suffix", nil) assert.Nil(t, err) assert.Len(t, logs, 1) } @@ -198,7 +205,7 @@ func TestGetLogsForContainerInPod_All(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " Suffix", nil) assert.Nil(t, err) assert.Len(t, logs, 2) } @@ -229,7 +236,7 @@ func TestGetLogsForContainerInPod_Stackdriver(t *testing.T) { } pod.Name = podName - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " Suffix", nil) assert.Nil(t, err) assert.Len(t, logs, 1) } @@ -303,7 +310,7 @@ func assertTestSucceeded(tb testing.TB, config *LogConfig, expectedTaskLogs []*c }, } - logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID, pod, 0, " my-Suffix", nil) + logs, err := GetLogsForContainerInPod(context.TODO(), logPlugin, dummyTaskExecID(), pod, 0, " my-Suffix", nil) assert.Nil(tb, err) assert.Len(tb, logs, len(expectedTaskLogs)) if diff := deep.Equal(logs, expectedTaskLogs); len(diff) > 0 { @@ -337,7 +344,7 @@ func TestGetLogsForContainerInPod_Templates(t *testing.T) { Name: "StackDriver my-Suffix", }, { - Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0/taskId/my-task-name/attempt/1/view/logs", + Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0-0-n0/taskId/my-task-name/attempt/1/view/logs", MessageFormat: core.TaskLog_JSON, Name: "Internal my-Suffix", }, diff --git a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go index 8517b9c385..9ac650baaa 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go @@ -27,6 +27,10 @@ type TaskExecutionID interface { // GetID returns the underlying idl task identifier. GetID() core.TaskExecutionIdentifier + + // GetUniqueNodeID returns the fully-qualified Node ID that is unique within a + // given workflow execution. + GetUniqueNodeID() string } // TaskExecutionMetadata represents any execution information for a Task. It is used to communicate meta information about the diff --git a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_id.go b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_id.go index 7db5590170..44596bf82f 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_id.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_id.go @@ -114,3 +114,35 @@ func (_m *TaskExecutionID) GetID() flyteidlcore.TaskExecutionIdentifier { return r0 } + +type TaskExecutionID_GetUniqueNodeID struct { + *mock.Call +} + +func (_m TaskExecutionID_GetUniqueNodeID) Return(_a0 string) *TaskExecutionID_GetUniqueNodeID { + return &TaskExecutionID_GetUniqueNodeID{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionID) OnGetUniqueNodeID() *TaskExecutionID_GetUniqueNodeID { + c_call := _m.On("GetUniqueNodeID") + return &TaskExecutionID_GetUniqueNodeID{Call: c_call} +} + +func (_m *TaskExecutionID) OnGetUniqueNodeIDMatch(matchers ...interface{}) *TaskExecutionID_GetUniqueNodeID { + c_call := _m.On("GetUniqueNodeID", matchers...) + return &TaskExecutionID_GetUniqueNodeID{Call: c_call} +} + +// GetUniqueNodeID provides a mock function with given fields: +func (_m *TaskExecutionID) GetUniqueNodeID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go index 5703b88d81..60a3833397 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go @@ -227,6 +227,10 @@ func (m mockTaskExecutionIdentifier) GetGeneratedName() string { return "task-exec-name" } +func (m mockTaskExecutionIdentifier) GetUniqueNodeID() string { + return "unique-node-id" +} + func TestDecorateEnvVars(t *testing.T) { ctx := context.Background() ctx = contextutils.WithWorkflowID(ctx, "fake_workflow") diff --git a/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go b/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go index 0ca91c3370..b812221f6d 100644 --- a/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go @@ -4,6 +4,7 @@ import ( "regexp" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" ) //go:generate enumer --type=TemplateScheme --trimprefix=TemplateScheme -json -yaml @@ -42,7 +43,7 @@ type Input struct { PodUnixStartTime int64 PodUnixFinishTime int64 PodUID string - TaskExecutionIdentifier *core.TaskExecutionIdentifier + TaskExecutionID pluginsCore.TaskExecutionID ExtraTemplateVarsByScheme *TemplateVarsByScheme } diff --git a/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go b/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go index 2a68f42cff..77c49d2695 100644 --- a/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go +++ b/flyteplugins/go/tasks/pluginmachinery/tasklog/template.go @@ -114,55 +114,55 @@ func (input Input) templateVarsForScheme(scheme TemplateScheme) TemplateVars { vars = append(vars, input.ExtraTemplateVarsByScheme.Pod...) } case TemplateSchemeTaskExecution: - if input.TaskExecutionIdentifier != nil { - vars = append(vars, TemplateVar{ + taskExecutionIdentifier := input.TaskExecutionID.GetID() + vars = append( + vars, + TemplateVar{ + defaultRegexes.NodeID, + input.TaskExecutionID.GetUniqueNodeID(), + }, + TemplateVar{ defaultRegexes.TaskRetryAttempt, - strconv.FormatUint(uint64(input.TaskExecutionIdentifier.RetryAttempt), 10), - }) - if input.TaskExecutionIdentifier.TaskId != nil { - vars = append( - vars, - TemplateVar{ - defaultRegexes.TaskID, - input.TaskExecutionIdentifier.TaskId.Name, - }, - TemplateVar{ - defaultRegexes.TaskVersion, - input.TaskExecutionIdentifier.TaskId.Version, - }, - TemplateVar{ - defaultRegexes.TaskProject, - input.TaskExecutionIdentifier.TaskId.Project, - }, - TemplateVar{ - defaultRegexes.TaskDomain, - input.TaskExecutionIdentifier.TaskId.Domain, - }, - ) - } - if input.TaskExecutionIdentifier.NodeExecutionId != nil { - vars = append(vars, TemplateVar{ - defaultRegexes.NodeID, - input.TaskExecutionIdentifier.NodeExecutionId.NodeId, - }) - if input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId != nil { - vars = append( - vars, - TemplateVar{ - defaultRegexes.ExecutionName, - input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Name, - }, - TemplateVar{ - defaultRegexes.ExecutionProject, - input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Project, - }, - TemplateVar{ - defaultRegexes.ExecutionDomain, - input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Domain, - }, - ) - } - } + strconv.FormatUint(uint64(taskExecutionIdentifier.RetryAttempt), 10), + }, + ) + if taskExecutionIdentifier.TaskId != nil { + vars = append( + vars, + TemplateVar{ + defaultRegexes.TaskID, + taskExecutionIdentifier.TaskId.Name, + }, + TemplateVar{ + defaultRegexes.TaskVersion, + taskExecutionIdentifier.TaskId.Version, + }, + TemplateVar{ + defaultRegexes.TaskProject, + taskExecutionIdentifier.TaskId.Project, + }, + TemplateVar{ + defaultRegexes.TaskDomain, + taskExecutionIdentifier.TaskId.Domain, + }, + ) + } + if taskExecutionIdentifier.NodeExecutionId != nil && taskExecutionIdentifier.NodeExecutionId.ExecutionId != nil { + vars = append( + vars, + TemplateVar{ + defaultRegexes.ExecutionName, + taskExecutionIdentifier.NodeExecutionId.ExecutionId.Name, + }, + TemplateVar{ + defaultRegexes.ExecutionProject, + taskExecutionIdentifier.NodeExecutionId.ExecutionId.Project, + }, + TemplateVar{ + defaultRegexes.ExecutionDomain, + taskExecutionIdentifier.NodeExecutionId.ExecutionId.Domain, + }, + ) } if gotExtraTemplateVars { vars = append(vars, input.ExtraTemplateVarsByScheme.TaskExecution...) diff --git a/flyteplugins/go/tasks/pluginmachinery/tasklog/template_test.go b/flyteplugins/go/tasks/pluginmachinery/tasklog/template_test.go index e3f03047aa..320ece05a4 100644 --- a/flyteplugins/go/tasks/pluginmachinery/tasklog/template_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/tasklog/template_test.go @@ -9,6 +9,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + pluginCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + coreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" ) func TestTemplateLog(t *testing.T) { @@ -38,6 +40,30 @@ func Benchmark_initDefaultRegexes(b *testing.B) { } } +func dummyTaskExecID() pluginCore.TaskExecutionID { + tID := &coreMocks.TaskExecutionID{} + tID.OnGetGeneratedName().Return("generated-name") + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Name: "my-task-name", + Project: "my-task-project", + Domain: "my-task-domain", + Version: "1", + }, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my-execution-name", + Project: "my-execution-project", + Domain: "my-execution-domain", + }, + }, + RetryAttempt: 1, + }) + tID.OnGetUniqueNodeID().Return("n0-0-n0") + return tID +} + func Test_Input_templateVarsForScheme(t *testing.T) { testRegexes := struct { Foo *regexp.Regexp @@ -66,25 +92,8 @@ func Test_Input_templateVarsForScheme(t *testing.T) { PodUnixFinishTime: 12345, } taskExecutionBase := Input{ - LogName: "main_logs", - TaskExecutionIdentifier: &core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "n0", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", - }, - }, - RetryAttempt: 0, - }, + LogName: "main_logs", + TaskExecutionID: dummyTaskExecID(), } tests := []struct { @@ -162,12 +171,12 @@ func Test_Input_templateVarsForScheme(t *testing.T) { nil, TemplateVars{ {defaultRegexes.LogName, "main_logs"}, - {defaultRegexes.TaskRetryAttempt, "0"}, + {defaultRegexes.NodeID, "n0-0-n0"}, + {defaultRegexes.TaskRetryAttempt, "1"}, {defaultRegexes.TaskID, "my-task-name"}, {defaultRegexes.TaskVersion, "1"}, {defaultRegexes.TaskProject, "my-task-project"}, {defaultRegexes.TaskDomain, "my-task-domain"}, - {defaultRegexes.NodeID, "n0"}, {defaultRegexes.ExecutionName, "my-execution-name"}, {defaultRegexes.ExecutionProject, "my-execution-project"}, {defaultRegexes.ExecutionDomain, "my-execution-domain"}, @@ -484,30 +493,13 @@ func TestTemplateLogPlugin_NewTaskLog(t *testing.T) { PodRFC3339FinishTime: "1970-01-01T04:25:45+01:00", PodUnixStartTime: 123, PodUnixFinishTime: 12345, - TaskExecutionIdentifier: &core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "n0", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", - }, - }, - RetryAttempt: 0, - }, + TaskExecutionID: dummyTaskExecID(), }, }, Output{ TaskLogs: []*core.TaskLog{ { - Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0/taskId/my-task-name/attempt/0/view/logs", + Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0-0-n0/taskId/my-task-name/attempt/1/view/logs", MessageFormat: core.TaskLog_JSON, Name: "main_logs", }, @@ -534,24 +526,7 @@ func TestTemplateLogPlugin_NewTaskLog(t *testing.T) { PodRFC3339FinishTime: "1970-01-01T04:25:45+01:00", PodUnixStartTime: 123, PodUnixFinishTime: 12345, - TaskExecutionIdentifier: &core.TaskExecutionIdentifier{ - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Name: "my-task-name", - Project: "my-task-project", - Domain: "my-task-domain", - Version: "1", - }, - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "n0", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Name: "my-execution-name", - Project: "my-execution-project", - Domain: "my-execution-domain", - }, - }, - RetryAttempt: 0, - }, + TaskExecutionID: dummyTaskExecID(), ExtraTemplateVarsByScheme: &TemplateVarsByScheme{ TaskExecution: TemplateVars{ {MustCreateRegex("subtaskExecutionIndex"), "1"}, @@ -564,7 +539,7 @@ func TestTemplateLogPlugin_NewTaskLog(t *testing.T) { Output{ TaskLogs: []*core.TaskLog{ { - Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0/taskId/my-task-name/attempt/0/mappedIndex/1/mappedAttempt/1/view/logs", + Uri: "https://flyte.corp.net/console/projects/my-execution-project/domains/my-execution-domain/executions/my-execution-name/nodeId/n0-0-n0/taskId/my-task-name/attempt/0/mappedIndex/1/mappedAttempt/1/view/logs", MessageFormat: core.TaskLog_JSON, Name: "main_logs", }, diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go index 65050f5bb2..eb27aec3ce 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go @@ -298,13 +298,13 @@ func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s status == daskAPI.DaskJobClusterCreated if !isQueued { - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() o, err := logPlugin.GetTaskLogs( tasklog.Input{ - Namespace: job.ObjectMeta.Namespace, - PodName: job.Status.JobRunnerPodName, - LogName: "(User logs)", - TaskExecutionIdentifier: &taskExecID, + Namespace: job.ObjectMeta.Namespace, + PodName: job.Status.JobRunnerPodName, + LogName: "(User logs)", + TaskExecutionID: taskExecID, }, ) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index e0903d02a3..594767b4b4 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -98,7 +98,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v namespace := objectMeta.Namespace taskLogs := make([]*core.TaskLog, 0, 10) - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig()) @@ -120,14 +120,14 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v if taskType == PytorchTaskType && hasMaster { masterTaskLog, masterErr := logPlugin.GetTaskLogs( tasklog.Input{ - PodName: name + "-master-0", - Namespace: namespace, - LogName: "master", - PodRFC3339StartTime: RFC3999StartTime, - PodRFC3339FinishTime: RFC3999FinishTime, - PodUnixStartTime: startTime, - PodUnixFinishTime: finishTime, - TaskExecutionIdentifier: &taskExecID, + PodName: name + "-master-0", + Namespace: namespace, + LogName: "master", + PodRFC3339StartTime: RFC3999StartTime, + PodRFC3339FinishTime: RFC3999FinishTime, + PodUnixStartTime: startTime, + PodUnixFinishTime: finishTime, + TaskExecutionID: taskExecID, }, ) if masterErr != nil { @@ -139,13 +139,13 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v // get all workers log for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-worker-%d", workerIndex), - Namespace: namespace, - PodRFC3339StartTime: RFC3999StartTime, - PodRFC3339FinishTime: RFC3999FinishTime, - PodUnixStartTime: startTime, - PodUnixFinishTime: finishTime, - TaskExecutionIdentifier: &taskExecID, + PodName: name + fmt.Sprintf("-worker-%d", workerIndex), + Namespace: namespace, + PodRFC3339StartTime: RFC3999StartTime, + PodRFC3339FinishTime: RFC3999FinishTime, + PodUnixStartTime: startTime, + PodUnixFinishTime: finishTime, + TaskExecutionID: taskExecID, }) if err != nil { return nil, err @@ -160,9 +160,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v // get all parameter servers logs for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), - Namespace: namespace, - TaskExecutionIdentifier: &taskExecID, + PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), + Namespace: namespace, + TaskExecutionID: taskExecID, }) if err != nil { return nil, err @@ -172,9 +172,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v // get chief worker log, and the max number of chief worker is 1 if chiefReplicasCount != 0 { chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), - Namespace: namespace, - TaskExecutionIdentifier: &taskExecID, + PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), + Namespace: namespace, + TaskExecutionID: taskExecID, }) if err != nil { return nil, err @@ -184,9 +184,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v // get evaluator log, and the max number of evaluator is 1 if evaluatorReplicasCount != 0 { evaluatorReplicasCount, err := logPlugin.GetTaskLogs(tasklog.Input{ - PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0), - Namespace: namespace, - TaskExecutionIdentifier: &taskExecID, + PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0), + Namespace: namespace, + TaskExecutionID: taskExecID, }) if err != nil { return nil, err diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go index d1ba98bcaa..11de877021 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go @@ -164,9 +164,9 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin ReportedAt: &reportedAt, } - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown { - taskLogs, err := logs.GetLogsForContainerInPod(ctx, logPlugin, &taskExecID, pod, 0, logSuffix, extraLogTemplateVarsByScheme) + taskLogs, err := logs.GetLogsForContainerInPod(ctx, logPlugin, taskExecID, pod, 0, logSuffix, extraLogTemplateVarsByScheme) if err != nil { return pluginsCore.PhaseInfoUndefined, err } @@ -211,7 +211,7 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin } else { // if the primary container annotation exists, we use the status of the specified container phaseInfo = flytek8s.DeterminePrimaryContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &info) - if phaseInfo.Phase() == pluginsCore.PhasePermanentFailure && phaseInfo.Err() != nil && + if phaseInfo.Phase() == pluginsCore.PhasePermanentFailure && phaseInfo.Err() != nil && phaseInfo.Err().GetCode() == flytek8s.PrimaryContainerNotFound { // if the primary container status is not found ensure that the primary container exists. // note: it should be impossible for the primary container to not exist at this point. diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index c1dcc2b8e2..cc8d198334 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -444,10 +444,10 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs // RayJob CRD does not include the name of the worker or head pod for now - taskID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() logOutput, err := logPlugin.GetTaskLogs(tasklog.Input{ - Namespace: rayJob.Namespace, - TaskExecutionIdentifier: &taskID, + Namespace: rayJob.Namespace, + TaskExecutionID: taskExecID, }) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index d0506ccfb5..e5fd14478a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -329,7 +329,7 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl sparkConfig := GetSparkConfig() taskLogs := make([]*core.TaskLog, 0, 3) - taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID() + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() if !isQueued { if sj.Status.DriverInfo.PodName != "" { @@ -340,10 +340,10 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl if p != nil { o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Status.DriverInfo.PodName, - Namespace: sj.Namespace, - LogName: "(Driver Logs)", - TaskExecutionIdentifier: &taskExecID, + PodName: sj.Status.DriverInfo.PodName, + Namespace: sj.Namespace, + LogName: "(Driver Logs)", + TaskExecutionID: taskExecID, }) if err != nil { @@ -361,10 +361,10 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl if p != nil { o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Status.DriverInfo.PodName, - Namespace: sj.Namespace, - LogName: "(User Logs)", - TaskExecutionIdentifier: &taskExecID, + PodName: sj.Status.DriverInfo.PodName, + Namespace: sj.Namespace, + LogName: "(User Logs)", + TaskExecutionID: taskExecID, }) if err != nil { @@ -381,10 +381,10 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl if p != nil { o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Name, - Namespace: sj.Namespace, - LogName: "(System Logs)", - TaskExecutionIdentifier: &taskExecID, + PodName: sj.Name, + Namespace: sj.Namespace, + LogName: "(System Logs)", + TaskExecutionID: taskExecID, }) if err != nil { @@ -402,10 +402,10 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl if p != nil { o, err := p.GetTaskLogs(tasklog.Input{ - PodName: sj.Name, - Namespace: sj.Namespace, - LogName: "(Spark-Submit/All User Logs)", - TaskExecutionIdentifier: &taskExecID, + PodName: sj.Name, + Namespace: sj.Namespace, + LogName: "(Spark-Submit/All User Logs)", + TaskExecutionID: taskExecID, }) if err != nil { diff --git a/flyteplugins/tests/end_to_end.go b/flyteplugins/tests/end_to_end.go index 603b4d3a30..037ae877d9 100644 --- a/flyteplugins/tests/end_to_end.go +++ b/flyteplugins/tests/end_to_end.go @@ -136,6 +136,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i }, RetryAttempt: 0, }) + tID.OnGetUniqueNodeID().Return("unique-node-id") overrides := &coreMocks.TaskOverrides{} overrides.OnGetConfig().Return(&v1.ConfigMap{Data: map[string]string{ diff --git a/flytepropeller/pkg/controller/nodes/task/taskexec_context.go b/flytepropeller/pkg/controller/nodes/task/taskexec_context.go index 8b819c79eb..4a6f9750a2 100644 --- a/flytepropeller/pkg/controller/nodes/task/taskexec_context.go +++ b/flytepropeller/pkg/controller/nodes/task/taskexec_context.go @@ -33,14 +33,19 @@ var ( const IDMaxLength = 50 type taskExecutionID struct { - execName string - id *core.TaskExecutionIdentifier + execName string + id *core.TaskExecutionIdentifier + uniqueNodeID string } func (te taskExecutionID) GetID() core.TaskExecutionIdentifier { return *te.id } +func (te taskExecutionID) GetUniqueNodeID() string { + return te.uniqueNodeID +} + func (te taskExecutionID) GetGeneratedName() string { return te.execName } @@ -291,11 +296,15 @@ func (t *Handler) newTaskExecutionContext(ctx context.Context, nCtx interfaces.N NodeExecutionContext: nCtx, tm: taskExecutionMetadata{ NodeExecutionMetadata: nCtx.NodeExecutionMetadata(), - taskExecID: taskExecutionID{execName: uniqueID, id: id}, - o: nCtx.Node(), - maxAttempts: maxAttempts, - platformResources: convertTaskResourcesToRequirements(nCtx.ExecutionContext().GetExecutionConfig().TaskResources), - environmentVariables: nCtx.ExecutionContext().GetExecutionConfig().EnvironmentVariables, + taskExecID: taskExecutionID{ + execName: uniqueID, + id: id, + uniqueNodeID: currentNodeUniqueID, + }, + o: nCtx.Node(), + maxAttempts: maxAttempts, + platformResources: convertTaskResourcesToRequirements(nCtx.ExecutionContext().GetExecutionConfig().TaskResources), + environmentVariables: nCtx.ExecutionContext().GetExecutionConfig().EnvironmentVariables, }, rm: resourcemanager.GetTaskResourceManager( t.resourceManager, resourceNamespacePrefix, id), diff --git a/flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go b/flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go index a8cbc7d6d3..b0106e8ff9 100644 --- a/flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go +++ b/flytepropeller/pkg/controller/nodes/task/taskexec_context_test.go @@ -20,8 +20,10 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors" mocks2 "github.com/flyteorg/flyte/flytepropeller/pkg/controller/executors/mocks" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/handler" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" nodeMocks "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces/mocks" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/codex" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/resourcemanager" @@ -31,15 +33,26 @@ import ( "github.com/flyteorg/flyte/flytestdlib/storage" ) -func TestHandler_newTaskExecutionContext(t *testing.T) { - wfExecID := &core.WorkflowExecutionIdentifier{ +type dummyPluginState struct { + A int +} + +var ( + wfExecID = &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", Name: "name", } + taskID = &core.Identifier{} + nodeID = "n1" + resources = &corev1.ResourceRequirements{ + Requests: make(corev1.ResourceList), + Limits: make(corev1.ResourceList), + } + dummyPluginStateA = 45 +) - nodeID := "n1" - +func dummyNodeExecutionContext(t *testing.T, parentInfo executors.ImmutableParentInfo, eventVersion v1alpha1.EventVersion) interfaces.NodeExecutionContext { nm := &nodeMocks.NodeExecutionMetadata{} nm.OnGetAnnotations().Return(map[string]string{}) nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ @@ -55,7 +68,6 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { Name: "name", }) - taskID := &core.Identifier{} tr := &nodeMocks.TaskReader{} tr.OnGetTaskID().Return(taskID) @@ -63,12 +75,8 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { ns.OnGetDataDir().Return("data-dir") ns.OnGetOutputDir().Return("output-dir") - res := &corev1.ResourceRequirements{ - Requests: make(corev1.ResourceList), - Limits: make(corev1.ResourceList), - } n := &flyteMocks.ExecutableNode{} - n.OnGetResources().Return(res) + n.OnGetResources().Return(resources) ma := 5 n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) @@ -87,8 +95,8 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { executionContext := &mocks2.ExecutionContext{} executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) - executionContext.OnGetParentInfo().Return(nil) - executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) + executionContext.OnGetParentInfo().Return(parentInfo) + executionContext.OnGetEventVersion().Return(eventVersion) nCtx.OnExecutionContext().Return(executionContext) ds, err := storage.NewDataStore( @@ -101,12 +109,8 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { nCtx.OnDataStore().Return(ds) st := bytes.NewBuffer([]byte{}) - a := 45 - type test struct { - A int - } codex := codex.GobStateCodec{} - assert.NoError(t, codex.Encode(test{A: a}, st)) + assert.NoError(t, codex.Encode(dummyPluginState{A: dummyPluginStateA}, st)) nr := &nodeMocks.NodeStateReader{} nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), @@ -114,28 +118,39 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { nCtx.OnNodeStateReader().Return(nr) nCtx.OnRawOutputPrefix().Return("s3://sandbox/") nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) + return nCtx +} - noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) +func dummyPlugin() pluginCore.Plugin { + p := &pluginCoreMocks.Plugin{} + p.On("GetID").Return("plugin1") + p.OnGetProperties().Return(pluginCore.PluginProperties{}) + return p +} +func dummyHandler() *Handler { + noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) c := &mocks.Client{} - tk := &Handler{ + return &Handler{ catalog: c, secretManager: secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()), resourceManager: noopRm, } +} - p := &pluginCoreMocks.Plugin{} - p.On("GetID").Return("plugin1") - p.OnGetProperties().Return(pluginCore.PluginProperties{}) +func TestHandler_newTaskExecutionContext(t *testing.T) { + nCtx := dummyNodeExecutionContext(t, nil, v1alpha1.EventVersion0) + p := dummyPlugin() + tk := dummyHandler() got, err := tk.newTaskExecutionContext(context.TODO(), nCtx, p) assert.NoError(t, err) assert.NotNil(t, got) - f := &test{} + f := &dummyPluginState{} v, err := got.PluginStateReader().Get(f) assert.NoError(t, err) assert.Equal(t, v, uint8(0)) - assert.Equal(t, f.A, a) + assert.Equal(t, f.A, dummyPluginStateA) // Try writing new state type test2 struct { @@ -151,13 +166,14 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { assert.NotNil(t, got.SecretManager()) assert.NotNil(t, got.OutputWriter()) - assert.Equal(t, got.TaskExecutionMetadata().GetOverrides().GetResources(), res) + assert.Equal(t, got.TaskExecutionMetadata().GetOverrides().GetResources(), resources) assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), "name-n1-1") assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().TaskId, taskID) assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt, uint32(1)) assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetNodeId(), nodeID) assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetExecutionId(), wfExecID) + assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetUniqueNodeID(), nodeID) assert.EqualValues(t, got.ResourceManager().(resourcemanager.TaskResourceManager).GetResourcePoolInfo(), make([]*event.ResourcePoolInfo, 0)) @@ -195,6 +211,22 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { assert.NotNil(t, anotherTaskExecCtx.tr) } +func TestHandler_newTaskExecutionContext_taskExecutionID_WithParentInfo(t *testing.T) { + parentInfo := &mocks2.ImmutableParentInfo{} + parentInfo.OnGetUniqueID().Return("n0") + parentInfo.OnCurrentAttempt().Return(uint32(2)) + + nCtx := dummyNodeExecutionContext(t, parentInfo, v1alpha1.EventVersion1) + p := dummyPlugin() + tk := dummyHandler() + got, err := tk.newTaskExecutionContext(context.TODO(), nCtx, p) + assert.NoError(t, err) + assert.NotNil(t, got) + + assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), "name-n0-2-n1-1") + assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetUniqueNodeID(), "n0-2-n1") +} + func TestGetGeneratedNameWith(t *testing.T) { t.Run("length 0", func(t *testing.T) { tCtx := taskExecutionID{ diff --git a/flytepropeller/pkg/controller/nodes/task/transformer_test.go b/flytepropeller/pkg/controller/nodes/task/transformer_test.go index a26705baee..db89dda3e6 100644 --- a/flytepropeller/pkg/controller/nodes/task/transformer_test.go +++ b/flytepropeller/pkg/controller/nodes/task/transformer_test.go @@ -67,6 +67,7 @@ func TestToTaskExecutionEvent(t *testing.T) { generatedName := "generated_name" tID.OnGetGeneratedName().Return(generatedName) tID.OnGetID().Return(*id) + tID.OnGetUniqueNodeID().Return("unique-node-id") tMeta := &pluginMocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID) @@ -261,6 +262,7 @@ func TestToTaskExecutionEventWithParent(t *testing.T) { generatedName := "generated_name" tID.OnGetGeneratedName().Return(generatedName) tID.OnGetID().Return(*id) + tID.OnGetUniqueNodeID().Return("unique-node-id") tMeta := &pluginMocks.TaskExecutionMetadata{} tMeta.OnGetTaskExecutionID().Return(tID)