From 01c35199bd248e59736ae5b697673f2b90846d50 Mon Sep 17 00:00:00 2001 From: Katrina Rogan Date: Wed, 30 Oct 2024 12:06:46 +0100 Subject: [PATCH] Remove unnecessary joins for list node and task execution entities in flyteadmin db queries (#5935) --- .../implementations/cloudevent_publisher.go | 2 +- flyteadmin/pkg/common/filters.go | 9 +- .../manager/impl/node_execution_manager.go | 17 +- .../impl/node_execution_manager_test.go | 145 +++++++++++++- flyteadmin/pkg/manager/impl/signal_manager.go | 2 +- .../manager/impl/task_execution_manager.go | 15 +- .../impl/task_execution_manager_test.go | 181 +++++++++++++++++- flyteadmin/pkg/manager/impl/util/filters.go | 14 +- .../pkg/manager/impl/util/filters_test.go | 12 +- .../pkg/repositories/gormimpl/common.go | 26 +-- .../gormimpl/node_execution_repo.go | 22 +-- .../gormimpl/node_execution_repo_test.go | 59 +++++- .../gormimpl/task_execution_repo.go | 43 +++-- .../gormimpl/task_execution_repo_test.go | 102 +++++++++- 14 files changed, 561 insertions(+), 88 deletions(-) diff --git a/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go b/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go index 228db852d0..7aaab0bb60 100644 --- a/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go +++ b/flyteadmin/pkg/async/cloudevent/implementations/cloudevent_publisher.go @@ -207,7 +207,7 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) (*admin.TaskExecution, error) { ctx = getNodeExecutionContext(ctx, nodeExecutionID) - identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID) + identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID, common.TaskExecution) if err != nil { return nil, err } diff --git a/flyteadmin/pkg/common/filters.go b/flyteadmin/pkg/common/filters.go index cf7987bdf5..57756e7820 100644 --- a/flyteadmin/pkg/common/filters.go +++ b/flyteadmin/pkg/common/filters.go @@ -96,6 +96,13 @@ var executionIdentifierFields = map[string]bool{ "name": true, } +// Entities that have special case handling for execution identifier fields. +var executionIdentifierEntities = map[Entity]bool{ + Execution: true, + NodeExecution: true, + TaskExecution: true, +} + var entityMetadataFields = map[string]bool{ "description": true, "state": true, @@ -253,7 +260,7 @@ func (f *inlineFilterImpl) GetGormJoinTableQueryExpr(tableName string) (GormQuer func customizeField(field string, entity Entity) string { // Execution identifier fields have to be customized because we differ from convention in those column names. - if entity == Execution && executionIdentifierFields[field] { + if executionIdentifierEntities[entity] && executionIdentifierFields[field] { return fmt.Sprintf("execution_%s", field) } // admin_tag table has been migrated to an execution_tag table, so we need to customize the field name. diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager.go b/flyteadmin/pkg/manager/impl/node_execution_manager.go index 2c6709dc6c..2f0f60977c 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager.go @@ -407,11 +407,16 @@ func (m *NodeExecutionManager) listNodeExecutions( return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListNodeExecutions", requestToken) } + joinTableEntities := make(map[common.Entity]bool) + for _, filter := range filters { + joinTableEntities[filter.GetEntity()] = true + } listInput := repoInterfaces.ListResourceInput{ - Limit: int(limit), - Offset: offset, - InlineFilters: filters, - SortParameter: sortParameter, + Limit: int(limit), + Offset: offset, + InlineFilters: filters, + SortParameter: sortParameter, + JoinTableEntities: joinTableEntities, } listInput.MapFilters = mapFilters @@ -445,7 +450,7 @@ func (m *NodeExecutionManager) ListNodeExecutions( } ctx = getExecutionContext(ctx, request.WorkflowExecutionId) - identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId) + identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId, common.NodeExecution) if err != nil { return nil, err } @@ -483,7 +488,7 @@ func (m *NodeExecutionManager) ListNodeExecutionsForTask( } ctx = getTaskExecutionContext(ctx, request.TaskExecutionId) identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters( - ctx, request.TaskExecutionId.NodeExecutionId.ExecutionId) + ctx, request.TaskExecutionId.NodeExecutionId.ExecutionId, common.NodeExecution) if err != nil { return nil, err } diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager_test.go b/flyteadmin/pkg/manager/impl/node_execution_manager_test.go index cfc3db2bff..b43c785b33 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager_test.go @@ -784,17 +784,17 @@ func TestListNodeExecutionsLevelZero(t *testing.T) { assert.Equal(t, 1, input.Limit) assert.Equal(t, 2, input.Offset) assert.Len(t, input.InlineFilters, 3) - assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity()) queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr() assert.Equal(t, "project", queryExpr.Args) assert.Equal(t, "execution_project = ?", queryExpr.Query) - assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity()) queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr() assert.Equal(t, "domain", queryExpr.Args) assert.Equal(t, "execution_domain = ?", queryExpr.Query) - assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity()) queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr() assert.Equal(t, "name", queryExpr.Args) assert.Equal(t, "execution_name = ?", queryExpr.Query) @@ -806,6 +806,10 @@ func TestListNodeExecutionsLevelZero(t *testing.T) { "parent_task_execution_id": nil, }, filter) + assert.EqualValues(t, input.JoinTableEntities, map[common.Entity]bool{ + common.NodeExecution: true, + }) + assert.Equal(t, "execution_domain asc", input.SortParameter.GetGormOrderExpr()) return interfaces.NodeExecutionCollectionOutput{ NodeExecutions: []models.NodeExecution{ @@ -904,17 +908,17 @@ func TestListNodeExecutionsWithParent(t *testing.T) { assert.Equal(t, 1, input.Limit) assert.Equal(t, 2, input.Offset) assert.Len(t, input.InlineFilters, 4) - assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity()) queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr() assert.Equal(t, "project", queryExpr.Args) assert.Equal(t, "execution_project = ?", queryExpr.Query) - assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity()) queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr() assert.Equal(t, "domain", queryExpr.Args) assert.Equal(t, "execution_domain = ?", queryExpr.Query) - assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity()) queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr() assert.Equal(t, "name", queryExpr.Args) assert.Equal(t, "execution_name = ?", queryExpr.Query) @@ -979,6 +983,129 @@ func TestListNodeExecutionsWithParent(t *testing.T) { assert.Equal(t, "3", nodeExecutions.Token) } +func TestListNodeExecutions_WithJoinTableFilter(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + expectedClosure := admin.NodeExecutionClosure{ + Phase: core.NodeExecution_SUCCEEDED, + } + expectedMetadata := admin.NodeExecutionMetaData{ + SpecNodeId: "spec_node_id", + RetryGroup: "retry_group", + } + metadataBytes, _ := proto.Marshal(&expectedMetadata) + closureBytes, _ := proto.Marshal(&expectedClosure) + + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetListCallback( + func(ctx context.Context, input interfaces.ListResourceInput) ( + interfaces.NodeExecutionCollectionOutput, error) { + assert.Equal(t, 1, input.Limit) + assert.Equal(t, 2, input.Offset) + assert.Len(t, input.InlineFilters, 4) + assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity()) + queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr() + assert.Equal(t, "project", queryExpr.Args) + assert.Equal(t, "execution_project = ?", queryExpr.Query) + + assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity()) + queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr() + assert.Equal(t, "domain", queryExpr.Args) + assert.Equal(t, "execution_domain = ?", queryExpr.Query) + + assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity()) + queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr() + assert.Equal(t, "name", queryExpr.Args) + assert.Equal(t, "execution_name = ?", queryExpr.Query) + + assert.Equal(t, common.Execution, input.InlineFilters[3].GetEntity()) + queryExpr, _ = input.InlineFilters[3].GetGormQueryExpr() + assert.Equal(t, "SUCCEEDED", queryExpr.Args) + assert.Equal(t, "phase = ?", queryExpr.Query) + + assert.Len(t, input.MapFilters, 1) + filter := input.MapFilters[0].GetFilter() + assert.Equal(t, map[string]interface{}{ + "parent_id": nil, + "parent_task_execution_id": nil, + }, filter) + + assert.EqualValues(t, input.JoinTableEntities, map[common.Entity]bool{ + common.NodeExecution: true, + common.Execution: true, + }) + + assert.Equal(t, "execution_domain asc", input.SortParameter.GetGormOrderExpr()) + return interfaces.NodeExecutionCollectionOutput{ + NodeExecutions: []models.NodeExecution{ + { + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "node id", + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }, + Phase: core.NodeExecution_SUCCEEDED.String(), + InputURI: "input uri", + StartedAt: &occurredAt, + Closure: closureBytes, + NodeExecutionMetadata: metadataBytes, + }, + }, + }, nil + }) + repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback( + func( + ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) { + return models.NodeExecution{ + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "node id", + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }, + Phase: core.NodeExecution_SUCCEEDED.String(), + InputURI: "input uri", + StartedAt: &occurredAt, + Closure: closureBytes, + NodeExecutionMetadata: metadataBytes, + }, nil + }) + nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{}) + nodeExecutions, err := nodeExecManager.ListNodeExecutions(context.Background(), &admin.NodeExecutionListRequest{ + WorkflowExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Limit: 1, + Token: "2", + SortBy: &admin.Sort{ + Direction: admin.Sort_ASCENDING, + Key: "execution_domain", + }, + Filters: "eq(execution.phase, SUCCEEDED)", + }) + assert.NoError(t, err) + assert.Len(t, nodeExecutions.NodeExecutions, 1) + assert.True(t, proto.Equal(&admin.NodeExecution{ + Id: &core.NodeExecutionIdentifier{ + NodeId: "node id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }, + InputUri: "input uri", + Closure: &expectedClosure, + Metadata: &expectedMetadata, + }, nodeExecutions.NodeExecutions[0])) + assert.Equal(t, "3", nodeExecutions.Token) +} + func TestListNodeExecutions_InvalidParams(t *testing.T) { nodeExecManager := NewNodeExecutionManager(nil, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{}) _, err := nodeExecManager.ListNodeExecutions(context.Background(), &admin.NodeExecutionListRequest{ @@ -1120,17 +1247,17 @@ func TestListNodeExecutionsForTask(t *testing.T) { assert.Equal(t, 1, input.Limit) assert.Equal(t, 2, input.Offset) assert.Len(t, input.InlineFilters, 4) - assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity()) queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr() assert.Equal(t, "project", queryExpr.Args) assert.Equal(t, "execution_project = ?", queryExpr.Query) - assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity()) queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr() assert.Equal(t, "domain", queryExpr.Args) assert.Equal(t, "execution_domain = ?", queryExpr.Query) - assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity()) + assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity()) queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr() assert.Equal(t, "name", queryExpr.Args) assert.Equal(t, "execution_name = ?", queryExpr.Query) diff --git a/flyteadmin/pkg/manager/impl/signal_manager.go b/flyteadmin/pkg/manager/impl/signal_manager.go index 49bfc8ac45..f98edae674 100644 --- a/flyteadmin/pkg/manager/impl/signal_manager.go +++ b/flyteadmin/pkg/manager/impl/signal_manager.go @@ -72,7 +72,7 @@ func (s *SignalManager) ListSignals(ctx context.Context, request *admin.SignalLi } ctx = getExecutionContext(ctx, request.WorkflowExecutionId) - identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId) + identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId, common.Signal) if err != nil { return nil, err } diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager.go b/flyteadmin/pkg/manager/impl/task_execution_manager.go index 04811200ac..f8b8e12e21 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager.go @@ -247,7 +247,7 @@ func (m *TaskExecutionManager) ListTaskExecutions( } ctx = getNodeExecutionContext(ctx, request.NodeExecutionId) - identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, request.NodeExecutionId) + identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, request.NodeExecutionId, common.TaskExecution) if err != nil { return nil, err } @@ -267,12 +267,17 @@ func (m *TaskExecutionManager) ListTaskExecutions( return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid pagination token %s for ListTaskExecutions", request.Token) } + joinTableEntities := make(map[common.Entity]bool) + for _, filter := range filters { + joinTableEntities[filter.GetEntity()] = true + } output, err := m.db.TaskExecutionRepo().List(ctx, repoInterfaces.ListResourceInput{ - InlineFilters: filters, - Offset: offset, - Limit: int(request.Limit), - SortParameter: sortParameter, + InlineFilters: filters, + Offset: offset, + Limit: int(request.Limit), + SortParameter: sortParameter, + JoinTableEntities: joinTableEntities, }) if err != nil { logger.Debugf(ctx, "Failed to list task executions with request [%+v] with err %v", diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go index b59b1c1b31..7e2a14131e 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager_test.go @@ -644,22 +644,22 @@ func TestListTaskExecutions(t *testing.T) { assert.Equal(t, 1, input.Offset) assert.Len(t, input.InlineFilters, 4) - assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity()) + assert.Equal(t, common.TaskExecution, input.InlineFilters[0].GetEntity()) queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr() assert.Equal(t, "exec project b", queryExpr.Args) assert.Equal(t, "execution_project = ?", queryExpr.Query) - assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity()) + assert.Equal(t, common.TaskExecution, input.InlineFilters[1].GetEntity()) queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr() assert.Equal(t, "exec domain b", queryExpr.Args) assert.Equal(t, "execution_domain = ?", queryExpr.Query) - assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity()) + assert.Equal(t, common.TaskExecution, input.InlineFilters[2].GetEntity()) queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr() assert.Equal(t, "exec name b", queryExpr.Args) assert.Equal(t, "execution_name = ?", queryExpr.Query) - assert.Equal(t, common.NodeExecution, input.InlineFilters[3].GetEntity()) + assert.Equal(t, common.TaskExecution, input.InlineFilters[3].GetEntity()) queryExpr, _ = input.InlineFilters[3].GetGormQueryExpr() assert.Equal(t, "nodey b", queryExpr.Args) assert.Equal(t, "node_id = ?", queryExpr.Query) @@ -777,6 +777,179 @@ func TestListTaskExecutions(t *testing.T) { }, taskExecutions.TaskExecutions[1])) } +func TestListTaskExecutions_Filters(t *testing.T) { + repository := repositoryMocks.NewMockRepository() + + expectedLogs := []*core.TaskLog{{Uri: "test-log1.txt"}} + extraLongErrMsg := string(make([]byte, 2*100)) + expectedOutputResult := &admin.TaskExecutionClosure_Error{ + Error: &core.ExecutionError{ + Message: extraLongErrMsg, + }, + } + expectedClosure := &admin.TaskExecutionClosure{ + StartedAt: sampleTaskEventOccurredAt, + Phase: core.TaskExecution_SUCCEEDED, + Duration: ptypes.DurationProto(time.Minute), + OutputResult: expectedOutputResult, + Logs: expectedLogs, + } + + closureBytes, _ := proto.Marshal(expectedClosure) + + firstRetryAttempt := uint32(1) + secondRetryAttempt := uint32(2) + listTaskExecutionsCalled := false + repository.TaskExecutionRepo().(*repositoryMocks.MockTaskExecutionRepo).SetListCallback( + func(ctx context.Context, input interfaces.ListResourceInput) (interfaces.TaskExecutionCollectionOutput, error) { + listTaskExecutionsCalled = true + assert.Equal(t, 99, input.Limit) + assert.Equal(t, 1, input.Offset) + + assert.Len(t, input.InlineFilters, 5) + assert.Equal(t, common.TaskExecution, input.InlineFilters[0].GetEntity()) + queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr() + assert.Equal(t, "exec project b", queryExpr.Args) + assert.Equal(t, "execution_project = ?", queryExpr.Query) + + assert.Equal(t, common.TaskExecution, input.InlineFilters[1].GetEntity()) + queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr() + assert.Equal(t, "exec domain b", queryExpr.Args) + assert.Equal(t, "execution_domain = ?", queryExpr.Query) + + assert.Equal(t, common.TaskExecution, input.InlineFilters[2].GetEntity()) + queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr() + assert.Equal(t, "exec name b", queryExpr.Args) + assert.Equal(t, "execution_name = ?", queryExpr.Query) + + assert.Equal(t, common.TaskExecution, input.InlineFilters[3].GetEntity()) + queryExpr, _ = input.InlineFilters[3].GetGormQueryExpr() + assert.Equal(t, "nodey b", queryExpr.Args) + assert.Equal(t, "node_id = ?", queryExpr.Query) + + assert.Equal(t, common.Execution, input.InlineFilters[4].GetEntity()) + queryExpr, _ = input.InlineFilters[4].GetGormQueryExpr() + assert.Equal(t, "SUCCEEDED", queryExpr.Args) + assert.Equal(t, "phase = ?", queryExpr.Query) + assert.EqualValues(t, input.JoinTableEntities, map[common.Entity]bool{ + common.TaskExecution: true, + common.Execution: true, + }) + + return interfaces.TaskExecutionCollectionOutput{ + TaskExecutions: []models.TaskExecution{ + { + TaskExecutionKey: models.TaskExecutionKey{ + TaskKey: models.TaskKey{ + Project: "task project a", + Domain: "task domain a", + Name: "task name a", + Version: "task version a", + }, + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "nodey a", + ExecutionKey: models.ExecutionKey{ + Project: "exec project a", + Domain: "exec domain a", + Name: "exec name a", + }, + }, + RetryAttempt: &firstRetryAttempt, + }, + Phase: core.TaskExecution_SUCCEEDED.String(), + InputURI: "input-uri.pb", + StartedAt: &taskStartedAt, + Closure: closureBytes, + }, + { + TaskExecutionKey: models.TaskExecutionKey{ + TaskKey: models.TaskKey{ + Project: "task project b", + Domain: "task domain b", + Name: "task name b", + Version: "task version b", + }, + NodeExecutionKey: models.NodeExecutionKey{ + NodeID: "nodey b", + ExecutionKey: models.ExecutionKey{ + Project: "exec project b", + Domain: "exec domain b", + Name: "exec name b", + }, + }, + RetryAttempt: &secondRetryAttempt, + }, + Phase: core.TaskExecution_SUCCEEDED.String(), + InputURI: "input-uri2.pb", + StartedAt: &taskStartedAt, + Closure: closureBytes, + }, + }, + }, nil + }) + taskExecManager := NewTaskExecutionManager(repository, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockTaskExecutionRemoteURL, nil, nil) + taskExecutions, err := taskExecManager.ListTaskExecutions(context.Background(), &admin.TaskExecutionListRequest{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + NodeId: "nodey b", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "exec project b", + Domain: "exec domain b", + Name: "exec name b", + }, + }, + Token: "1", + Limit: 99, + Filters: "eq(execution.phase, SUCCEEDED)", + }) + assert.Nil(t, err) + assert.True(t, listTaskExecutionsCalled) + + assert.True(t, proto.Equal(&admin.TaskExecution{ + Id: &core.TaskExecutionIdentifier{ + RetryAttempt: firstRetryAttempt, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "exec project a", + Domain: "exec domain a", + Name: "exec name a", + }, + NodeId: "nodey a", + }, + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "task project a", + Domain: "task domain a", + Name: "task name a", + Version: "task version a", + }, + }, + InputUri: "input-uri.pb", + Closure: expectedClosure, + }, taskExecutions.TaskExecutions[0])) + assert.True(t, proto.Equal(&admin.TaskExecution{ + Id: &core.TaskExecutionIdentifier{ + RetryAttempt: secondRetryAttempt, + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "exec project b", + Domain: "exec domain b", + Name: "exec name b", + }, + NodeId: "nodey b", + }, + TaskId: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: "task project b", + Domain: "task domain b", + Name: "task name b", + Version: "task version b", + }, + }, + InputUri: "input-uri2.pb", + Closure: expectedClosure, + }, taskExecutions.TaskExecutions[1])) +} + func TestListTaskExecutions_NoFilters(t *testing.T) { repository := repositoryMocks.NewMockRepository() diff --git a/flyteadmin/pkg/manager/impl/util/filters.go b/flyteadmin/pkg/manager/impl/util/filters.go index 81cb55a994..377dcdab51 100644 --- a/flyteadmin/pkg/manager/impl/util/filters.go +++ b/flyteadmin/pkg/manager/impl/util/filters.go @@ -271,10 +271,10 @@ func GetDbFilters(spec FilterSpec, primaryEntity common.Entity) ([]common.Inline } func GetWorkflowExecutionIdentifierFilters( - ctx context.Context, workflowExecutionIdentifier *core.WorkflowExecutionIdentifier) ([]common.InlineFilter, error) { + ctx context.Context, workflowExecutionIdentifier *core.WorkflowExecutionIdentifier, entity common.Entity) ([]common.InlineFilter, error) { identifierFilters := make([]common.InlineFilter, 3) identifierProjectFilter, err := GetSingleValueEqualityFilter( - common.Execution, shared.Project, workflowExecutionIdentifier.Project) + entity, shared.Project, workflowExecutionIdentifier.Project) if err != nil { logger.Warningf(ctx, "Failed to create execution identifier filter for project: %s with identifier [%+v]", workflowExecutionIdentifier.Project, workflowExecutionIdentifier) @@ -283,7 +283,7 @@ func GetWorkflowExecutionIdentifierFilters( identifierFilters[0] = identifierProjectFilter identifierDomainFilter, err := GetSingleValueEqualityFilter( - common.Execution, shared.Domain, workflowExecutionIdentifier.Domain) + entity, shared.Domain, workflowExecutionIdentifier.Domain) if err != nil { logger.Warningf(ctx, "Failed to create execution identifier filter for domain: %s with identifier [%+v]", workflowExecutionIdentifier.Domain, workflowExecutionIdentifier) @@ -292,7 +292,7 @@ func GetWorkflowExecutionIdentifierFilters( identifierFilters[1] = identifierDomainFilter identifierNameFilter, err := GetSingleValueEqualityFilter( - common.Execution, shared.Name, workflowExecutionIdentifier.Name) + entity, shared.Name, workflowExecutionIdentifier.Name) if err != nil { logger.Warningf(ctx, "Failed to create execution identifier filter for domain: %s with identifier [%+v]", workflowExecutionIdentifier.Name, workflowExecutionIdentifier) @@ -304,14 +304,14 @@ func GetWorkflowExecutionIdentifierFilters( // All inputs to this function must be validated. func GetNodeExecutionIdentifierFilters( - ctx context.Context, nodeExecutionIdentifier *core.NodeExecutionIdentifier) ([]common.InlineFilter, error) { + ctx context.Context, nodeExecutionIdentifier *core.NodeExecutionIdentifier, entity common.Entity) ([]common.InlineFilter, error) { workflowExecutionIdentifierFilters, err := - GetWorkflowExecutionIdentifierFilters(ctx, nodeExecutionIdentifier.ExecutionId) + GetWorkflowExecutionIdentifierFilters(ctx, nodeExecutionIdentifier.ExecutionId, entity) if err != nil { return nil, err } nodeIDFilter, err := GetSingleValueEqualityFilter( - common.NodeExecution, shared.NodeID, nodeExecutionIdentifier.NodeId) + entity, shared.NodeID, nodeExecutionIdentifier.NodeId) if err != nil { logger.Warningf(ctx, "Failed to create node execution identifier filter for node id: %s with identifier [%+v]", nodeExecutionIdentifier.NodeId, nodeExecutionIdentifier) diff --git a/flyteadmin/pkg/manager/impl/util/filters_test.go b/flyteadmin/pkg/manager/impl/util/filters_test.go index 29c1116a8e..72a8c9971b 100644 --- a/flyteadmin/pkg/manager/impl/util/filters_test.go +++ b/flyteadmin/pkg/manager/impl/util/filters_test.go @@ -176,7 +176,7 @@ func TestGetWorkflowExecutionIdentifierFilters(t *testing.T) { Project: "ex project", Domain: "ex domain", Name: "ex name", - }) + }, common.Execution) assert.Nil(t, err) assert.Len(t, identifierFilters, 3) @@ -205,26 +205,26 @@ func TestGetNodeExecutionIdentifierFilters(t *testing.T) { Name: "ex name", }, NodeId: "nodey", - }) + }, common.TaskExecution) assert.Nil(t, err) assert.Len(t, identifierFilters, 4) - assert.Equal(t, common.Execution, identifierFilters[0].GetEntity()) + assert.Equal(t, common.TaskExecution, identifierFilters[0].GetEntity()) queryExpr, _ := identifierFilters[0].GetGormQueryExpr() assert.Equal(t, "ex project", queryExpr.Args) assert.Equal(t, "execution_project = ?", queryExpr.Query) - assert.Equal(t, common.Execution, identifierFilters[1].GetEntity()) + assert.Equal(t, common.TaskExecution, identifierFilters[1].GetEntity()) queryExpr, _ = identifierFilters[1].GetGormQueryExpr() assert.Equal(t, "ex domain", queryExpr.Args) assert.Equal(t, "execution_domain = ?", queryExpr.Query) - assert.Equal(t, common.Execution, identifierFilters[2].GetEntity()) + assert.Equal(t, common.TaskExecution, identifierFilters[2].GetEntity()) queryExpr, _ = identifierFilters[2].GetGormQueryExpr() assert.Equal(t, "ex name", queryExpr.Args) assert.Equal(t, "execution_name = ?", queryExpr.Query) - assert.Equal(t, common.NodeExecution, identifierFilters[3].GetEntity()) + assert.Equal(t, common.TaskExecution, identifierFilters[3].GetEntity()) queryExpr, _ = identifierFilters[3].GetGormQueryExpr() assert.Equal(t, "nodey", queryExpr.Args) assert.Equal(t, "node_id = ?", queryExpr.Query) diff --git a/flyteadmin/pkg/repositories/gormimpl/common.go b/flyteadmin/pkg/repositories/gormimpl/common.go index 40a54f8878..330555be8f 100644 --- a/flyteadmin/pkg/repositories/gormimpl/common.go +++ b/flyteadmin/pkg/repositories/gormimpl/common.go @@ -52,25 +52,25 @@ var entityToTableName = map[common.Entity]string{ } var innerJoinExecToNodeExec = fmt.Sprintf( - "INNER JOIN %s ON %s.execution_project = %s.execution_project AND "+ - "%s.execution_domain = %s.execution_domain AND %s.execution_name = %s.execution_name", - executionTableName, nodeExecutionTableName, executionTableName, nodeExecutionTableName, executionTableName, - nodeExecutionTableName, executionTableName) + "INNER JOIN %[1]s ON %[2]s.execution_project = %[1]s.execution_project AND "+ + "%[2]s.execution_domain = %[1]s.execution_domain AND %[2]s.execution_name = %[1]s.execution_name", + executionTableName, nodeExecutionTableName) +var innerJoinExecToTaskExec = fmt.Sprintf( + "INNER JOIN %[1]s ON %[2]s.execution_project = %[1]s.execution_project AND "+ + "%[2]s.execution_domain = %[1]s.execution_domain AND %[2]s.execution_name = %[1]s.execution_name", + executionTableName, taskExecutionTableName) var innerJoinNodeExecToTaskExec = fmt.Sprintf( - "INNER JOIN %s ON %s.node_id = %s.node_id AND %s.execution_project = %s.execution_project AND "+ - "%s.execution_domain = %s.execution_domain AND %s.execution_name = %s.execution_name", - nodeExecutionTableName, taskExecutionTableName, nodeExecutionTableName, taskExecutionTableName, - nodeExecutionTableName, taskExecutionTableName, nodeExecutionTableName, taskExecutionTableName, - nodeExecutionTableName) + "INNER JOIN %[1]s ON %s.node_id = %[1]s.node_id AND %[2]s.execution_project = %[1]s.execution_project AND "+ + "%[2]s.execution_domain = %[1]s.execution_domain AND %[2]s.execution_name = %[1]s.execution_name", + nodeExecutionTableName, taskExecutionTableName) // Because dynamic tasks do NOT necessarily register static task definitions, we use a left join to not exclude // dynamic tasks from list queries. var leftJoinTaskToTaskExec = fmt.Sprintf( - "LEFT JOIN %s ON %s.project = %s.project AND %s.domain = %s.domain AND %s.name = %s.name AND "+ - "%s.version = %s.version", - taskTableName, taskExecutionTableName, taskTableName, taskExecutionTableName, taskTableName, - taskExecutionTableName, taskTableName, taskExecutionTableName, taskTableName) + "LEFT JOIN %[1]s ON %[2]s.project = %[1]s.project AND %[2]s.domain = %[1]s.domain AND %[2]s.name = %[1]s.name AND "+ + " %[2]s.version = %[1]s.version", + taskTableName, taskExecutionTableName) // Validates there are no missing but required parameters in ListResourceInput func ValidateListInput(input interfaces.ListResourceInput) adminErrors.FlyteAdminError { diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go index 0fe97d2f8c..70833d4d77 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo.go @@ -3,10 +3,10 @@ package gormimpl import ( "context" "errors" - "fmt" "gorm.io/gorm" + "github.com/flyteorg/flyte/flyteadmin/pkg/common" adminErrors "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/models" @@ -113,12 +113,10 @@ func (r *NodeExecutionRepo) List(ctx context.Context, input interfaces.ListResou } var nodeExecutions []models.NodeExecution tx := r.db.WithContext(ctx).Limit(input.Limit).Offset(input.Offset).Preload("ChildNodeExecutions") - // And add join condition (joining multiple tables is fine even we only filter on a subset of table attributes). - // (this query isn't called for deletes). - tx = tx.Joins(fmt.Sprintf("INNER JOIN %s ON %s.execution_project = %s.execution_project AND "+ - "%s.execution_domain = %s.execution_domain AND %s.execution_name = %s.execution_name", - executionTableName, nodeExecutionTableName, executionTableName, - nodeExecutionTableName, executionTableName, nodeExecutionTableName, executionTableName)) + // And add join condition, if any + if input.JoinTableEntities[common.Execution] { + tx = tx.Joins(innerJoinExecToNodeExec) + } // Apply filters tx, err := applyScopedFilters(tx, input.InlineFilters, input.MapFilters) @@ -165,12 +163,10 @@ func (r *NodeExecutionRepo) Count(ctx context.Context, input interfaces.CountRes var err error tx := r.db.WithContext(ctx).Model(&models.NodeExecution{}).Preload("ChildNodeExecutions") - // Add join condition (joining multiple tables is fine even we only filter on a subset of table attributes). - // (this query isn't called for deletes). - tx = tx.Joins(fmt.Sprintf("INNER JOIN %s ON %s.execution_project = %s.execution_project AND "+ - "%s.execution_domain = %s.execution_domain AND %s.execution_name = %s.execution_name", - executionTableName, nodeExecutionTableName, executionTableName, - nodeExecutionTableName, executionTableName, nodeExecutionTableName, executionTableName)) + // And add join condition, if any + if input.JoinTableEntities[common.Execution] { + tx = tx.Joins(innerJoinExecToNodeExec) + } // Apply filters tx, err = applyScopedFilters(tx, input.InlineFilters, input.MapFilters) diff --git a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go index e90e342a7c..d35f8ac4f4 100644 --- a/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/node_execution_repo_test.go @@ -215,7 +215,7 @@ func TestListNodeExecutions(t *testing.T) { } GlobalMock := mocket.Catcher.Reset() - GlobalMock.NewMock().WithQuery(`SELECT "node_executions"."id","node_executions"."created_at","node_executions"."updated_at","node_executions"."deleted_at","node_executions"."execution_project","node_executions"."execution_domain","node_executions"."execution_name","node_executions"."node_id","node_executions"."phase","node_executions"."input_uri","node_executions"."closure","node_executions"."started_at","node_executions"."node_execution_created_at","node_executions"."node_execution_updated_at","node_executions"."duration","node_executions"."node_execution_metadata","node_executions"."parent_id","node_executions"."parent_task_execution_id","node_executions"."error_kind","node_executions"."error_code","node_executions"."cache_status","node_executions"."dynamic_workflow_remote_closure_reference","node_executions"."internal_data" FROM "node_executions" INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE node_executions.phase = $1 LIMIT 20%`). + GlobalMock.NewMock().WithQuery(`SELECT * FROM "node_executions" WHERE node_executions.phase = $1 LIMIT 20`). WithReply(nodeExecutions) collection, err := nodeExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ @@ -240,6 +240,59 @@ func TestListNodeExecutions(t *testing.T) { } } +func TestListNodeExecutions_WithJoins(t *testing.T) { + nodeExecutionRepo := NewNodeExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + nodeExecutions := make([]map[string]interface{}, 0) + executionIDs := []string{"100", "200"} + for _, executionID := range executionIDs { + nodeExecution := getMockNodeExecutionResponseFromDb(models.NodeExecution{ + NodeExecutionKey: models.NodeExecutionKey{ + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: executionID, + }, + }, + Phase: nodePhase, + Closure: []byte("closure"), + InputURI: "input uri", + StartedAt: &nodeStartedAt, + Duration: time.Hour, + NodeExecutionCreatedAt: &nodeCreatedAt, + NodeExecutionUpdatedAt: &nodePlanUpdatedAt, + }) + nodeExecutions = append(nodeExecutions, nodeExecution) + } + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + GlobalMock.NewMock().WithQuery(`SELECT "node_executions"."id","node_executions"."created_at","node_executions"."updated_at","node_executions"."deleted_at","node_executions"."execution_project","node_executions"."execution_domain","node_executions"."execution_name","node_executions"."node_id","node_executions"."phase","node_executions"."input_uri","node_executions"."closure","node_executions"."started_at","node_executions"."node_execution_created_at","node_executions"."node_execution_updated_at","node_executions"."duration","node_executions"."node_execution_metadata","node_executions"."parent_id","node_executions"."parent_task_execution_id","node_executions"."error_kind","node_executions"."error_code","node_executions"."cache_status","node_executions"."dynamic_workflow_remote_closure_reference","node_executions"."internal_data" FROM "node_executions" INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE node_executions.phase = $1 LIMIT 20`). + WithReply(nodeExecutions) + + collection, err := nodeExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ + InlineFilters: []common.InlineFilter{ + getEqualityFilter(common.NodeExecution, "phase", nodePhase), + }, + JoinTableEntities: map[common.Entity]bool{ + common.Execution: true, + }, + Limit: 20, + }) + assert.NoError(t, err) + assert.NotEmpty(t, collection) + assert.NotEmpty(t, collection.NodeExecutions) + assert.Len(t, collection.NodeExecutions, 2) + for _, nodeExecution := range collection.NodeExecutions { + assert.Equal(t, "project", nodeExecution.ExecutionKey.Project) + assert.Equal(t, "domain", nodeExecution.ExecutionKey.Domain) + assert.Contains(t, executionIDs, nodeExecution.ExecutionKey.Name) + assert.Equal(t, nodePhase, nodeExecution.Phase) + assert.Equal(t, []byte("closure"), nodeExecution.Closure) + assert.Equal(t, "input uri", nodeExecution.InputURI) + assert.Equal(t, nodeStartedAt, *nodeExecution.StartedAt) + assert.Equal(t, time.Hour, nodeExecution.Duration) + } +} + func TestListNodeExecutions_Order(t *testing.T) { nodeExecutionRepo := NewNodeExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) nodeExecutions := make([]map[string]interface{}, 0) @@ -305,7 +358,7 @@ func TestListNodeExecutionsForExecution(t *testing.T) { nodeExecutions = append(nodeExecutions, nodeExecution) GlobalMock := mocket.Catcher.Reset() - query := `SELECT "node_executions"."id","node_executions"."created_at","node_executions"."updated_at","node_executions"."deleted_at","node_executions"."execution_project","node_executions"."execution_domain","node_executions"."execution_name","node_executions"."node_id","node_executions"."phase","node_executions"."input_uri","node_executions"."closure","node_executions"."started_at","node_executions"."node_execution_created_at","node_executions"."node_execution_updated_at","node_executions"."duration","node_executions"."node_execution_metadata","node_executions"."parent_id","node_executions"."parent_task_execution_id","node_executions"."error_kind","node_executions"."error_code","node_executions"."cache_status","node_executions"."dynamic_workflow_remote_closure_reference","node_executions"."internal_data" FROM "node_executions" INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE node_executions.phase = $1 AND executions.execution_name = $2 LIMIT 20%` + query := `SELECT * FROM "node_executions" WHERE node_executions.phase = $1 AND executions.execution_name = $2 LIMIT 20` GlobalMock.NewMock().WithQuery(query).WithReply(nodeExecutions) collection, err := nodeExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ @@ -392,7 +445,7 @@ func TestCountNodeExecutions_Filters(t *testing.T) { GlobalMock := mocket.Catcher.Reset() GlobalMock.NewMock().WithQuery( - `SELECT count(*) FROM "node_executions" INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE node_executions.phase = $1 AND "node_executions"."error_code" IS NULL`).WithReply([]map[string]interface{}{{"rows": 3}}) + `SELECT count(*) FROM "node_executions" WHERE node_executions.phase = $1 AND "node_executions"."error_code" IS NULL`).WithReply([]map[string]interface{}{{"rows": 3}}) count, err := nodeExecutionRepo.Count(context.Background(), interfaces.CountResourceInput{ InlineFilters: []common.InlineFilter{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go index f2ac2adf52..c42d36b1bc 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo.go @@ -6,6 +6,7 @@ import ( "gorm.io/gorm" + "github.com/flyteorg/flyte/flyteadmin/pkg/common" flyteAdminDbErrors "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/interfaces" "github.com/flyteorg/flyte/flyteadmin/pkg/repositories/models" @@ -97,13 +98,20 @@ func (r *TaskExecutionRepo) List(ctx context.Context, input interfaces.ListResou var taskExecutions []models.TaskExecution tx := r.db.WithContext(ctx).Limit(input.Limit).Offset(input.Offset).Preload("ChildNodeExecution") - // And add three join conditions (joining multiple tables is fine even we only filter on a subset of table attributes). - // We are joining on task -> taskExec -> NodeExec -> Exec. - // NOTE: the order in which the joins are called below are important because postgres will only know about certain - // tables as they are joined. So we should do it in the order specified above. - tx = tx.Joins(leftJoinTaskToTaskExec) - tx = tx.Joins(innerJoinNodeExecToTaskExec) - tx = tx.Joins(innerJoinExecToNodeExec) + // And add three join conditions + // We enable joining on + // - task x task exec + // - node exec x task exec + // - exec x task exec + if input.JoinTableEntities[common.Task] { + tx = tx.Joins(leftJoinTaskToTaskExec) + } + if input.JoinTableEntities[common.NodeExecution] { + tx = tx.Joins(innerJoinNodeExecToTaskExec) + } + if input.JoinTableEntities[common.Execution] { + tx = tx.Joins(innerJoinExecToTaskExec) + } // Apply filters tx, err := applyScopedFilters(tx, input.InlineFilters, input.MapFilters) @@ -132,13 +140,20 @@ func (r *TaskExecutionRepo) Count(ctx context.Context, input interfaces.CountRes var err error tx := r.db.WithContext(ctx).Model(&models.TaskExecution{}) - // Add three join conditions (joining multiple tables is fine even we only filter on a subset of table attributes). - // We are joining on task -> taskExec -> NodeExec -> Exec. - // NOTE: the order in which the joins are called below are important because postgres will only know about certain - // tables as they are joined. So we should do it in the order specified above. - tx = tx.Joins(leftJoinTaskToTaskExec) - tx = tx.Joins(innerJoinNodeExecToTaskExec) - tx = tx.Joins(innerJoinExecToNodeExec) + // And add three join conditions + // We enable joining on + // - task x task exec + // - node exec x task exec + // - exec x task exec + if input.JoinTableEntities[common.Task] { + tx = tx.Joins(leftJoinTaskToTaskExec) + } + if input.JoinTableEntities[common.NodeExecution] { + tx = tx.Joins(innerJoinNodeExecToTaskExec) + } + if input.JoinTableEntities[common.Execution] { + tx = tx.Joins(innerJoinExecToTaskExec) + } // Apply filters tx, err = applyScopedFilters(tx, input.InlineFilters, input.MapFilters) diff --git a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go index 5947edf175..8ccee763c2 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_execution_repo_test.go @@ -134,7 +134,7 @@ func TestListTaskExecutionForExecution(t *testing.T) { taskExecutions = append(taskExecutions, taskExecution) GlobalMock := mocket.Catcher.Reset() GlobalMock.Logging = true - GlobalMock.NewMock().WithQuery(`SELECT "task_executions"."id","task_executions"."created_at","task_executions"."updated_at","task_executions"."deleted_at","task_executions"."project","task_executions"."domain","task_executions"."name","task_executions"."version","task_executions"."execution_project","task_executions"."execution_domain","task_executions"."execution_name","task_executions"."node_id","task_executions"."retry_attempt","task_executions"."phase","task_executions"."phase_version","task_executions"."input_uri","task_executions"."closure","task_executions"."started_at","task_executions"."task_execution_created_at","task_executions"."task_execution_updated_at","task_executions"."duration" FROM "task_executions" LEFT JOIN tasks ON task_executions.project = tasks.project AND task_executions.domain = tasks.domain AND task_executions.name = tasks.name AND task_executions.version = tasks.version INNER JOIN node_executions ON task_executions.node_id = node_executions.node_id AND task_executions.execution_project = node_executions.execution_project AND task_executions.execution_domain = node_executions.execution_domain AND task_executions.execution_name = node_executions.execution_name INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE executions.execution_project = $1 AND executions.execution_domain = $2 AND executions.execution_name = $3 LIMIT 20`).WithReply(taskExecutions) + GlobalMock.NewMock().WithQuery(`SELECT * FROM "task_executions" WHERE executions.execution_project = $1 AND executions.execution_domain = $2 AND executions.execution_name = $3 LIMIT 20`).WithReply(taskExecutions) collection, err := taskExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ InlineFilters: []common.InlineFilter{ @@ -160,7 +160,7 @@ func TestListTaskExecutionForExecution(t *testing.T) { } } -func TestListTaskExecutionsForTaskExecution(t *testing.T) { +func TestListTaskExecutionsForNodeExecution(t *testing.T) { taskExecutionRepo := NewTaskExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) taskExecutions := make([]map[string]interface{}, 0) @@ -168,9 +168,50 @@ func TestListTaskExecutionsForTaskExecution(t *testing.T) { taskExecutions = append(taskExecutions, taskExecution) GlobalMock := mocket.Catcher.Reset() - GlobalMock.Logging = true + GlobalMock.NewMock().WithQuery(`SELECT "task_executions"."id","task_executions"."created_at","task_executions"."updated_at","task_executions"."deleted_at","task_executions"."project","task_executions"."domain","task_executions"."name","task_executions"."version","task_executions"."execution_project","task_executions"."execution_domain","task_executions"."execution_name","task_executions"."node_id","task_executions"."retry_attempt","task_executions"."phase","task_executions"."phase_version","task_executions"."input_uri","task_executions"."closure","task_executions"."started_at","task_executions"."task_execution_created_at","task_executions"."task_execution_updated_at","task_executions"."duration" FROM "task_executions" INNER JOIN node_executions ON task_executions.node_id = node_executions.node_id AND task_executions.execution_project = node_executions.execution_project AND task_executions.execution_domain = node_executions.execution_domain AND task_executions.execution_name = node_executions.execution_name WHERE tasks.project = $1 AND tasks.domain = $2 AND tasks.name = $3 AND tasks.version = $4 AND node_executions.phase = $5 AND executions.execution_project = $6 AND executions.execution_domain = $7 AND executions.execution_name = $8 LIMIT 20`).WithReply(taskExecutions) + + collection, err := taskExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ + InlineFilters: []common.InlineFilter{ + getEqualityFilter(common.Task, "project", "project_tn"), + getEqualityFilter(common.Task, "domain", "domain_t"), + getEqualityFilter(common.Task, "name", "domain_t"), + getEqualityFilter(common.Task, "version", "version_t"), + + getEqualityFilter(common.NodeExecution, "phase", nodePhase), + getEqualityFilter(common.Execution, "project", "project_name"), + getEqualityFilter(common.Execution, "domain", "domain_name"), + getEqualityFilter(common.Execution, "name", "execution_name"), + }, + JoinTableEntities: map[common.Entity]bool{ + common.NodeExecution: true, + }, + Limit: 20, + }) + assert.NoError(t, err) + assert.NotEmpty(t, collection) + assert.NotEmpty(t, collection.TaskExecutions) + assert.Len(t, collection.TaskExecutions, 1) + + for _, taskExecution := range collection.TaskExecutions { + assert.Equal(t, testTaskExecution.TaskExecutionKey, taskExecution.TaskExecutionKey) + assert.Equal(t, &retryAttemptValue, taskExecution.RetryAttempt) + assert.Equal(t, taskPhase, taskExecution.Phase) + assert.Equal(t, []byte("Test"), taskExecution.Closure) + assert.Equal(t, "testInput.pb", taskExecution.InputURI) + assert.Equal(t, taskStartedAt, *taskExecution.StartedAt) + assert.Equal(t, time.Hour, taskExecution.Duration) + } +} + +func TestListTaskExecutionsForExecution(t *testing.T) { + taskExecutionRepo := NewTaskExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + taskExecutions := make([]map[string]interface{}, 0) + taskExecution := getMockTaskExecutionResponseFromDb(testTaskExecution) + taskExecutions = append(taskExecutions, taskExecution) - GlobalMock.NewMock().WithQuery(`SELECT "task_executions"."id","task_executions"."created_at","task_executions"."updated_at","task_executions"."deleted_at","task_executions"."project","task_executions"."domain","task_executions"."name","task_executions"."version","task_executions"."execution_project","task_executions"."execution_domain","task_executions"."execution_name","task_executions"."node_id","task_executions"."retry_attempt","task_executions"."phase","task_executions"."phase_version","task_executions"."input_uri","task_executions"."closure","task_executions"."started_at","task_executions"."task_execution_created_at","task_executions"."task_execution_updated_at","task_executions"."duration" FROM "task_executions" LEFT JOIN tasks ON task_executions.project = tasks.project AND task_executions.domain = tasks.domain AND task_executions.name = tasks.name AND task_executions.version = tasks.version INNER JOIN node_executions ON task_executions.node_id = node_executions.node_id AND task_executions.execution_project = node_executions.execution_project AND task_executions.execution_domain = node_executions.execution_domain AND task_executions.execution_name = node_executions.execution_name INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE tasks.project = $1 AND tasks.domain = $2 AND tasks.name = $3 AND tasks.version = $4 AND node_executions.phase = $5 AND executions.execution_project = $6 AND executions.execution_domain = $7 AND executions.execution_name = $8 LIMIT 20`).WithReply(taskExecutions) + GlobalMock := mocket.Catcher.Reset() + GlobalMock.NewMock().WithQuery(`SELECT "task_executions"."id","task_executions"."created_at","task_executions"."updated_at","task_executions"."deleted_at","task_executions"."project","task_executions"."domain","task_executions"."name","task_executions"."version","task_executions"."execution_project","task_executions"."execution_domain","task_executions"."execution_name","task_executions"."node_id","task_executions"."retry_attempt","task_executions"."phase","task_executions"."phase_version","task_executions"."input_uri","task_executions"."closure","task_executions"."started_at","task_executions"."task_execution_created_at","task_executions"."task_execution_updated_at","task_executions"."duration" FROM "task_executions" INNER JOIN executions ON task_executions.execution_project = executions.execution_project AND task_executions.execution_domain = executions.execution_domain AND task_executions.execution_name = executions.execution_name WHERE tasks.project = $1 AND tasks.domain = $2 AND tasks.name = $3 AND tasks.version = $4 AND tasks.org = $5 AND executions.execution_project = $6 AND executions.execution_domain = $7 AND executions.execution_name = $8 AND executions.org = $9 LIMIT 20`).WithReply(taskExecutions) collection, err := taskExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ InlineFilters: []common.InlineFilter{ @@ -178,11 +219,62 @@ func TestListTaskExecutionsForTaskExecution(t *testing.T) { getEqualityFilter(common.Task, "domain", "domain_t"), getEqualityFilter(common.Task, "name", "domain_t"), getEqualityFilter(common.Task, "version", "version_t"), + getEqualityFilter(common.Task, "org", "org_t"), + + getEqualityFilter(common.Execution, "project", "project_name"), + getEqualityFilter(common.Execution, "domain", "domain_name"), + getEqualityFilter(common.Execution, "name", "execution_name"), + getEqualityFilter(common.Execution, "org", "execution_org"), + }, + JoinTableEntities: map[common.Entity]bool{ + common.Execution: true, + }, + Limit: 20, + }) + assert.NoError(t, err) + assert.NotEmpty(t, collection) + assert.NotEmpty(t, collection.TaskExecutions) + assert.Len(t, collection.TaskExecutions, 1) + + for _, taskExecution := range collection.TaskExecutions { + assert.Equal(t, testTaskExecution.TaskExecutionKey, taskExecution.TaskExecutionKey) + assert.Equal(t, &retryAttemptValue, taskExecution.RetryAttempt) + assert.Equal(t, taskPhase, taskExecution.Phase) + assert.Equal(t, []byte("Test"), taskExecution.Closure) + assert.Equal(t, "testInput.pb", taskExecution.InputURI) + assert.Equal(t, taskStartedAt, *taskExecution.StartedAt) + assert.Equal(t, time.Hour, taskExecution.Duration) + } +} + +func TestListTaskExecutionsForNodeAndExecution(t *testing.T) { + taskExecutionRepo := NewTaskExecutionRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + taskExecutions := make([]map[string]interface{}, 0) + taskExecution := getMockTaskExecutionResponseFromDb(testTaskExecution) + taskExecutions = append(taskExecutions, taskExecution) + + GlobalMock := mocket.Catcher.Reset() + + GlobalMock.NewMock().WithQuery(`SELECT "task_executions"."id","task_executions"."created_at","task_executions"."updated_at","task_executions"."deleted_at","task_executions"."project","task_executions"."domain","task_executions"."name","task_executions"."version","task_executions"."execution_project","task_executions"."execution_domain","task_executions"."execution_name","task_executions"."node_id","task_executions"."retry_attempt","task_executions"."phase","task_executions"."phase_version","task_executions"."input_uri","task_executions"."closure","task_executions"."started_at","task_executions"."task_execution_created_at","task_executions"."task_execution_updated_at","task_executions"."duration" FROM "task_executions" INNER JOIN node_executions ON task_executions.node_id = node_executions.node_id AND task_executions.execution_project = node_executions.execution_project AND task_executions.execution_domain = node_executions.execution_domain AND task_executions.execution_name = node_executions.execution_name INNER JOIN executions ON task_executions.execution_project = executions.execution_project AND task_executions.execution_domain = executions.execution_domain AND task_executions.execution_name = executions.execution_name WHERE tasks.project = $1 AND tasks.domain = $2 AND tasks.name = $3 AND tasks.version = $4 AND tasks.org = $5 AND node_executions.phase = $6 AND executions.execution_project = $7 AND executions.execution_domain = $8 AND executions.execution_name = $9 AND executions.org = $10 LIMIT 20`).WithReply(taskExecutions) + + collection, err := taskExecutionRepo.List(context.Background(), interfaces.ListResourceInput{ + InlineFilters: []common.InlineFilter{ + getEqualityFilter(common.Task, "project", "project_tn"), + getEqualityFilter(common.Task, "domain", "domain_t"), + getEqualityFilter(common.Task, "name", "domain_t"), + getEqualityFilter(common.Task, "version", "version_t"), + getEqualityFilter(common.Task, "org", "org_t"), getEqualityFilter(common.NodeExecution, "phase", nodePhase), getEqualityFilter(common.Execution, "project", "project_name"), getEqualityFilter(common.Execution, "domain", "domain_name"), getEqualityFilter(common.Execution, "name", "execution_name"), + getEqualityFilter(common.Execution, "org", "execution_org"), + }, + JoinTableEntities: map[common.Entity]bool{ + common.NodeExecution: true, + common.Execution: true, }, Limit: 20, }) @@ -219,7 +311,7 @@ func TestCountTaskExecutions_Filters(t *testing.T) { GlobalMock := mocket.Catcher.Reset() GlobalMock.NewMock().WithQuery( - `SELECT count(*) FROM "task_executions" LEFT JOIN tasks ON task_executions.project = tasks.project AND task_executions.domain = tasks.domain AND task_executions.name = tasks.name AND task_executions.version = tasks.version INNER JOIN node_executions ON task_executions.node_id = node_executions.node_id AND task_executions.execution_project = node_executions.execution_project AND task_executions.execution_domain = node_executions.execution_domain AND task_executions.execution_name = node_executions.execution_name INNER JOIN executions ON node_executions.execution_project = executions.execution_project AND node_executions.execution_domain = executions.execution_domain AND node_executions.execution_name = executions.execution_name WHERE task_executions.phase = $1 AND "task_execution_updated_at" IS NULL`).WithReply([]map[string]interface{}{{"rows": 3}}) + `SELECT count(*) FROM "task_executions" WHERE task_executions.phase = $1 AND "task_execution_updated_at" IS NULL`).WithReply([]map[string]interface{}{{"rows": 3}}) count, err := taskExecutionRepo.Count(context.Background(), interfaces.CountResourceInput{ InlineFilters: []common.InlineFilter{