Skip to content

Commit

Permalink
Fix Kubeflow TF Operator GetTaskPhase Bug (#4469)
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
3 people authored Nov 23, 2023
1 parent 568e686 commit ea72bbd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,24 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task
return job, nil
}

func getReplicaCount(specs map[commonOp.ReplicaType]*commonOp.ReplicaSpec, replicaType commonOp.ReplicaType) *int32 {
if spec, ok := specs[replicaType]; ok && spec.Replicas != nil {
return spec.Replicas
}

return new(int32) // return 0 as default value
}

// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast,
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app := resource.(*kubeflowv1.TFJob)

workersCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas
psReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas
chiefCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas
evaluatorReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas
workersCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker)
psReplicasCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)
chiefCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)
evaluatorReplicasCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval)

taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, false,
*workersCount, *psReplicasCount, *chiefCount, *evaluatorReplicasCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,22 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso
}
}

func TestGetReplicaCount(t *testing.T) {
tensorflowResourceHandler := tensorflowOperatorResourceHandler{}
tfObj := dummyTensorFlowCustomObj(1, 0, 0, 0)
taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj)
resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil))
assert.NoError(t, err)
assert.NotNil(t, resource)
tensorflowJob, ok := resource.(*kubeflowv1.TFJob)
assert.True(t, ok)

assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval))
}

func TestBuildResourceTensorFlow(t *testing.T) {
tensorflowResourceHandler := tensorflowOperatorResourceHandler{}

Expand Down
1 change: 0 additions & 1 deletion flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ func TestCreateTaskInfo(t *testing.T) {
assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "query_id")
}


func TestCreateTaskInfoGovAWS(t *testing.T) {
taskInfo := createTaskInfo("query_id", awsSdk.Config{
Region: "us-gov-east-1",
Expand Down

0 comments on commit ea72bbd

Please sign in to comment.