From dbc36a0b5c89fed181ce22f69357957d7ea8e4e6 Mon Sep 17 00:00:00 2001 From: Fabio Graetz Date: Wed, 17 Apr 2024 19:06:58 +0000 Subject: [PATCH] Fix spark tests Signed-off-by: Fabio Graetz --- .../go/tasks/plugins/k8s/spark/spark_test.go | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 561901226a..7feeb80e06 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -3,6 +3,7 @@ package spark import ( "context" "os" + "reflect" "strconv" "testing" @@ -118,9 +119,14 @@ func TestGetEventInfo(t *testing.T) { assert.Equal(t, expectedLinks, generatedLinks) info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.SubmittedState)) + generatedLinks = make([]string, 0, len(info.Logs)) + for _, l := range info.Logs { + generatedLinks = append(generatedLinks, l.Uri) + } assert.NoError(t, err) - assert.Len(t, info.Logs, 1) - assert.Equal(t, "https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.spark-app-name;streamFilter=typeLogStreamPrefix", info.Logs[0].Uri) + assert.Len(t, info.Logs, 5) + assert.Equal(t, expectedLinks[:5], generatedLinks) // No Spark Driver UI for Submitted state + assert.True(t, info.Logs[4].ShowWhilePending) // All User Logs should be shown while pending assert.NoError(t, setSparkConfig(&Config{ SparkHistoryServerURL: "spark-history.flyte", @@ -406,6 +412,19 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) taskExecutionMetadata.On("GetOverrides").Return(overrides) taskExecutionMetadata.On("GetK8sServiceAccount").Return("new-val") taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) + + inputState := k8s.PluginState{} + pluginStateReaderMock := mocks.PluginStateReader{} + pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return( + func(v interface{}) uint8 { + *(v.(*k8s.PluginState)) = inputState + return 0 + }, + func(v interface{}) error { + return nil + }) + + taskCtx.OnPluginStateReader().Return(&pluginStateReaderMock) return taskCtx }