Skip to content

Commit

Permalink
Bump phase version in pytorch plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 committed Jan 14, 2024
1 parent f39865e commit 1fa91aa
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ func ParseElasticConfig(elasticConfig ElasticConfig) *kubeflowv1.ElasticPolicy {
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
pluginState := k8s.PluginState{}
_, err := pluginContext.PluginStateReader().Get(&pluginState)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

app, ok := resource.(*kubeflowv1.PyTorchJob)
if !ok {
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
Expand Down Expand Up @@ -205,7 +211,22 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont
CustomInfo: statusDetails,
}

return common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo)
phaseInfo, err := common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo)

// TODO: logic copied from pod/plugin.go
// Can we centralize this logic to not reproduce in every single plugin?????
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
} else if phaseInfo.Phase() != pluginsCore.PhaseRunning && phaseInfo.Phase() == pluginState.Phase &&
phaseInfo.Version() <= pluginState.PhaseVersion && phaseInfo.Reason() != pluginState.Reason {

// if we have the same Phase as the previous evaluation and updated the Reason but not the PhaseVersion we must
// update the PhaseVersion so an event is sent to reflect the Reason update. this does not handle the Running
// Phase because the legacy used `DefaultPhaseVersion + 1` which will only increment to 1.
phaseInfo = phaseInfo.WithVersion(pluginState.PhaseVersion + 1)
}

return phaseInfo, err
}

func init() {
Expand Down

0 comments on commit 1fa91aa

Please sign in to comment.