diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index 7c64ccbc3c5..aab62c12f98 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -403,15 +403,33 @@ func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admi return workflowExecConfig, nil } -func (m *ExecutionManager) getClusterAssignment(ctx context.Context, request *admin.ExecutionCreateRequest) ( - *admin.ClusterAssignment, error) { - if request.Spec.ClusterAssignment != nil { - return request.Spec.ClusterAssignment, nil +func (m *ExecutionManager) getClusterAssignment(ctx context.Context, req *admin.ExecutionCreateRequest) (*admin.ClusterAssignment, error) { + storedAssignment, err := m.fetchClusterAssignment(ctx, req.Org, req.Project, req.Domain) + if err != nil { + return nil, err + } + + if req.GetSpec().GetClusterAssignment() == nil { + return storedAssignment, nil + } + + if storedAssignment == nil { + return req.GetSpec().GetClusterAssignment(), nil } + reqPool := req.Spec.ClusterAssignment.GetClusterPoolName() + storedPool := storedAssignment.GetClusterPoolName() + if reqPool != storedPool { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "execution with project %q and domain %q cannot run on cluster pool %q, because its configured to run on pool %q", req.Project, req.Domain, reqPool, storedPool) + } + + return storedAssignment, nil +} + +func (m *ExecutionManager) fetchClusterAssignment(ctx context.Context, org, project, domain string) (*admin.ClusterAssignment, error) { resource, err := m.resourceManager.GetResource(ctx, interfaces.ResourceRequest{ - Project: request.Project, - Domain: request.Domain, + Project: project, + Domain: domain, ResourceType: admin.MatchableResource_CLUSTER_ASSIGNMENT, }) if err != nil && !errors.IsDoesNotExistError(err) { @@ -421,11 +439,13 @@ func (m *ExecutionManager) getClusterAssignment(ctx context.Context, request *ad if resource != nil && resource.Attributes.GetClusterAssignment() != nil { return resource.Attributes.GetClusterAssignment(), nil } - clusterPoolAssignment := m.config.ClusterPoolAssignmentConfiguration().GetClusterPoolAssignments()[request.GetDomain()] - return &admin.ClusterAssignment{ - ClusterPoolName: clusterPoolAssignment.Pool, - }, nil + var clusterAssignment *admin.ClusterAssignment + domainAssignment := m.config.ClusterPoolAssignmentConfiguration().GetClusterPoolAssignments()[domain] + if domainAssignment.Pool != "" { + clusterAssignment = &admin.ClusterAssignment{ClusterPoolName: domainAssignment.Pool} + } + return clusterAssignment, nil } func (m *ExecutionManager) launchSingleTaskExecution( diff --git a/flyteadmin/pkg/manager/impl/execution_manager_test.go b/flyteadmin/pkg/manager/impl/execution_manager_test.go index 93d327bd534..1c8c2b9f60c 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/execution_manager_test.go @@ -304,8 +304,7 @@ func TestCreateExecution(t *testing.T) { }} repository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( ctx context.Context, projectID string) (models.Project, error) { - return transformers.CreateProjectModel(&admin.Project{ - Labels: &labels}), nil + return transformers.CreateProjectModel(&admin.Project{Labels: &labels}), nil } clusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"} @@ -382,8 +381,6 @@ func TestCreateExecution(t *testing.T) { mockConfig := getMockExecutionsConfigProvider() mockConfig.(*runtimeMocks.MockConfigurationProvider).AddQualityOfServiceConfiguration(qosProvider) - - execManager := NewExecutionManager(repository, r, mockConfig, getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, &mockPublisher, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) request := testutils.GetExecutionRequest() request.Spec.Metadata = &admin.ExecutionMetadata{ Principal: "unused - populated from authenticated context", @@ -392,16 +389,18 @@ func TestCreateExecution(t *testing.T) { request.Spec.ClusterAssignment = &clusterAssignment request.Spec.ExecutionClusterLabel = &admin.ExecutionClusterLabel{Value: executionClusterLabel} + execManager := NewExecutionManager(repository, r, mockConfig, getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, &mockPublisher, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + identity, err := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil) assert.NoError(t, err) ctx := identity.WithContext(context.Background()) response, err := execManager.CreateExecution(ctx, request, requestedAt) - assert.Nil(t, err) + assert.NoError(t, err) expectedResponse := &admin.ExecutionCreateResponse{ Id: &executionIdentifier, } - assert.Nil(t, err) + assert.NoError(t, err) assert.True(t, proto.Equal(expectedResponse.Id, response.Id)) // TODO: Check for offloaded inputs @@ -632,7 +631,6 @@ func TestCreateExecutionInCompatibleInputs(t *testing.T) { } func TestCreateExecutionPropellerFailure(t *testing.T) { - clusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"} repository := getMockRepositoryForExecTest() setDefaultLpCallbackForExecTest(repository) expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "ABC") @@ -666,7 +664,6 @@ func TestCreateExecutionPropellerFailure(t *testing.T) { Principal: "unused - populated from authenticated context", } request.Spec.RawOutputDataConfig = &admin.RawOutputDataConfig{OutputLocationPrefix: rawOutput} - request.Spec.ClusterAssignment = &clusterAssignment identity, err := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil) assert.NoError(t, err) @@ -5467,18 +5464,6 @@ func TestGetClusterAssignment(t *testing.T) { assert.NoError(t, err) assert.True(t, proto.Equal(ca, &clusterAssignment)) }) - t.Run("value from request", func(t *testing.T) { - reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "swimming-pool"} - ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{ - Project: workflowIdentifier.Project, - Domain: workflowIdentifier.Domain, - Spec: &admin.ExecutionSpec{ - ClusterAssignment: &reqClusterAssignment, - }, - }) - assert.NoError(t, err) - assert.True(t, proto.Equal(ca, &reqClusterAssignment)) - }) t.Run("value from config", func(t *testing.T) { customCP := "my_cp" clusterPoolAsstProvider := &runtimeIFaceMocks.ClusterPoolAssignmentConfiguration{} @@ -5503,6 +5488,51 @@ func TestGetClusterAssignment(t *testing.T) { assert.NoError(t, err) assert.Equal(t, customCP, ca.GetClusterPoolName()) }) + t.Run("value from request matches value from config", func(t *testing.T) { + reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"} + ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{ + ClusterAssignment: &reqClusterAssignment, + }, + }) + assert.NoError(t, err) + assert.True(t, proto.Equal(ca, &reqClusterAssignment)) + }) + t.Run("no value in DB nor in config, takes value from request", func(t *testing.T) { + mockConfig := getMockExecutionsConfigProvider() + + executionManager := ExecutionManager{ + resourceManager: &managerMocks.MockResourceManager{}, + config: mockConfig, + } + + reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"} + ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{ + ClusterAssignment: &reqClusterAssignment, + }, + }) + assert.NoError(t, err) + assert.True(t, proto.Equal(ca, &reqClusterAssignment)) + }) + t.Run("value from request doesn't match value from config", func(t *testing.T) { + reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "swimming-pool"} + _, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{ + ClusterAssignment: &reqClusterAssignment, + }, + }) + st, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Equal(t, `execution with project "project" and domain "domain" cannot run on cluster pool "swimming-pool", because its configured to run on pool "gpu"`, st.Message()) + }) } func TestResolvePermissions(t *testing.T) { diff --git a/flyteadmin/pkg/manager/interfaces/resource.go b/flyteadmin/pkg/manager/interfaces/resource.go index 928a910d6ca..3d586a59c91 100644 --- a/flyteadmin/pkg/manager/interfaces/resource.go +++ b/flyteadmin/pkg/manager/interfaces/resource.go @@ -6,6 +6,8 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" ) +//go:generate mockery -name ResourceInterface -output=../mocks -case=underscore + // ResourceInterface manages project, domain and workflow -specific attributes. type ResourceInterface interface { ListAll(ctx context.Context, request *admin.ListMatchableAttributesRequest) ( diff --git a/flyteadmin/pkg/manager/mocks/resource_interface.go b/flyteadmin/pkg/manager/mocks/resource_interface.go new file mode 100644 index 00000000000..c1b416eb9d1 --- /dev/null +++ b/flyteadmin/pkg/manager/mocks/resource_interface.go @@ -0,0 +1,469 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + + interfaces "github.com/flyteorg/flyte/flyteadmin/pkg/manager/interfaces" + + mock "github.com/stretchr/testify/mock" +) + +// ResourceInterface is an autogenerated mock type for the ResourceInterface type +type ResourceInterface struct { + mock.Mock +} + +type ResourceInterface_DeleteProjectAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_DeleteProjectAttributes) Return(_a0 *admin.ProjectAttributesDeleteResponse, _a1 error) *ResourceInterface_DeleteProjectAttributes { + return &ResourceInterface_DeleteProjectAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnDeleteProjectAttributes(ctx context.Context, request *admin.ProjectAttributesDeleteRequest) *ResourceInterface_DeleteProjectAttributes { + c_call := _m.On("DeleteProjectAttributes", ctx, request) + return &ResourceInterface_DeleteProjectAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnDeleteProjectAttributesMatch(matchers ...interface{}) *ResourceInterface_DeleteProjectAttributes { + c_call := _m.On("DeleteProjectAttributes", matchers...) + return &ResourceInterface_DeleteProjectAttributes{Call: c_call} +} + +// DeleteProjectAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) DeleteProjectAttributes(ctx context.Context, request *admin.ProjectAttributesDeleteRequest) (*admin.ProjectAttributesDeleteResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.ProjectAttributesDeleteResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.ProjectAttributesDeleteRequest) *admin.ProjectAttributesDeleteResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ProjectAttributesDeleteResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.ProjectAttributesDeleteRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_DeleteProjectDomainAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_DeleteProjectDomainAttributes) Return(_a0 *admin.ProjectDomainAttributesDeleteResponse, _a1 error) *ResourceInterface_DeleteProjectDomainAttributes { + return &ResourceInterface_DeleteProjectDomainAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnDeleteProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesDeleteRequest) *ResourceInterface_DeleteProjectDomainAttributes { + c_call := _m.On("DeleteProjectDomainAttributes", ctx, request) + return &ResourceInterface_DeleteProjectDomainAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnDeleteProjectDomainAttributesMatch(matchers ...interface{}) *ResourceInterface_DeleteProjectDomainAttributes { + c_call := _m.On("DeleteProjectDomainAttributes", matchers...) + return &ResourceInterface_DeleteProjectDomainAttributes{Call: c_call} +} + +// DeleteProjectDomainAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) DeleteProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesDeleteRequest) (*admin.ProjectDomainAttributesDeleteResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.ProjectDomainAttributesDeleteResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.ProjectDomainAttributesDeleteRequest) *admin.ProjectDomainAttributesDeleteResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ProjectDomainAttributesDeleteResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.ProjectDomainAttributesDeleteRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_DeleteWorkflowAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_DeleteWorkflowAttributes) Return(_a0 *admin.WorkflowAttributesDeleteResponse, _a1 error) *ResourceInterface_DeleteWorkflowAttributes { + return &ResourceInterface_DeleteWorkflowAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnDeleteWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesDeleteRequest) *ResourceInterface_DeleteWorkflowAttributes { + c_call := _m.On("DeleteWorkflowAttributes", ctx, request) + return &ResourceInterface_DeleteWorkflowAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnDeleteWorkflowAttributesMatch(matchers ...interface{}) *ResourceInterface_DeleteWorkflowAttributes { + c_call := _m.On("DeleteWorkflowAttributes", matchers...) + return &ResourceInterface_DeleteWorkflowAttributes{Call: c_call} +} + +// DeleteWorkflowAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) DeleteWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesDeleteRequest) (*admin.WorkflowAttributesDeleteResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.WorkflowAttributesDeleteResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.WorkflowAttributesDeleteRequest) *admin.WorkflowAttributesDeleteResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.WorkflowAttributesDeleteResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.WorkflowAttributesDeleteRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_GetProjectAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_GetProjectAttributes) Return(_a0 *admin.ProjectAttributesGetResponse, _a1 error) *ResourceInterface_GetProjectAttributes { + return &ResourceInterface_GetProjectAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnGetProjectAttributes(ctx context.Context, request *admin.ProjectAttributesGetRequest) *ResourceInterface_GetProjectAttributes { + c_call := _m.On("GetProjectAttributes", ctx, request) + return &ResourceInterface_GetProjectAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnGetProjectAttributesMatch(matchers ...interface{}) *ResourceInterface_GetProjectAttributes { + c_call := _m.On("GetProjectAttributes", matchers...) + return &ResourceInterface_GetProjectAttributes{Call: c_call} +} + +// GetProjectAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) GetProjectAttributes(ctx context.Context, request *admin.ProjectAttributesGetRequest) (*admin.ProjectAttributesGetResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.ProjectAttributesGetResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.ProjectAttributesGetRequest) *admin.ProjectAttributesGetResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ProjectAttributesGetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.ProjectAttributesGetRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_GetProjectDomainAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_GetProjectDomainAttributes) Return(_a0 *admin.ProjectDomainAttributesGetResponse, _a1 error) *ResourceInterface_GetProjectDomainAttributes { + return &ResourceInterface_GetProjectDomainAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnGetProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesGetRequest) *ResourceInterface_GetProjectDomainAttributes { + c_call := _m.On("GetProjectDomainAttributes", ctx, request) + return &ResourceInterface_GetProjectDomainAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnGetProjectDomainAttributesMatch(matchers ...interface{}) *ResourceInterface_GetProjectDomainAttributes { + c_call := _m.On("GetProjectDomainAttributes", matchers...) + return &ResourceInterface_GetProjectDomainAttributes{Call: c_call} +} + +// GetProjectDomainAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) GetProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesGetRequest) (*admin.ProjectDomainAttributesGetResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.ProjectDomainAttributesGetResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.ProjectDomainAttributesGetRequest) *admin.ProjectDomainAttributesGetResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ProjectDomainAttributesGetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.ProjectDomainAttributesGetRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_GetResource struct { + *mock.Call +} + +func (_m ResourceInterface_GetResource) Return(_a0 *interfaces.ResourceResponse, _a1 error) *ResourceInterface_GetResource { + return &ResourceInterface_GetResource{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnGetResource(ctx context.Context, request interfaces.ResourceRequest) *ResourceInterface_GetResource { + c_call := _m.On("GetResource", ctx, request) + return &ResourceInterface_GetResource{Call: c_call} +} + +func (_m *ResourceInterface) OnGetResourceMatch(matchers ...interface{}) *ResourceInterface_GetResource { + c_call := _m.On("GetResource", matchers...) + return &ResourceInterface_GetResource{Call: c_call} +} + +// GetResource provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) GetResource(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *interfaces.ResourceResponse + if rf, ok := ret.Get(0).(func(context.Context, interfaces.ResourceRequest) *interfaces.ResourceResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*interfaces.ResourceResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, interfaces.ResourceRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_GetWorkflowAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_GetWorkflowAttributes) Return(_a0 *admin.WorkflowAttributesGetResponse, _a1 error) *ResourceInterface_GetWorkflowAttributes { + return &ResourceInterface_GetWorkflowAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnGetWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesGetRequest) *ResourceInterface_GetWorkflowAttributes { + c_call := _m.On("GetWorkflowAttributes", ctx, request) + return &ResourceInterface_GetWorkflowAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnGetWorkflowAttributesMatch(matchers ...interface{}) *ResourceInterface_GetWorkflowAttributes { + c_call := _m.On("GetWorkflowAttributes", matchers...) + return &ResourceInterface_GetWorkflowAttributes{Call: c_call} +} + +// GetWorkflowAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) GetWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesGetRequest) (*admin.WorkflowAttributesGetResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.WorkflowAttributesGetResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.WorkflowAttributesGetRequest) *admin.WorkflowAttributesGetResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.WorkflowAttributesGetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.WorkflowAttributesGetRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_ListAll struct { + *mock.Call +} + +func (_m ResourceInterface_ListAll) Return(_a0 *admin.ListMatchableAttributesResponse, _a1 error) *ResourceInterface_ListAll { + return &ResourceInterface_ListAll{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnListAll(ctx context.Context, request *admin.ListMatchableAttributesRequest) *ResourceInterface_ListAll { + c_call := _m.On("ListAll", ctx, request) + return &ResourceInterface_ListAll{Call: c_call} +} + +func (_m *ResourceInterface) OnListAllMatch(matchers ...interface{}) *ResourceInterface_ListAll { + c_call := _m.On("ListAll", matchers...) + return &ResourceInterface_ListAll{Call: c_call} +} + +// ListAll provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) ListAll(ctx context.Context, request *admin.ListMatchableAttributesRequest) (*admin.ListMatchableAttributesResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.ListMatchableAttributesResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.ListMatchableAttributesRequest) *admin.ListMatchableAttributesResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ListMatchableAttributesResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.ListMatchableAttributesRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_UpdateProjectAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_UpdateProjectAttributes) Return(_a0 *admin.ProjectAttributesUpdateResponse, _a1 error) *ResourceInterface_UpdateProjectAttributes { + return &ResourceInterface_UpdateProjectAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnUpdateProjectAttributes(ctx context.Context, request *admin.ProjectAttributesUpdateRequest) *ResourceInterface_UpdateProjectAttributes { + c_call := _m.On("UpdateProjectAttributes", ctx, request) + return &ResourceInterface_UpdateProjectAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnUpdateProjectAttributesMatch(matchers ...interface{}) *ResourceInterface_UpdateProjectAttributes { + c_call := _m.On("UpdateProjectAttributes", matchers...) + return &ResourceInterface_UpdateProjectAttributes{Call: c_call} +} + +// UpdateProjectAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) UpdateProjectAttributes(ctx context.Context, request *admin.ProjectAttributesUpdateRequest) (*admin.ProjectAttributesUpdateResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.ProjectAttributesUpdateResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.ProjectAttributesUpdateRequest) *admin.ProjectAttributesUpdateResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ProjectAttributesUpdateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.ProjectAttributesUpdateRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_UpdateProjectDomainAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_UpdateProjectDomainAttributes) Return(_a0 *admin.ProjectDomainAttributesUpdateResponse, _a1 error) *ResourceInterface_UpdateProjectDomainAttributes { + return &ResourceInterface_UpdateProjectDomainAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnUpdateProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesUpdateRequest) *ResourceInterface_UpdateProjectDomainAttributes { + c_call := _m.On("UpdateProjectDomainAttributes", ctx, request) + return &ResourceInterface_UpdateProjectDomainAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnUpdateProjectDomainAttributesMatch(matchers ...interface{}) *ResourceInterface_UpdateProjectDomainAttributes { + c_call := _m.On("UpdateProjectDomainAttributes", matchers...) + return &ResourceInterface_UpdateProjectDomainAttributes{Call: c_call} +} + +// UpdateProjectDomainAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) UpdateProjectDomainAttributes(ctx context.Context, request *admin.ProjectDomainAttributesUpdateRequest) (*admin.ProjectDomainAttributesUpdateResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.ProjectDomainAttributesUpdateResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.ProjectDomainAttributesUpdateRequest) *admin.ProjectDomainAttributesUpdateResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.ProjectDomainAttributesUpdateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.ProjectDomainAttributesUpdateRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ResourceInterface_UpdateWorkflowAttributes struct { + *mock.Call +} + +func (_m ResourceInterface_UpdateWorkflowAttributes) Return(_a0 *admin.WorkflowAttributesUpdateResponse, _a1 error) *ResourceInterface_UpdateWorkflowAttributes { + return &ResourceInterface_UpdateWorkflowAttributes{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ResourceInterface) OnUpdateWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesUpdateRequest) *ResourceInterface_UpdateWorkflowAttributes { + c_call := _m.On("UpdateWorkflowAttributes", ctx, request) + return &ResourceInterface_UpdateWorkflowAttributes{Call: c_call} +} + +func (_m *ResourceInterface) OnUpdateWorkflowAttributesMatch(matchers ...interface{}) *ResourceInterface_UpdateWorkflowAttributes { + c_call := _m.On("UpdateWorkflowAttributes", matchers...) + return &ResourceInterface_UpdateWorkflowAttributes{Call: c_call} +} + +// UpdateWorkflowAttributes provides a mock function with given fields: ctx, request +func (_m *ResourceInterface) UpdateWorkflowAttributes(ctx context.Context, request *admin.WorkflowAttributesUpdateRequest) (*admin.WorkflowAttributesUpdateResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.WorkflowAttributesUpdateResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.WorkflowAttributesUpdateRequest) *admin.WorkflowAttributesUpdateResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.WorkflowAttributesUpdateResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.WorkflowAttributesUpdateRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flyteadmin/tests/execution_test.go b/flyteadmin/tests/execution_test.go index a3d226562b4..784f87ae0f3 100644 --- a/flyteadmin/tests/execution_test.go +++ b/flyteadmin/tests/execution_test.go @@ -12,6 +12,8 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" "github.com/flyteorg/flyte/flyteadmin/pkg/repositories" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" @@ -329,3 +331,226 @@ func TestListWorkflowExecutions_Pagination(t *testing.T) { assert.Equal(t, len(resp.Executions), 1) assert.Empty(t, resp.Token) } + +func TestGetWorkflowExecutionCounts(t *testing.T) { + truncateAllTablesForTestingOnly() + populateWorkflowExecutionsForTestingOnly() + + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + executionCountsResp, err := client.GetExecutionCounts(ctx, &admin.ExecutionCountsGetRequest{ + Project: "project1", + Domain: "domain1", + Filters: "gte(execution_created_at,2000-01-01T00:00:00Z)", + }) + assert.Nil(t, err) + assert.Equal(t, 3, len(executionCountsResp.ExecutionCounts)) + otherPhase := false + for _, item := range executionCountsResp.ExecutionCounts { + if item.Phase == core.WorkflowExecution_FAILED { + assert.Equal(t, int64(1), item.Count) + } else if item.Phase == core.WorkflowExecution_SUCCEEDED { + assert.Equal(t, int64(1), item.Count) + } else if item.Phase == core.WorkflowExecution_RUNNING { + assert.Equal(t, int64(2), item.Count) + } else { + otherPhase = true + } + } + assert.False(t, otherPhase) +} + +func TestGetWorkflowExecutionCounts_Filters(t *testing.T) { + truncateAllTablesForTestingOnly() + populateWorkflowExecutionsForTestingOnly() + + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + executionCountsResp, err := client.GetExecutionCounts(ctx, &admin.ExecutionCountsGetRequest{ + Project: "project1", + Domain: "domain1", + Filters: "gte(execution_created_at,2000-01-01T00:00:00Z)+eq(launch_plan_id, 1)", + }) + assert.Nil(t, err) + assert.Equal(t, 3, len(executionCountsResp.ExecutionCounts)) + otherPhase := false + for _, item := range executionCountsResp.ExecutionCounts { + if item.Phase == core.WorkflowExecution_SUCCEEDED { + assert.Equal(t, int64(1), item.Count) + } else if item.Phase == core.WorkflowExecution_RUNNING { + assert.Equal(t, int64(1), item.Count) + } else if item.Phase == core.WorkflowExecution_FAILED { + assert.Equal(t, int64(1), item.Count) + } else { + otherPhase = true + } + } + assert.False(t, otherPhase) +} + +func TestGetWorkflowExecutionCounts_PhaseFilter(t *testing.T) { + truncateAllTablesForTestingOnly() + populateWorkflowExecutionsForTestingOnly() + + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + executionCountsResp, err := client.GetExecutionCounts(ctx, &admin.ExecutionCountsGetRequest{ + Project: "project1", + Domain: "domain1", + Filters: "gte(execution_created_at,2000-01-01T00:00:00Z)+eq(phase,RUNNING)", + }) + assert.Nil(t, err) + assert.Equal(t, 1, len(executionCountsResp.ExecutionCounts)) + assert.Equal(t, core.WorkflowExecution_RUNNING, executionCountsResp.ExecutionCounts[0].Phase) + assert.Equal(t, int64(2), executionCountsResp.ExecutionCounts[0].Count) +} + +func TestGetRunningExecutionsCount(t *testing.T) { + truncateAllTablesForTestingOnly() + populateWorkflowExecutionsForTestingOnly() + + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + runningExecutionsCountResp, err := client.GetRunningExecutionsCount(ctx, &admin.RunningExecutionsCountGetRequest{ + Project: "project1", + Domain: "domain1", + }) + assert.Nil(t, err) + assert.Equal(t, int64(2), runningExecutionsCountResp.Count) +} + +func TestResolvedSpec(t *testing.T) { + truncateAllTablesForTestingOnly() + + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + insertTasksForTests(t, client) + createWorkflowReq := getWorkflowCreateRequest() + _, err := client.CreateWorkflow(ctx, &createWorkflowReq) + assert.Nil(t, err) + createLaunchPlanReq := getLaunchPlanCreateRequest(createWorkflowReq.Id) + _, err = client.CreateLaunchPlan(ctx, &createLaunchPlanReq) + + spec := &admin.ExecutionSpec{ + LaunchPlan: &launchPlanIdentifier, + Labels: &admin.Labels{Values: map[string]string{ + "foo": "bar", + }}, + OverwriteCache: true, + MaxParallelism: 10, + ClusterAssignment: &admin.ClusterAssignment{ + ClusterPoolName: "cluster", + }, + Metadata: &admin.ExecutionMetadata{}, + Annotations: &admin.Annotations{Values: map[string]string{"foo": "bar"}}, + SecurityContext: &core.SecurityContext{RunAs: &core.Identity{IamRole: "iamrole"}}, + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Gpu: "0", + Memory: "1Gi", + EphemeralStorage: "0", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + Gpu: "0", + Memory: "2Gi", + EphemeralStorage: "0", + }, + }, + } + _, err = client.CreateExecution(ctx, &admin.ExecutionCreateRequest{ + Project: launchPlanIdentifier.Project, + Domain: launchPlanIdentifier.Domain, + Name: launchPlanIdentifier.Name, + Spec: spec, + }) + require.NoError(t, err) + + resp, err := client.GetExecution(ctx, &admin.WorkflowExecutionGetRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: launchPlanIdentifier.Project, + Domain: launchPlanIdentifier.Domain, + Name: launchPlanIdentifier.Name, + }, + }) + assert.NoError(t, err) + spec.Interruptible = &wrapperspb.BoolValue{ + Value: false, + } + assert.True(t, proto.Equal(spec, resp.Closure.ResolvedSpec)) +} + +func TestSingleTaskResolvedSpec(t *testing.T) { + truncateAllTablesForTestingOnly() + + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + insertTasksForTests(t, client) + createWorkflowReq := getWorkflowCreateRequest() + _, err := client.CreateWorkflow(ctx, &createWorkflowReq) + assert.Nil(t, err) + createLaunchPlanReq := getLaunchPlanCreateRequest(createWorkflowReq.Id) + _, err = client.CreateLaunchPlan(ctx, &createLaunchPlanReq) + + spec := &admin.ExecutionSpec{ + LaunchPlan: &launchPlanIdentifier, + Labels: &admin.Labels{Values: map[string]string{ + "foo": "bar", + }}, + OverwriteCache: true, + MaxParallelism: 10, + ClusterAssignment: &admin.ClusterAssignment{ + ClusterPoolName: "cluster", + }, + Metadata: &admin.ExecutionMetadata{}, + Annotations: &admin.Annotations{Values: map[string]string{"foo": "bar"}}, + SecurityContext: &core.SecurityContext{RunAs: &core.Identity{IamRole: "iamrole"}}, + TaskResourceAttributes: &admin.TaskResourceAttributes{ + Defaults: &admin.TaskResourceSpec{ + Cpu: "1", + Gpu: "0", + Memory: "1Gi", + EphemeralStorage: "0", + }, + Limits: &admin.TaskResourceSpec{ + Cpu: "2", + Gpu: "0", + Memory: "2Gi", + EphemeralStorage: "0", + }, + }, + } + _, err = client.CreateExecution(ctx, &admin.ExecutionCreateRequest{ + Project: launchPlanIdentifier.Project, + Domain: launchPlanIdentifier.Domain, + Name: launchPlanIdentifier.Name, + Spec: spec, + }) + assert.Nil(t, err) + + resp, err := client.GetExecution(ctx, &admin.WorkflowExecutionGetRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: launchPlanIdentifier.Project, + Domain: launchPlanIdentifier.Domain, + Name: launchPlanIdentifier.Name, + }, + }) + assert.Nil(t, err) + spec.Interruptible = &wrapperspb.BoolValue{ + Value: false, + } + assert.True(t, proto.Equal(spec, resp.Closure.ResolvedSpec)) +}