Skip to content

Commit

Permalink
Passthrough unique node ID in task execution ID for generating log te…
Browse files Browse the repository at this point in the history
…mplate vars

Signed-off-by: Jeev B <[email protected]>
  • Loading branch information
jeevb committed Nov 8, 2023
1 parent bec7bbb commit 0680e1f
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 111 deletions.
5 changes: 3 additions & 2 deletions flyteplugins/go/tasks/logs/logging_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
v1 "k8s.io/api/core/v1"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand All @@ -18,7 +19,7 @@ type logPlugin struct {
}

// Internal
func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, taskExecID *core.TaskExecutionIdentifier, pod *v1.Pod, index uint32, nameSuffix string, extraLogTemplateVarsByScheme *tasklog.TemplateVarsByScheme) ([]*core.TaskLog, error) {
func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, taskExecID pluginsCore.TaskExecutionID, pod *v1.Pod, index uint32, nameSuffix string, extraLogTemplateVarsByScheme *tasklog.TemplateVarsByScheme) ([]*core.TaskLog, error) {
if logPlugin == nil {
return nil, nil
}
Expand Down Expand Up @@ -53,7 +54,7 @@ func GetLogsForContainerInPod(ctx context.Context, logPlugin tasklog.Plugin, tas
PodRFC3339FinishTime: time.Unix(finishTime, 0).Format(time.RFC3339),
PodUnixStartTime: startTime,
PodUnixFinishTime: finishTime,
TaskExecutionIdentifier: taskExecID,
TaskExecutionID: taskExecID,
ExtraTemplateVarsByScheme: extraLogTemplateVarsByScheme,
},
)
Expand Down
4 changes: 4 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ type TaskExecutionID interface {

// GetID returns the underlying idl task identifier.
GetID() core.TaskExecutionIdentifier

// GetUniqueNodeID returns the fully-qualified Node ID that is unique within a
// given workflow execution.
GetUniqueNodeID() string
}

// TaskExecutionMetadata represents any execution information for a Task. It is used to communicate meta information about the
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/tasklog/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"regexp"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
)

//go:generate enumer --type=TemplateScheme --trimprefix=TemplateScheme -json -yaml
Expand Down Expand Up @@ -42,7 +43,7 @@ type Input struct {
PodUnixStartTime int64
PodUnixFinishTime int64
PodUID string
TaskExecutionIdentifier *core.TaskExecutionIdentifier
TaskExecutionID pluginsCore.TaskExecutionID
ExtraTemplateVarsByScheme *TemplateVarsByScheme
}

Expand Down
96 changes: 48 additions & 48 deletions flyteplugins/go/tasks/pluginmachinery/tasklog/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,55 +114,55 @@ func (input Input) templateVarsForScheme(scheme TemplateScheme) TemplateVars {
vars = append(vars, input.ExtraTemplateVarsByScheme.Pod...)
}
case TemplateSchemeTaskExecution:
if input.TaskExecutionIdentifier != nil {
vars = append(vars, TemplateVar{
taskExecutionIdentifier := input.TaskExecutionID.GetID()
vars = append(
vars,
TemplateVar{
defaultRegexes.NodeID,
input.TaskExecutionID.GetUniqueNodeID(),
},
TemplateVar{
defaultRegexes.TaskRetryAttempt,
strconv.FormatUint(uint64(input.TaskExecutionIdentifier.RetryAttempt), 10),
})
if input.TaskExecutionIdentifier.TaskId != nil {
vars = append(
vars,
TemplateVar{
defaultRegexes.TaskID,
input.TaskExecutionIdentifier.TaskId.Name,
},
TemplateVar{
defaultRegexes.TaskVersion,
input.TaskExecutionIdentifier.TaskId.Version,
},
TemplateVar{
defaultRegexes.TaskProject,
input.TaskExecutionIdentifier.TaskId.Project,
},
TemplateVar{
defaultRegexes.TaskDomain,
input.TaskExecutionIdentifier.TaskId.Domain,
},
)
}
if input.TaskExecutionIdentifier.NodeExecutionId != nil {
vars = append(vars, TemplateVar{
defaultRegexes.NodeID,
input.TaskExecutionIdentifier.NodeExecutionId.NodeId,
})
if input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId != nil {
vars = append(
vars,
TemplateVar{
defaultRegexes.ExecutionName,
input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Name,
},
TemplateVar{
defaultRegexes.ExecutionProject,
input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Project,
},
TemplateVar{
defaultRegexes.ExecutionDomain,
input.TaskExecutionIdentifier.NodeExecutionId.ExecutionId.Domain,
},
)
}
}
strconv.FormatUint(uint64(taskExecutionIdentifier.RetryAttempt), 10),
},
)
if taskExecutionIdentifier.TaskId != nil {
vars = append(
vars,
TemplateVar{
defaultRegexes.TaskID,
taskExecutionIdentifier.TaskId.Name,
},
TemplateVar{
defaultRegexes.TaskVersion,
taskExecutionIdentifier.TaskId.Version,
},
TemplateVar{
defaultRegexes.TaskProject,
taskExecutionIdentifier.TaskId.Project,
},
TemplateVar{
defaultRegexes.TaskDomain,
taskExecutionIdentifier.TaskId.Domain,
},
)
}
if taskExecutionIdentifier.NodeExecutionId != nil && taskExecutionIdentifier.NodeExecutionId.ExecutionId != nil {
vars = append(
vars,
TemplateVar{
defaultRegexes.ExecutionName,
taskExecutionIdentifier.NodeExecutionId.ExecutionId.Name,
},
TemplateVar{
defaultRegexes.ExecutionProject,
taskExecutionIdentifier.NodeExecutionId.ExecutionId.Project,
},
TemplateVar{
defaultRegexes.ExecutionDomain,
taskExecutionIdentifier.NodeExecutionId.ExecutionId.Domain,
},
)
}
if gotExtraTemplateVars {
vars = append(vars, input.ExtraTemplateVarsByScheme.TaskExecution...)
Expand Down
10 changes: 5 additions & 5 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,13 @@ func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s
status == daskAPI.DaskJobClusterCreated

if !isQueued {
taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID()
taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()
o, err := logPlugin.GetTaskLogs(
tasklog.Input{
Namespace: job.ObjectMeta.Namespace,
PodName: job.Status.JobRunnerPodName,
LogName: "(User logs)",
TaskExecutionIdentifier: &taskExecID,
Namespace: job.ObjectMeta.Namespace,
PodName: job.Status.JobRunnerPodName,
LogName: "(User logs)",
TaskExecutionID: taskExecID,
},
)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
namespace := objectMeta.Namespace

taskLogs := make([]*core.TaskLog, 0, 10)
taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID()
taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()

logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig())

Expand All @@ -120,14 +120,14 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
if taskType == PytorchTaskType && hasMaster {
masterTaskLog, masterErr := logPlugin.GetTaskLogs(
tasklog.Input{
PodName: name + "-master-0",
Namespace: namespace,
LogName: "master",
PodRFC3339StartTime: RFC3999StartTime,
PodRFC3339FinishTime: RFC3999FinishTime,
PodUnixStartTime: startTime,
PodUnixFinishTime: finishTime,
TaskExecutionIdentifier: &taskExecID,
PodName: name + "-master-0",
Namespace: namespace,
LogName: "master",
PodRFC3339StartTime: RFC3999StartTime,
PodRFC3339FinishTime: RFC3999FinishTime,
PodUnixStartTime: startTime,
PodUnixFinishTime: finishTime,
TaskExecutionID: taskExecID,
},
)
if masterErr != nil {
Expand All @@ -139,13 +139,13 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
// get all workers log
for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ {
workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-worker-%d", workerIndex),
Namespace: namespace,
PodRFC3339StartTime: RFC3999StartTime,
PodRFC3339FinishTime: RFC3999FinishTime,
PodUnixStartTime: startTime,
PodUnixFinishTime: finishTime,
TaskExecutionIdentifier: &taskExecID,
PodName: name + fmt.Sprintf("-worker-%d", workerIndex),
Namespace: namespace,
PodRFC3339StartTime: RFC3999StartTime,
PodRFC3339FinishTime: RFC3999FinishTime,
PodUnixStartTime: startTime,
PodUnixFinishTime: finishTime,
TaskExecutionID: taskExecID,
})
if err != nil {
return nil, err
Expand All @@ -160,9 +160,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
// get all parameter servers logs
for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ {
psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
TaskExecutionIdentifier: &taskExecID,
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
TaskExecutionID: taskExecID,
})
if err != nil {
return nil, err
Expand All @@ -172,9 +172,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
// get chief worker log, and the max number of chief worker is 1
if chiefReplicasCount != 0 {
chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
TaskExecutionIdentifier: &taskExecID,
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
TaskExecutionID: taskExecID,
})
if err != nil {
return nil, err
Expand All @@ -184,9 +184,9 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v
// get evaluator log, and the max number of evaluator is 1
if evaluatorReplicasCount != 0 {
evaluatorReplicasCount, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0),
Namespace: namespace,
TaskExecutionIdentifier: &taskExecID,
PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0),
Namespace: namespace,
TaskExecutionID: taskExecID,
})
if err != nil {
return nil, err
Expand Down
6 changes: 3 additions & 3 deletions flyteplugins/go/tasks/plugins/k8s/pod/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin
ReportedAt: &reportedAt,
}

taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID()
taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()
if pod.Status.Phase != v1.PodPending && pod.Status.Phase != v1.PodUnknown {
taskLogs, err := logs.GetLogsForContainerInPod(ctx, logPlugin, &taskExecID, pod, 0, logSuffix, extraLogTemplateVarsByScheme)
taskLogs, err := logs.GetLogsForContainerInPod(ctx, logPlugin, taskExecID, pod, 0, logSuffix, extraLogTemplateVarsByScheme)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down Expand Up @@ -211,7 +211,7 @@ func (plugin) GetTaskPhaseWithLogs(ctx context.Context, pluginContext k8s.Plugin
} else {
// if the primary container annotation exists, we use the status of the specified container
phaseInfo = flytek8s.DeterminePrimaryContainerPhase(primaryContainerName, pod.Status.ContainerStatuses, &info)
if phaseInfo.Phase() == pluginsCore.PhasePermanentFailure && phaseInfo.Err() != nil &&
if phaseInfo.Phase() == pluginsCore.PhasePermanentFailure && phaseInfo.Err() != nil &&
phaseInfo.Err().GetCode() == flytek8s.PrimaryContainerNotFound {
// if the primary container status is not found ensure that the primary container exists.
// note: it should be impossible for the primary container to not exist at this point.
Expand Down
6 changes: 3 additions & 3 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,10 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon
// TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs
// RayJob CRD does not include the name of the worker or head pod for now

taskID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetID()
taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()
logOutput, err := logPlugin.GetTaskLogs(tasklog.Input{
Namespace: rayJob.Namespace,
TaskExecutionIdentifier: &taskID,
Namespace: rayJob.Namespace,
TaskExecutionID: taskExecID,
})

if err != nil {
Expand Down
Loading

0 comments on commit 0680e1f

Please sign in to comment.