From 4793a35321556a31bcceeb6881c3a95a1641436f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 23 Nov 2023 17:49:03 -0800 Subject: [PATCH 1/6] Improve the error message for databricks plugins Signed-off-by: Kevin Su --- Makefile | 4 +- .../plugins/webapi/athena/plugin_test.go | 1 - .../webapi/databricks/integration_test.go | 14 +- .../tasks/plugins/webapi/databricks/plugin.go | 249 ++++++++---------- .../plugins/webapi/databricks/plugin_test.go | 196 ++++++++++---- 5 files changed, 259 insertions(+), 205 deletions(-) diff --git a/Makefile b/Makefile index 595f91b16e..73d866e964 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,6 @@ setup_local_dev: ## Sets up k3d cluster with Flyte dependencies for local develo .PHONY: build_native_flyte build_native_flyte: FLYTECONSOLE_VERSION := latest build_native_flyte: - docker build \ + docker build --platform linux/arm64 \ --build-arg FLYTECONSOLE_VERSION=$(FLYTECONSOLE_VERSION) \ - --tag flyte-binary:native . + --tag pingsutw/flyte-binary:t1 . diff --git a/flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go index e19829447e..c3fc39f451 100644 --- a/flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/athena/plugin_test.go @@ -23,7 +23,6 @@ func TestCreateTaskInfo(t *testing.T) { assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "query_id") } - func TestCreateTaskInfoGovAWS(t *testing.T) { taskInfo := createTaskInfo("query_id", awsSdk.Config{ Region: "us-gov-east-1", diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go index 651892f672..d18f4ba79e 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go @@ -109,8 +109,8 @@ func newFakeDatabricksServer() *httptest.Server { runID := "065168461" jobID := "019e7546" return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - if request.URL.Path == fmt.Sprintf("%v/submit", databricksAPI) && request.Method == post { - writer.WriteHeader(202) + if request.URL.Path == fmt.Sprintf("%v/submit", databricksAPI) && request.Method == http.MethodPost { + writer.WriteHeader(http.StatusOK) bytes := []byte(fmt.Sprintf(`{ "run_id": "%v" }`, runID)) @@ -118,8 +118,8 @@ func newFakeDatabricksServer() *httptest.Server { return } - if request.URL.Path == fmt.Sprintf("%v/get", databricksAPI) && request.Method == get { - writer.WriteHeader(200) + if request.URL.Path == fmt.Sprintf("%v/get", databricksAPI) && request.Method == http.MethodGet { + writer.WriteHeader(http.StatusOK) bytes := []byte(fmt.Sprintf(`{ "job_id": "%v", "state": {"state_message": "execution in progress.", "life_cycle_state": "TERMINATED", "result_state": "SUCCESS"} @@ -128,12 +128,12 @@ func newFakeDatabricksServer() *httptest.Server { return } - if request.URL.Path == fmt.Sprintf("%v/cancel", databricksAPI) && request.Method == post { - writer.WriteHeader(200) + if request.URL.Path == fmt.Sprintf("%v/cancel", databricksAPI) && request.Method == http.MethodPost { + writer.WriteHeader(http.StatusOK) return } - writer.WriteHeader(500) + writer.WriteHeader(http.StatusInternalServerError) })) } diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go index 3bd03135dc..1862125c75 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go @@ -6,6 +6,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" "time" @@ -26,20 +27,20 @@ import ( ) const ( - ErrSystem errors.ErrorCode = "System" - post string = "POST" - get string = "GET" - databricksAPI string = "/api/2.1/jobs/runs" - newCluster string = "new_cluster" - dockerImage string = "docker_image" - sparkConfig string = "spark_conf" - sparkPythonTask string = "spark_python_task" - pythonFile string = "python_file" - parameters string = "parameters" - url string = "url" + create string = "create" + get string = "get" + cancel string = "cancel" + databricksAPI string = "/api/2.1/jobs/runs" + newCluster string = "new_cluster" + dockerImage string = "docker_image" + sparkConfig string = "spark_conf" + sparkPythonTask string = "spark_python_task" + pythonFile string = "python_file" + parameters string = "parameters" + url string = "url" ) -// for mocking/testing purposes, and we'll override this method +// HTTPClient for mocking/testing purposes, and we'll override this method type HTTPClient interface { Do(req *http.Request) (*http.Response, error) } @@ -127,60 +128,31 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR } databricksJob[sparkPythonTask] = map[string]interface{}{pythonFile: p.cfg.EntrypointFile, parameters: modifiedArgs} - req, err := buildRequest(post, databricksJob, p.cfg.databricksEndpoint, - p.cfg.DatabricksInstance, token, "", false) + data, err := p.sendRequest(create, databricksJob, token, "") if err != nil { return nil, nil, err } - resp, err := p.client.Do(req) - if err != nil { - return nil, nil, err - } - defer resp.Body.Close() - data, err := buildResponse(resp) - if err != nil { - return nil, nil, err - } - if data["run_id"] == "" { - return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, - "Unable to fetch statementHandle from http response") - } runID := fmt.Sprintf("%.0f", data["run_id"]) - return ResourceMetaWrapper{runID, p.cfg.DatabricksInstance, token}, - ResourceWrapper{StatusCode: resp.StatusCode}, nil + return ResourceMetaWrapper{runID, p.cfg.DatabricksInstance, token}, nil, nil } func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) - req, err := buildRequest(get, nil, p.cfg.databricksEndpoint, - p.cfg.DatabricksInstance, exec.Token, exec.RunID, false) - if err != nil { - logger.Errorf(ctx, "Failed to build databricks job request [%v]", err) - return nil, err - } - resp, err := p.client.Do(req) - logger.Debugf(ctx, "Get databricks job response", "resp", resp) - if err != nil { - logger.Errorf(ctx, "Failed to get databricks job status [%v]", resp) - return nil, err - } - defer resp.Body.Close() - data, err := buildResponse(resp) + res, err := p.sendRequest(get, nil, exec.Token, exec.RunID) if err != nil { return nil, err } - if data == nil || data["state"] == nil { + if res == nil || res["state"] == nil { return nil, errors.Errorf("CorruptedPluginState", "can't get the job state") } - jobState := data["state"].(map[string]interface{}) + jobState := res["state"].(map[string]interface{}) message := fmt.Sprintf("%s", jobState["state_message"]) - jobID := fmt.Sprintf("%.0f", data["job_id"]) + jobID := fmt.Sprintf("%.0f", res["job_id"]) lifeCycleState := fmt.Sprintf("%s", jobState["life_cycle_state"]) resultState := fmt.Sprintf("%s", jobState["result_state"]) return ResourceWrapper{ - StatusCode: resp.StatusCode, JobID: jobID, LifeCycleState: lifeCycleState, ResultState: resultState, @@ -193,63 +165,116 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) - req, err := buildRequest(post, nil, p.cfg.databricksEndpoint, - p.cfg.DatabricksInstance, exec.Token, exec.RunID, true) + _, err := p.sendRequest(cancel, nil, exec.Token, exec.RunID) if err != nil { return err } + logger.Info(ctx, "Deleted Databricks job execution.") + + return nil +} + +func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, token string, runID string) (map[string]interface{}, error) { + var databricksURL string + // for mocking/testing purposes + if p.cfg.databricksEndpoint == "" { + databricksURL = fmt.Sprintf("https://%v%v", p.cfg.DatabricksInstance, databricksAPI) + } else { + databricksURL = fmt.Sprintf("%v%v", p.cfg.databricksEndpoint, databricksAPI) + } + + // build the request spec + var body io.Reader + var httpMethod string + switch method { + case create: + databricksURL += "/submit" + mJSON, err := json.Marshal(databricksJob) + if err != nil { + return nil, fmt.Errorf("failed to marshal the job request: %v", err) + } + body = bytes.NewBuffer(mJSON) + httpMethod = http.MethodPost + case get: + databricksURL += "/get?run_id=" + runID + httpMethod = http.MethodGet + case cancel: + databricksURL += "/cancel" + body = bytes.NewBuffer([]byte(fmt.Sprintf("{ \"run_id\": %v }", runID))) + httpMethod = http.MethodPost + } + + req, err := http.NewRequest(httpMethod, databricksURL, body) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+token) + req.Header.Add("Content-Type", "application/json") + + // Send the request resp, err := p.client.Do(req) if err != nil { - return err + return nil, fmt.Errorf("failed to send request to Databricks platform with err: [%v]", err) } defer resp.Body.Close() - logger.Info(ctx, "Deleted query execution [%v]", resp) - return nil + logger.Errorf(context.Background(), "resp.Body [%v]", resp) + // Parse the response body + responseBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + logger.Errorf(context.Background(), "responseBody [%v]", string(responseBody)) + var data map[string]interface{} + err = json.Unmarshal(responseBody, &data) + if err != nil { + return nil, fmt.Errorf("failed to parse response with err: [%v]", err) + } + if resp.StatusCode != http.StatusOK { + message := "" + if v, ok := data["message"]; ok { + message = v.(string) + } + return nil, fmt.Errorf("failed to %v Databricks job with error [%v]", method, message) + } + return data, nil } func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) resource := taskCtx.Resource().(ResourceWrapper) message := resource.Message - statusCode := resource.StatusCode jobID := resource.JobID lifeCycleState := resource.LifeCycleState resultState := resource.ResultState - if statusCode == 0 { - return core.PhaseInfoUndefined, errors.Errorf(ErrSystem, "No Status field set.") - } - + logger.Errorf(context.Background(), "lifeCycleStateeeee [%v]", lifeCycleState) taskInfo := createTaskInfo(exec.RunID, jobID, exec.DatabricksInstance) - switch statusCode { - // Job response format. https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit - case http.StatusAccepted: - return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil - case http.StatusOK: - if lifeCycleState == "TERMINATED" || lifeCycleState == "TERMINATING" || lifeCycleState == "INTERNAL_ERROR" { - if resultState == "SUCCESS" { - if err := writeOutput(ctx, taskCtx); err != nil { - pluginsCore.PhaseInfoFailure(string(rune(statusCode)), "failed to write output", taskInfo) - } - return pluginsCore.PhaseInfoSuccess(taskInfo), nil + switch lifeCycleState { + // Job response format. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runlifecyclestate + case "": + return core.PhaseInfoQueued(time.Now(), core.DefaultPhaseVersion, "job queued"), nil + case "PENDING": + return core.PhaseInfoInitializing(time.Now(), core.DefaultPhaseVersion, message, taskInfo), nil + case "RUNNING": + fallthrough + case "TERMINATING": + return core.PhaseInfoRunning(core.DefaultPhaseVersion, taskInfo), nil + case "TERMINATED": + if resultState == "SUCCESS" { + // Result state details. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + if err := writeOutput(ctx, taskCtx); err != nil { + return core.PhaseInfoFailure(string(rune(http.StatusInternalServerError)), "failed to write output", taskInfo), nil } - return pluginsCore.PhaseInfoFailure(string(rune(statusCode)), message, taskInfo), nil - } - - if lifeCycleState == "PENDING" { - return core.PhaseInfoInitializing(time.Now(), core.DefaultPhaseVersion, message, taskInfo), nil + return core.PhaseInfoSuccess(taskInfo), nil } - - return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil - case http.StatusBadRequest: - fallthrough - case http.StatusInternalServerError: - fallthrough - case http.StatusUnauthorized: - return pluginsCore.PhaseInfoFailure(string(rune(statusCode)), message, taskInfo), nil + return core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, message, taskInfo), nil + case "SKIPPED": + return core.PhaseInfoFailure(string(rune(http.StatusConflict)), message, taskInfo), nil + case "INTERNAL_ERROR": + return core.PhaseInfoFailure(string(rune(http.StatusInternalServerError)), message, taskInfo), nil } - return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", statusCode) + return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", lifeCycleState) } func writeOutput(ctx context.Context, taskCtx webapi.StatusContext) error { @@ -266,66 +291,6 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext) error { return taskCtx.OutputWriter().Put(ctx, outputReader) } -func buildRequest( - method string, - databricksJob map[string]interface{}, - databricksEndpoint string, - databricksInstance string, - token string, - runID string, - isCancel bool, -) (*http.Request, error) { - var databricksURL string - // for mocking/testing purposes - if databricksEndpoint == "" { - databricksURL = fmt.Sprintf("https://%v%v", databricksInstance, databricksAPI) - } else { - databricksURL = fmt.Sprintf("%v%v", databricksEndpoint, databricksAPI) - } - - var data []byte - var req *http.Request - var err error - if isCancel { - databricksURL += "/cancel" - data = []byte(fmt.Sprintf("{ \"run_id\": %v }", runID)) - } else if method == post { - databricksURL += "/submit" - mJSON, err := json.Marshal(databricksJob) - if err != nil { - return nil, err - } - data = []byte(string(mJSON)) - } else { - databricksURL += "/get?run_id=" + runID - } - - if data == nil { - req, err = http.NewRequest(method, databricksURL, nil) - } else { - req, err = http.NewRequest(method, databricksURL, bytes.NewBuffer(data)) - } - if err != nil { - return nil, err - } - req.Header.Add("Authorization", "Bearer "+token) - req.Header.Add("Content-Type", "application/json") - return req, nil -} - -func buildResponse(response *http.Response) (map[string]interface{}, error) { - responseBody, err := ioutil.ReadAll(response.Body) - if err != nil { - return nil, err - } - var data map[string]interface{} - err = json.Unmarshal(responseBody, &data) - if err != nil { - return nil, err - } - return data, nil -} - func createTaskInfo(runID, jobID, databricksInstance string) *core.TaskInfo { timeNow := time.Now() diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go index fda3ab61b0..b8b0bebb22 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go @@ -2,10 +2,8 @@ package databricks import ( "context" - "encoding/json" - "io/ioutil" + "errors" "net/http" - "strings" "testing" "time" @@ -13,21 +11,22 @@ import ( pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flytestdlib/ioutils" "github.com/flyteorg/flyte/flytestdlib/promutils" ) type MockClient struct { + MockDo func(req *http.Request) (*http.Response, error) +} + +func (m MockClient) Do(req *http.Request) (*http.Response, error) { + return m.MockDo(req) } var ( - MockDo func(req *http.Request) (*http.Response, error) testInstance = "test-account.cloud.databricks.com" ) -func (m *MockClient) Do(req *http.Request) (*http.Response, error) { - return MockDo(req) -} - func TestPlugin(t *testing.T) { fakeSetupContext := pluginCoreMocks.SetupContext{} fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) @@ -35,7 +34,9 @@ func TestPlugin(t *testing.T) { plugin := Plugin{ metricScope: fakeSetupContext.MetricsScope(), cfg: GetConfig(), - client: &MockClient{}, + client: &MockClient{func(req *http.Request) (*http.Response, error) { + return nil, nil + }}, } t.Run("get config", func(t *testing.T) { cfg := defaultConfig @@ -53,61 +54,150 @@ func TestPlugin(t *testing.T) { }) } -func TestCreateTaskInfo(t *testing.T) { - t.Run("create task info", func(t *testing.T) { - taskInfo := createTaskInfo("run-id", "job-id", testInstance) +func TestSendRequest(t *testing.T) { + fakeSetupContext := pluginCoreMocks.SetupContext{} + fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test1")) + databricksJob := map[string]interface{}{"sparkConfig": map[string]interface{}{"sparkVersion": "7.3.x-scala2.12"}} + token := "token" - assert.Equal(t, 1, len(taskInfo.Logs)) - assert.Equal(t, taskInfo.Logs[0].Uri, "https://test-account.cloud.databricks.com/#job/job-id/run/run-id") - assert.Equal(t, taskInfo.Logs[0].Name, "Databricks Console") + plugin := Plugin{ + metricScope: fakeSetupContext.MetricsScope(), + cfg: GetConfig(), + client: &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutils.NewBytesReadCloser([]byte(`{"id":"someID","data":"someData"}`)), + }, nil + }}, + } + + t.Run("create a Databricks job", func(t *testing.T) { + data, err := plugin.sendRequest(create, databricksJob, token, "") + assert.NotNil(t, data) + assert.Equal(t, "someID", data["id"]) + assert.Equal(t, "someData", data["data"]) + assert.Nil(t, err) }) -} -func TestBuildRequest(t *testing.T) { - token := "test-token" - runID := "019e70eb" - databricksEndpoint := "" - databricksURL := "https://" + testInstance + "/api/2.1/jobs/runs" - t.Run("build http request for submitting a databricks job", func(t *testing.T) { - req, err := buildRequest(post, nil, databricksEndpoint, testInstance, token, runID, false) - header := http.Header{} - header.Add("Authorization", "Bearer "+token) - header.Add("Content-Type", "application/json") + t.Run("failed to create a Databricks job", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: ioutils.NewBytesReadCloser([]byte(`{"message":"failed"}`)), + }, nil + }} + data, err := plugin.sendRequest(create, databricksJob, token, "") + assert.Nil(t, data) + assert.Equal(t, err.Error(), "failed to create Databricks job with error [failed]") + }) - assert.NoError(t, err) - assert.Equal(t, header, req.Header) - assert.Equal(t, databricksURL+"/submit", req.URL.String()) - assert.Equal(t, post, req.Method) + t.Run("failed to send request to Databricks", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return nil, errors.New("failed to send request") + }} + data, err := plugin.sendRequest(create, databricksJob, token, "") + assert.Nil(t, data) + assert.Equal(t, err.Error(), "failed to send request to Databricks platform with err: [failed to send request]") }) - t.Run("Get a databricks spark job status", func(t *testing.T) { - req, err := buildRequest(get, nil, databricksEndpoint, testInstance, token, runID, false) - assert.NoError(t, err) - assert.Equal(t, databricksURL+"/get?run_id="+runID, req.URL.String()) - assert.Equal(t, get, req.Method) + t.Run("failed to send request to Databricks", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutils.NewBytesReadCloser([]byte(`123`)), + }, nil + }} + data, err := plugin.sendRequest(create, databricksJob, token, "") + assert.Nil(t, data) + assert.Equal(t, err.Error(), "failed to parse response with err: [json: cannot unmarshal number into Go value of type map[string]interface {}]") }) - t.Run("Cancel a spark job", func(t *testing.T) { - req, err := buildRequest(post, nil, databricksEndpoint, testInstance, token, runID, true) - assert.NoError(t, err) - assert.Equal(t, databricksURL+"/cancel", req.URL.String()) - assert.Equal(t, post, req.Method) + t.Run("get a Databricks job", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodGet) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)), + }, nil + }} + data, err := plugin.sendRequest(get, databricksJob, token, "") + assert.NotNil(t, data) + assert.Nil(t, err) + }) + + t.Run("cancel a Databricks job", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)), + }, nil + }} + data, err := plugin.sendRequest(cancel, databricksJob, token, "") + assert.NotNil(t, data) + assert.Nil(t, err) }) } -func TestBuildResponse(t *testing.T) { - t.Run("build http response", func(t *testing.T) { - bodyStr := `{"job_id":"019c06a4-0000", "message":"Statement executed successfully."}` - responseBody := ioutil.NopCloser(strings.NewReader(bodyStr)) - response := &http.Response{Body: responseBody} - actualData, err := buildResponse(response) - assert.NoError(t, err) +func TestCreateTaskInfo(t *testing.T) { + t.Run("create task info", func(t *testing.T) { + taskInfo := createTaskInfo("run-id", "job-id", testInstance) - bodyByte, err := ioutil.ReadAll(strings.NewReader(bodyStr)) - assert.NoError(t, err) - var expectedData map[string]interface{} - err = json.Unmarshal(bodyByte, &expectedData) - assert.NoError(t, err) - assert.Equal(t, expectedData, actualData) + assert.Equal(t, 1, len(taskInfo.Logs)) + assert.Equal(t, taskInfo.Logs[0].Uri, "https://test-account.cloud.databricks.com/#job/job-id/run/run-id") + assert.Equal(t, taskInfo.Logs[0].Name, "Databricks Console") }) } + +//func TestBuildRequest(t *testing.T) { +// token := "test-token" +// runID := "019e70eb" +// databricksEndpoint := "" +// databricksURL := "https://" + testInstance + "/api/2.1/jobs/runs" +// t.Run("build http request for submitting a databricks job", func(t *testing.T) { +// req, err := buildRequest(post, nil, databricksEndpoint, testInstance, token, runID, false) +// header := http.Header{} +// header.Add("Authorization", "Bearer "+token) +// header.Add("Content-Type", "application/json") +// +// assert.NoError(t, err) +// assert.Equal(t, header, req.Header) +// assert.Equal(t, databricksURL+"/submit", req.URL.String()) +// assert.Equal(t, post, req.Method) +// }) +// t.Run("Get a databricks spark job status", func(t *testing.T) { +// req, err := buildRequest(get, nil, databricksEndpoint, testInstance, token, runID, false) +// +// assert.NoError(t, err) +// assert.Equal(t, databricksURL+"/get?run_id="+runID, req.URL.String()) +// assert.Equal(t, get, req.Method) +// }) +// t.Run("Cancel a spark job", func(t *testing.T) { +// req, err := buildRequest(post, nil, databricksEndpoint, testInstance, token, runID, true) +// +// assert.NoError(t, err) +// assert.Equal(t, databricksURL+"/cancel", req.URL.String()) +// assert.Equal(t, post, req.Method) +// }) +//} +// +//func TestBuildResponse(t *testing.T) { +// t.Run("build http response", func(t *testing.T) { +// bodyStr := `{"job_id":"019c06a4-0000", "message":"Statement executed successfully."}` +// responseBody := ioutil.NopCloser(strings.NewReader(bodyStr)) +// response := &http.Response{Body: responseBody} +// actualData, err := buildResponse(response) +// assert.NoError(t, err) +// +// bodyByte, err := ioutil.ReadAll(strings.NewReader(bodyStr)) +// assert.NoError(t, err) +// var expectedData map[string]interface{} +// err = json.Unmarshal(bodyByte, &expectedData) +// assert.NoError(t, err) +// assert.Equal(t, expectedData, actualData) +// }) +//} From e95018db0e7adf236568a57885e5b93492b47baf Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 23 Nov 2023 17:52:43 -0800 Subject: [PATCH 2/6] nit Signed-off-by: Kevin Su --- Makefile | 4 +- .../plugins/webapi/databricks/plugin_test.go | 49 ------------------- 2 files changed, 2 insertions(+), 51 deletions(-) diff --git a/Makefile b/Makefile index 73d866e964..595f91b16e 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,6 @@ setup_local_dev: ## Sets up k3d cluster with Flyte dependencies for local develo .PHONY: build_native_flyte build_native_flyte: FLYTECONSOLE_VERSION := latest build_native_flyte: - docker build --platform linux/arm64 \ + docker build \ --build-arg FLYTECONSOLE_VERSION=$(FLYTECONSOLE_VERSION) \ - --tag pingsutw/flyte-binary:t1 . + --tag flyte-binary:native . diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go index b8b0bebb22..228914af93 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go @@ -152,52 +152,3 @@ func TestCreateTaskInfo(t *testing.T) { assert.Equal(t, taskInfo.Logs[0].Name, "Databricks Console") }) } - -//func TestBuildRequest(t *testing.T) { -// token := "test-token" -// runID := "019e70eb" -// databricksEndpoint := "" -// databricksURL := "https://" + testInstance + "/api/2.1/jobs/runs" -// t.Run("build http request for submitting a databricks job", func(t *testing.T) { -// req, err := buildRequest(post, nil, databricksEndpoint, testInstance, token, runID, false) -// header := http.Header{} -// header.Add("Authorization", "Bearer "+token) -// header.Add("Content-Type", "application/json") -// -// assert.NoError(t, err) -// assert.Equal(t, header, req.Header) -// assert.Equal(t, databricksURL+"/submit", req.URL.String()) -// assert.Equal(t, post, req.Method) -// }) -// t.Run("Get a databricks spark job status", func(t *testing.T) { -// req, err := buildRequest(get, nil, databricksEndpoint, testInstance, token, runID, false) -// -// assert.NoError(t, err) -// assert.Equal(t, databricksURL+"/get?run_id="+runID, req.URL.String()) -// assert.Equal(t, get, req.Method) -// }) -// t.Run("Cancel a spark job", func(t *testing.T) { -// req, err := buildRequest(post, nil, databricksEndpoint, testInstance, token, runID, true) -// -// assert.NoError(t, err) -// assert.Equal(t, databricksURL+"/cancel", req.URL.String()) -// assert.Equal(t, post, req.Method) -// }) -//} -// -//func TestBuildResponse(t *testing.T) { -// t.Run("build http response", func(t *testing.T) { -// bodyStr := `{"job_id":"019c06a4-0000", "message":"Statement executed successfully."}` -// responseBody := ioutil.NopCloser(strings.NewReader(bodyStr)) -// response := &http.Response{Body: responseBody} -// actualData, err := buildResponse(response) -// assert.NoError(t, err) -// -// bodyByte, err := ioutil.ReadAll(strings.NewReader(bodyStr)) -// assert.NoError(t, err) -// var expectedData map[string]interface{} -// err = json.Unmarshal(bodyByte, &expectedData) -// assert.NoError(t, err) -// assert.Equal(t, expectedData, actualData) -// }) -//} From 232e7e68f820d32bdbb4d62f6c490e2ffcce66f3 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 24 Nov 2023 00:00:00 -0800 Subject: [PATCH 3/6] nit Signed-off-by: Kevin Su --- .../go/tasks/pluginmachinery/internal/webapi/launcher.go | 3 ++- flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go | 5 ----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go index bf6ef791ac..d5185b94fa 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go @@ -2,6 +2,7 @@ package webapi import ( "context" + pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "time" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" @@ -15,7 +16,7 @@ func launch(ctx context.Context, p webapi.AsyncPlugin, tCtx core.TaskExecutionCo rMeta, r, err := p.Create(ctx, tCtx) if err != nil { logger.Errorf(ctx, "Failed to create resource. Error: %v", err) - return nil, core.PhaseInfo{}, err + return state, core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, err.Error(), nil), nil } // If the plugin also returned the created resource, check to see if it's already in a terminal state. diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go index 1862125c75..33bef67152 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go @@ -218,13 +218,11 @@ func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, } defer resp.Body.Close() - logger.Errorf(context.Background(), "resp.Body [%v]", resp) // Parse the response body responseBody, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err } - logger.Errorf(context.Background(), "responseBody [%v]", string(responseBody)) var data map[string]interface{} err = json.Unmarshal(responseBody, &data) if err != nil { @@ -248,12 +246,9 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase lifeCycleState := resource.LifeCycleState resultState := resource.ResultState - logger.Errorf(context.Background(), "lifeCycleStateeeee [%v]", lifeCycleState) taskInfo := createTaskInfo(exec.RunID, jobID, exec.DatabricksInstance) switch lifeCycleState { // Job response format. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runlifecyclestate - case "": - return core.PhaseInfoQueued(time.Now(), core.DefaultPhaseVersion, "job queued"), nil case "PENDING": return core.PhaseInfoInitializing(time.Now(), core.DefaultPhaseVersion, message, taskInfo), nil case "RUNNING": From 1b67946bbf18bf111d9dd0b448508ecf18feb243 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 24 Nov 2023 00:14:09 -0800 Subject: [PATCH 4/6] nit Signed-off-by: Kevin Su --- .../plugins/webapi/agent/integration_test.go | 12 ++++++------ .../go/tasks/plugins/webapi/databricks/plugin.go | 16 +++++++++++++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 0aecffdfc7..1f1bdb8fc2 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -149,8 +149,8 @@ func TestEndToEnd(t *testing.T) { tCtx.OnInputReader().Return(inputReader) trns, err := plugin.Handle(context.Background(), tCtx) - assert.Error(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + assert.Nil(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhasePermanentFailure) err = plugin.Abort(context.Background(), tCtx) assert.Nil(t, err) }) @@ -167,8 +167,8 @@ func TestEndToEnd(t *testing.T) { assert.NoError(t, err) trns, err := plugin.Handle(context.Background(), tCtx) - assert.Error(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + assert.Nil(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhasePermanentFailure) }) t.Run("failed to read inputs", func(t *testing.T) { @@ -188,8 +188,8 @@ func TestEndToEnd(t *testing.T) { assert.NoError(t, err) trns, err := plugin.Handle(context.Background(), tCtx) - assert.Error(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + assert.Nil(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhasePermanentFailure) }) } diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go index 33bef67152..5ebe1d0075 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go @@ -133,6 +133,9 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR return nil, nil, err } + if _, ok := data["run_id"]; !ok { + return nil, nil, errors.Errorf("CorruptedPluginState", "can't get the run_id") + } runID := fmt.Sprintf("%.0f", data["run_id"]) return ResourceMetaWrapper{runID, p.cfg.DatabricksInstance, token}, nil, nil @@ -144,14 +147,21 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba if err != nil { return nil, err } - if res == nil || res["state"] == nil { + if _, ok := res["state"]; !ok { return nil, errors.Errorf("CorruptedPluginState", "can't get the job state") } jobState := res["state"].(map[string]interface{}) - message := fmt.Sprintf("%s", jobState["state_message"]) jobID := fmt.Sprintf("%.0f", res["job_id"]) + message := fmt.Sprintf("%s", jobState["state_message"]) lifeCycleState := fmt.Sprintf("%s", jobState["life_cycle_state"]) - resultState := fmt.Sprintf("%s", jobState["result_state"]) + var resultState string + if _, ok := jobState["result_state"]; !ok { + // The result_state is not available until the job is finished. + // https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + resultState = "" + } else { + resultState = fmt.Sprintf("%s", jobState["result_state"]) + } return ResourceWrapper{ JobID: jobID, LifeCycleState: lifeCycleState, From f40f62f31c34ce470dbe1f2cfb696fcaf4711211 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 24 Nov 2023 00:24:49 -0800 Subject: [PATCH 5/6] fix tests Signed-off-by: Kevin Su --- .../tasks/pluginmachinery/internal/webapi/launcher_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go index 85ba42d0c6..4533b1d851 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go @@ -79,8 +79,9 @@ func Test_launch(t *testing.T) { plgn := newPluginWithProperties(webapi.PluginConfig{}) plgn.OnCreate(ctx, tCtx).Return("", nil, fmt.Errorf("error creating")) - _, _, err := launch(ctx, plgn, tCtx, c, &s) - assert.Error(t, err) + _, phase, err := launch(ctx, plgn, tCtx, c, &s) + assert.Nil(t, err) + assert.Equal(t, core.PhasePermanentFailure, phase.Phase()) }) t.Run("Failed to cache", func(t *testing.T) { From a2623e3dc4215a56f4fd2fbb81aacaf119985f48 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sun, 3 Mar 2024 23:11:55 -0800 Subject: [PATCH 6/6] PhaseInfoRetryableFailure Signed-off-by: Kevin Su --- .../go/tasks/pluginmachinery/internal/webapi/launcher.go | 2 +- .../tasks/pluginmachinery/internal/webapi/launcher_test.go | 2 +- .../go/tasks/plugins/webapi/agent/integration_test.go | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go index d5185b94fa..99a3ccdf7a 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go @@ -16,7 +16,7 @@ func launch(ctx context.Context, p webapi.AsyncPlugin, tCtx core.TaskExecutionCo rMeta, r, err := p.Create(ctx, tCtx) if err != nil { logger.Errorf(ctx, "Failed to create resource. Error: %v", err) - return state, core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, err.Error(), nil), nil + return state, core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, err.Error(), nil), nil } // If the plugin also returned the created resource, check to see if it's already in a terminal state. diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go index 4533b1d851..7836cc591d 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go @@ -81,7 +81,7 @@ func Test_launch(t *testing.T) { plgn.OnCreate(ctx, tCtx).Return("", nil, fmt.Errorf("error creating")) _, phase, err := launch(ctx, plgn, tCtx, c, &s) assert.Nil(t, err) - assert.Equal(t, core.PhasePermanentFailure, phase.Phase()) + assert.Equal(t, core.PhaseRetryableFailure, phase.Phase()) }) t.Run("Failed to cache", func(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index c84458dbba..d863d77a8c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -113,7 +113,7 @@ func TestEndToEnd(t *testing.T) { trns, err := plugin.Handle(context.Background(), tCtx) assert.Nil(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhasePermanentFailure) + assert.Equal(t, trns.Info().Phase(), core.PhaseRetryableFailure) err = plugin.Abort(context.Background(), tCtx) assert.Nil(t, err) }) @@ -131,7 +131,7 @@ func TestEndToEnd(t *testing.T) { trns, err := plugin.Handle(context.Background(), tCtx) assert.Nil(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhasePermanentFailure) + assert.Equal(t, trns.Info().Phase(), core.PhaseRetryableFailure) }) t.Run("failed to read inputs", func(t *testing.T) { @@ -152,7 +152,7 @@ func TestEndToEnd(t *testing.T) { trns, err := plugin.Handle(context.Background(), tCtx) assert.Nil(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhasePermanentFailure) + assert.Equal(t, trns.Info().Phase(), core.PhaseRetryableFailure) }) }