diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go index bf6ef791ac..99a3ccdf7a 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher.go @@ -2,6 +2,7 @@ package webapi import ( "context" + pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "time" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" @@ -15,7 +16,7 @@ func launch(ctx context.Context, p webapi.AsyncPlugin, tCtx core.TaskExecutionCo rMeta, r, err := p.Create(ctx, tCtx) if err != nil { logger.Errorf(ctx, "Failed to create resource. Error: %v", err) - return nil, core.PhaseInfo{}, err + return state, core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, err.Error(), nil), nil } // If the plugin also returned the created resource, check to see if it's already in a terminal state. diff --git a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go index 85ba42d0c6..7836cc591d 100644 --- a/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/internal/webapi/launcher_test.go @@ -79,8 +79,9 @@ func Test_launch(t *testing.T) { plgn := newPluginWithProperties(webapi.PluginConfig{}) plgn.OnCreate(ctx, tCtx).Return("", nil, fmt.Errorf("error creating")) - _, _, err := launch(ctx, plgn, tCtx, c, &s) - assert.Error(t, err) + _, phase, err := launch(ctx, plgn, tCtx, c, &s) + assert.Nil(t, err) + assert.Equal(t, core.PhaseRetryableFailure, phase.Phase()) }) t.Run("Failed to cache", func(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index a4ddc5e303..689527ee3b 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -137,8 +137,8 @@ func TestEndToEnd(t *testing.T) { tCtx.OnInputReader().Return(inputReader) trns, err := plugin.Handle(context.Background(), tCtx) - assert.Error(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + assert.Nil(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhaseRetryableFailure) err = plugin.Abort(context.Background(), tCtx) assert.Nil(t, err) }) @@ -155,8 +155,8 @@ func TestEndToEnd(t *testing.T) { assert.NoError(t, err) trns, err := plugin.Handle(context.Background(), tCtx) - assert.Error(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + assert.Nil(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhaseRetryableFailure) }) t.Run("failed to read inputs", func(t *testing.T) { @@ -176,8 +176,8 @@ func TestEndToEnd(t *testing.T) { assert.NoError(t, err) trns, err := plugin.Handle(context.Background(), tCtx) - assert.Error(t, err) - assert.Equal(t, trns.Info().Phase(), core.PhaseUndefined) + assert.Nil(t, err) + assert.Equal(t, trns.Info().Phase(), core.PhaseRetryableFailure) }) } diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go index 651892f672..d18f4ba79e 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/integration_test.go @@ -109,8 +109,8 @@ func newFakeDatabricksServer() *httptest.Server { runID := "065168461" jobID := "019e7546" return httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - if request.URL.Path == fmt.Sprintf("%v/submit", databricksAPI) && request.Method == post { - writer.WriteHeader(202) + if request.URL.Path == fmt.Sprintf("%v/submit", databricksAPI) && request.Method == http.MethodPost { + writer.WriteHeader(http.StatusOK) bytes := []byte(fmt.Sprintf(`{ "run_id": "%v" }`, runID)) @@ -118,8 +118,8 @@ func newFakeDatabricksServer() *httptest.Server { return } - if request.URL.Path == fmt.Sprintf("%v/get", databricksAPI) && request.Method == get { - writer.WriteHeader(200) + if request.URL.Path == fmt.Sprintf("%v/get", databricksAPI) && request.Method == http.MethodGet { + writer.WriteHeader(http.StatusOK) bytes := []byte(fmt.Sprintf(`{ "job_id": "%v", "state": {"state_message": "execution in progress.", "life_cycle_state": "TERMINATED", "result_state": "SUCCESS"} @@ -128,12 +128,12 @@ func newFakeDatabricksServer() *httptest.Server { return } - if request.URL.Path == fmt.Sprintf("%v/cancel", databricksAPI) && request.Method == post { - writer.WriteHeader(200) + if request.URL.Path == fmt.Sprintf("%v/cancel", databricksAPI) && request.Method == http.MethodPost { + writer.WriteHeader(http.StatusOK) return } - writer.WriteHeader(500) + writer.WriteHeader(http.StatusInternalServerError) })) } diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go index 3e9b37ea93..5ebe1d0075 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin.go @@ -6,6 +6,7 @@ import ( "encoding/gob" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" "time" @@ -26,20 +27,20 @@ import ( ) const ( - ErrSystem errors.ErrorCode = "System" - post string = "POST" - get string = "GET" - databricksAPI string = "/api/2.1/jobs/runs" - newCluster string = "new_cluster" - dockerImage string = "docker_image" - sparkConfig string = "spark_conf" - sparkPythonTask string = "spark_python_task" - pythonFile string = "python_file" - parameters string = "parameters" - url string = "url" + create string = "create" + get string = "get" + cancel string = "cancel" + databricksAPI string = "/api/2.1/jobs/runs" + newCluster string = "new_cluster" + dockerImage string = "docker_image" + sparkConfig string = "spark_conf" + sparkPythonTask string = "spark_python_task" + pythonFile string = "python_file" + parameters string = "parameters" + url string = "url" ) -// for mocking/testing purposes, and we'll override this method +// HTTPClient for mocking/testing purposes, and we'll override this method type HTTPClient interface { Do(req *http.Request) (*http.Response, error) } @@ -127,60 +128,41 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR } databricksJob[sparkPythonTask] = map[string]interface{}{pythonFile: p.cfg.EntrypointFile, parameters: modifiedArgs} - req, err := buildRequest(post, databricksJob, p.cfg.databricksEndpoint, - p.cfg.DatabricksInstance, token, "", false) + data, err := p.sendRequest(create, databricksJob, token, "") if err != nil { return nil, nil, err } - resp, err := p.client.Do(req) - if err != nil { - return nil, nil, err - } - defer resp.Body.Close() - data, err := buildResponse(resp) - if err != nil { - return nil, nil, err - } - if data["run_id"] == "" { - return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, - "Unable to fetch statementHandle from http response") + if _, ok := data["run_id"]; !ok { + return nil, nil, errors.Errorf("CorruptedPluginState", "can't get the run_id") } runID := fmt.Sprintf("%.0f", data["run_id"]) - return ResourceMetaWrapper{runID, p.cfg.DatabricksInstance, token}, - ResourceWrapper{StatusCode: resp.StatusCode}, nil + return ResourceMetaWrapper{runID, p.cfg.DatabricksInstance, token}, nil, nil } func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) - req, err := buildRequest(get, nil, p.cfg.databricksEndpoint, - p.cfg.DatabricksInstance, exec.Token, exec.RunID, false) - if err != nil { - logger.Errorf(ctx, "Failed to build databricks job request [%v]", err) - return nil, err - } - resp, err := p.client.Do(req) - logger.Debugf(ctx, "Get databricks job response", "resp", resp) + res, err := p.sendRequest(get, nil, exec.Token, exec.RunID) if err != nil { - logger.Errorf(ctx, "Failed to get databricks job status [%v]", resp) return nil, err } - defer resp.Body.Close() - data, err := buildResponse(resp) - if err != nil { - return nil, err - } - if data == nil || data["state"] == nil { + if _, ok := res["state"]; !ok { return nil, errors.Errorf("CorruptedPluginState", "can't get the job state") } - jobState := data["state"].(map[string]interface{}) + jobState := res["state"].(map[string]interface{}) + jobID := fmt.Sprintf("%.0f", res["job_id"]) message := fmt.Sprintf("%s", jobState["state_message"]) - jobID := fmt.Sprintf("%.0f", data["job_id"]) lifeCycleState := fmt.Sprintf("%s", jobState["life_cycle_state"]) - resultState := fmt.Sprintf("%s", jobState["result_state"]) + var resultState string + if _, ok := jobState["result_state"]; !ok { + // The result_state is not available until the job is finished. + // https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + resultState = "" + } else { + resultState = fmt.Sprintf("%s", jobState["result_state"]) + } return ResourceWrapper{ - StatusCode: resp.StatusCode, JobID: jobID, LifeCycleState: lifeCycleState, ResultState: resultState, @@ -193,63 +175,111 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) - req, err := buildRequest(post, nil, p.cfg.databricksEndpoint, - p.cfg.DatabricksInstance, exec.Token, exec.RunID, true) + _, err := p.sendRequest(cancel, nil, exec.Token, exec.RunID) if err != nil { return err } + logger.Info(ctx, "Deleted Databricks job execution.") + + return nil +} + +func (p Plugin) sendRequest(method string, databricksJob map[string]interface{}, token string, runID string) (map[string]interface{}, error) { + var databricksURL string + // for mocking/testing purposes + if p.cfg.databricksEndpoint == "" { + databricksURL = fmt.Sprintf("https://%v%v", p.cfg.DatabricksInstance, databricksAPI) + } else { + databricksURL = fmt.Sprintf("%v%v", p.cfg.databricksEndpoint, databricksAPI) + } + + // build the request spec + var body io.Reader + var httpMethod string + switch method { + case create: + databricksURL += "/submit" + mJSON, err := json.Marshal(databricksJob) + if err != nil { + return nil, fmt.Errorf("failed to marshal the job request: %v", err) + } + body = bytes.NewBuffer(mJSON) + httpMethod = http.MethodPost + case get: + databricksURL += "/get?run_id=" + runID + httpMethod = http.MethodGet + case cancel: + databricksURL += "/cancel" + body = bytes.NewBuffer([]byte(fmt.Sprintf("{ \"run_id\": %v }", runID))) + httpMethod = http.MethodPost + } + + req, err := http.NewRequest(httpMethod, databricksURL, body) + if err != nil { + return nil, err + } + req.Header.Add("Authorization", "Bearer "+token) + req.Header.Add("Content-Type", "application/json") + + // Send the request resp, err := p.client.Do(req) if err != nil { - return err + return nil, fmt.Errorf("failed to send request to Databricks platform with err: [%v]", err) } defer resp.Body.Close() - logger.Infof(ctx, "Deleted query execution [%v]", resp) - return nil + // Parse the response body + responseBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var data map[string]interface{} + err = json.Unmarshal(responseBody, &data) + if err != nil { + return nil, fmt.Errorf("failed to parse response with err: [%v]", err) + } + if resp.StatusCode != http.StatusOK { + message := "" + if v, ok := data["message"]; ok { + message = v.(string) + } + return nil, fmt.Errorf("failed to %v Databricks job with error [%v]", method, message) + } + return data, nil } func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) resource := taskCtx.Resource().(ResourceWrapper) message := resource.Message - statusCode := resource.StatusCode jobID := resource.JobID lifeCycleState := resource.LifeCycleState resultState := resource.ResultState - if statusCode == 0 { - return core.PhaseInfoUndefined, errors.Errorf(ErrSystem, "No Status field set.") - } - taskInfo := createTaskInfo(exec.RunID, jobID, exec.DatabricksInstance) - switch statusCode { - // Job response format. https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit - case http.StatusAccepted: - return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil - case http.StatusOK: - if lifeCycleState == "TERMINATED" || lifeCycleState == "TERMINATING" || lifeCycleState == "INTERNAL_ERROR" { - if resultState == "SUCCESS" { - if err := writeOutput(ctx, taskCtx); err != nil { - pluginsCore.PhaseInfoFailure(string(rune(statusCode)), "failed to write output", taskInfo) - } - return pluginsCore.PhaseInfoSuccess(taskInfo), nil + switch lifeCycleState { + // Job response format. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runlifecyclestate + case "PENDING": + return core.PhaseInfoInitializing(time.Now(), core.DefaultPhaseVersion, message, taskInfo), nil + case "RUNNING": + fallthrough + case "TERMINATING": + return core.PhaseInfoRunning(core.DefaultPhaseVersion, taskInfo), nil + case "TERMINATED": + if resultState == "SUCCESS" { + // Result state details. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate + if err := writeOutput(ctx, taskCtx); err != nil { + return core.PhaseInfoFailure(string(rune(http.StatusInternalServerError)), "failed to write output", taskInfo), nil } - return pluginsCore.PhaseInfoFailure(string(rune(statusCode)), message, taskInfo), nil - } - - if lifeCycleState == "PENDING" { - return core.PhaseInfoInitializing(time.Now(), core.DefaultPhaseVersion, message, taskInfo), nil + return core.PhaseInfoSuccess(taskInfo), nil } - - return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil - case http.StatusBadRequest: - fallthrough - case http.StatusInternalServerError: - fallthrough - case http.StatusUnauthorized: - return pluginsCore.PhaseInfoFailure(string(rune(statusCode)), message, taskInfo), nil + return core.PhaseInfoFailure(pluginErrors.TaskFailedWithError, message, taskInfo), nil + case "SKIPPED": + return core.PhaseInfoFailure(string(rune(http.StatusConflict)), message, taskInfo), nil + case "INTERNAL_ERROR": + return core.PhaseInfoFailure(string(rune(http.StatusInternalServerError)), message, taskInfo), nil } - return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", statusCode) + return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", lifeCycleState) } func writeOutput(ctx context.Context, taskCtx webapi.StatusContext) error { @@ -266,66 +296,6 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext) error { return taskCtx.OutputWriter().Put(ctx, outputReader) } -func buildRequest( - method string, - databricksJob map[string]interface{}, - databricksEndpoint string, - databricksInstance string, - token string, - runID string, - isCancel bool, -) (*http.Request, error) { - var databricksURL string - // for mocking/testing purposes - if databricksEndpoint == "" { - databricksURL = fmt.Sprintf("https://%v%v", databricksInstance, databricksAPI) - } else { - databricksURL = fmt.Sprintf("%v%v", databricksEndpoint, databricksAPI) - } - - var data []byte - var req *http.Request - var err error - if isCancel { - databricksURL += "/cancel" - data = []byte(fmt.Sprintf("{ \"run_id\": %v }", runID)) - } else if method == post { - databricksURL += "/submit" - mJSON, err := json.Marshal(databricksJob) - if err != nil { - return nil, err - } - data = []byte(string(mJSON)) - } else { - databricksURL += "/get?run_id=" + runID - } - - if data == nil { - req, err = http.NewRequest(method, databricksURL, nil) - } else { - req, err = http.NewRequest(method, databricksURL, bytes.NewBuffer(data)) - } - if err != nil { - return nil, err - } - req.Header.Add("Authorization", "Bearer "+token) - req.Header.Add("Content-Type", "application/json") - return req, nil -} - -func buildResponse(response *http.Response) (map[string]interface{}, error) { - responseBody, err := ioutil.ReadAll(response.Body) - if err != nil { - return nil, err - } - var data map[string]interface{} - err = json.Unmarshal(responseBody, &data) - if err != nil { - return nil, err - } - return data, nil -} - func createTaskInfo(runID, jobID, databricksInstance string) *core.TaskInfo { timeNow := time.Now() diff --git a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go index fda3ab61b0..228914af93 100644 --- a/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/databricks/plugin_test.go @@ -2,10 +2,8 @@ package databricks import ( "context" - "encoding/json" - "io/ioutil" + "errors" "net/http" - "strings" "testing" "time" @@ -13,21 +11,22 @@ import ( pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flytestdlib/ioutils" "github.com/flyteorg/flyte/flytestdlib/promutils" ) type MockClient struct { + MockDo func(req *http.Request) (*http.Response, error) +} + +func (m MockClient) Do(req *http.Request) (*http.Response, error) { + return m.MockDo(req) } var ( - MockDo func(req *http.Request) (*http.Response, error) testInstance = "test-account.cloud.databricks.com" ) -func (m *MockClient) Do(req *http.Request) (*http.Response, error) { - return MockDo(req) -} - func TestPlugin(t *testing.T) { fakeSetupContext := pluginCoreMocks.SetupContext{} fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test")) @@ -35,7 +34,9 @@ func TestPlugin(t *testing.T) { plugin := Plugin{ metricScope: fakeSetupContext.MetricsScope(), cfg: GetConfig(), - client: &MockClient{}, + client: &MockClient{func(req *http.Request) (*http.Response, error) { + return nil, nil + }}, } t.Run("get config", func(t *testing.T) { cfg := defaultConfig @@ -53,61 +54,101 @@ func TestPlugin(t *testing.T) { }) } -func TestCreateTaskInfo(t *testing.T) { - t.Run("create task info", func(t *testing.T) { - taskInfo := createTaskInfo("run-id", "job-id", testInstance) +func TestSendRequest(t *testing.T) { + fakeSetupContext := pluginCoreMocks.SetupContext{} + fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test1")) + databricksJob := map[string]interface{}{"sparkConfig": map[string]interface{}{"sparkVersion": "7.3.x-scala2.12"}} + token := "token" - assert.Equal(t, 1, len(taskInfo.Logs)) - assert.Equal(t, taskInfo.Logs[0].Uri, "https://test-account.cloud.databricks.com/#job/job-id/run/run-id") - assert.Equal(t, taskInfo.Logs[0].Name, "Databricks Console") + plugin := Plugin{ + metricScope: fakeSetupContext.MetricsScope(), + cfg: GetConfig(), + client: &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutils.NewBytesReadCloser([]byte(`{"id":"someID","data":"someData"}`)), + }, nil + }}, + } + + t.Run("create a Databricks job", func(t *testing.T) { + data, err := plugin.sendRequest(create, databricksJob, token, "") + assert.NotNil(t, data) + assert.Equal(t, "someID", data["id"]) + assert.Equal(t, "someData", data["data"]) + assert.Nil(t, err) }) -} -func TestBuildRequest(t *testing.T) { - token := "test-token" - runID := "019e70eb" - databricksEndpoint := "" - databricksURL := "https://" + testInstance + "/api/2.1/jobs/runs" - t.Run("build http request for submitting a databricks job", func(t *testing.T) { - req, err := buildRequest(post, nil, databricksEndpoint, testInstance, token, runID, false) - header := http.Header{} - header.Add("Authorization", "Bearer "+token) - header.Add("Content-Type", "application/json") + t.Run("failed to create a Databricks job", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: ioutils.NewBytesReadCloser([]byte(`{"message":"failed"}`)), + }, nil + }} + data, err := plugin.sendRequest(create, databricksJob, token, "") + assert.Nil(t, data) + assert.Equal(t, err.Error(), "failed to create Databricks job with error [failed]") + }) - assert.NoError(t, err) - assert.Equal(t, header, req.Header) - assert.Equal(t, databricksURL+"/submit", req.URL.String()) - assert.Equal(t, post, req.Method) + t.Run("failed to send request to Databricks", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return nil, errors.New("failed to send request") + }} + data, err := plugin.sendRequest(create, databricksJob, token, "") + assert.Nil(t, data) + assert.Equal(t, err.Error(), "failed to send request to Databricks platform with err: [failed to send request]") }) - t.Run("Get a databricks spark job status", func(t *testing.T) { - req, err := buildRequest(get, nil, databricksEndpoint, testInstance, token, runID, false) - assert.NoError(t, err) - assert.Equal(t, databricksURL+"/get?run_id="+runID, req.URL.String()) - assert.Equal(t, get, req.Method) + t.Run("failed to send request to Databricks", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutils.NewBytesReadCloser([]byte(`123`)), + }, nil + }} + data, err := plugin.sendRequest(create, databricksJob, token, "") + assert.Nil(t, data) + assert.Equal(t, err.Error(), "failed to parse response with err: [json: cannot unmarshal number into Go value of type map[string]interface {}]") }) - t.Run("Cancel a spark job", func(t *testing.T) { - req, err := buildRequest(post, nil, databricksEndpoint, testInstance, token, runID, true) - assert.NoError(t, err) - assert.Equal(t, databricksURL+"/cancel", req.URL.String()) - assert.Equal(t, post, req.Method) + t.Run("get a Databricks job", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodGet) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)), + }, nil + }} + data, err := plugin.sendRequest(get, databricksJob, token, "") + assert.NotNil(t, data) + assert.Nil(t, err) + }) + + t.Run("cancel a Databricks job", func(t *testing.T) { + plugin.client = &MockClient{MockDo: func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Method, http.MethodPost) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutils.NewBytesReadCloser([]byte(`{"message":"ok"}`)), + }, nil + }} + data, err := plugin.sendRequest(cancel, databricksJob, token, "") + assert.NotNil(t, data) + assert.Nil(t, err) }) } -func TestBuildResponse(t *testing.T) { - t.Run("build http response", func(t *testing.T) { - bodyStr := `{"job_id":"019c06a4-0000", "message":"Statement executed successfully."}` - responseBody := ioutil.NopCloser(strings.NewReader(bodyStr)) - response := &http.Response{Body: responseBody} - actualData, err := buildResponse(response) - assert.NoError(t, err) +func TestCreateTaskInfo(t *testing.T) { + t.Run("create task info", func(t *testing.T) { + taskInfo := createTaskInfo("run-id", "job-id", testInstance) - bodyByte, err := ioutil.ReadAll(strings.NewReader(bodyStr)) - assert.NoError(t, err) - var expectedData map[string]interface{} - err = json.Unmarshal(bodyByte, &expectedData) - assert.NoError(t, err) - assert.Equal(t, expectedData, actualData) + assert.Equal(t, 1, len(taskInfo.Logs)) + assert.Equal(t, taskInfo.Logs[0].Uri, "https://test-account.cloud.databricks.com/#job/job-id/run/run-id") + assert.Equal(t, taskInfo.Logs[0].Name, "Databricks Console") }) }