Skip to content

Commit

Permalink
Remove unnecessary joins for node and task execution entities in flyt…
Browse files Browse the repository at this point in the history
…eadmin db queries

Signed-off-by: Katrina Rogan <[email protected]>
  • Loading branch information
katrogan committed Oct 29, 2024
1 parent 553a702 commit 66f09b7
Show file tree
Hide file tree
Showing 14 changed files with 561 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecution
func (c *CloudEventWrappedPublisher) getLatestTaskExecutions(ctx context.Context, nodeExecutionID *core.NodeExecutionIdentifier) (*admin.TaskExecution, error) {
ctx = getNodeExecutionContext(ctx, nodeExecutionID)

identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID)
identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, nodeExecutionID, common.TaskExecution)
if err != nil {
return nil, err
}
Expand Down
9 changes: 8 additions & 1 deletion flyteadmin/pkg/common/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ var executionIdentifierFields = map[string]bool{
"name": true,
}

// Entities that have special case handling for execution identifier fields.
var executionIdentifierEntities = map[Entity]bool{
Execution: true,
NodeExecution: true,
TaskExecution: true,
}

var entityMetadataFields = map[string]bool{
"description": true,
"state": true,
Expand Down Expand Up @@ -253,7 +260,7 @@ func (f *inlineFilterImpl) GetGormJoinTableQueryExpr(tableName string) (GormQuer

func customizeField(field string, entity Entity) string {
// Execution identifier fields have to be customized because we differ from convention in those column names.
if entity == Execution && executionIdentifierFields[field] {
if executionIdentifierEntities[entity] && executionIdentifierFields[field] {
return fmt.Sprintf("execution_%s", field)
}
// admin_tag table has been migrated to an execution_tag table, so we need to customize the field name.
Expand Down
17 changes: 11 additions & 6 deletions flyteadmin/pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,16 @@ func (m *NodeExecutionManager) listNodeExecutions(
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListNodeExecutions", requestToken)
}
joinTableEntities := make(map[common.Entity]bool)
for _, filter := range filters {
joinTableEntities[filter.GetEntity()] = true
}
listInput := repoInterfaces.ListResourceInput{
Limit: int(limit),
Offset: offset,
InlineFilters: filters,
SortParameter: sortParameter,
Limit: int(limit),
Offset: offset,
InlineFilters: filters,
SortParameter: sortParameter,
JoinTableEntities: joinTableEntities,
}

listInput.MapFilters = mapFilters
Expand Down Expand Up @@ -445,7 +450,7 @@ func (m *NodeExecutionManager) ListNodeExecutions(
}
ctx = getExecutionContext(ctx, request.WorkflowExecutionId)

identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId)
identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId, common.NodeExecution)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -483,7 +488,7 @@ func (m *NodeExecutionManager) ListNodeExecutionsForTask(
}
ctx = getTaskExecutionContext(ctx, request.TaskExecutionId)
identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(
ctx, request.TaskExecutionId.NodeExecutionId.ExecutionId)
ctx, request.TaskExecutionId.NodeExecutionId.ExecutionId, common.NodeExecution)
if err != nil {
return nil, err
}
Expand Down
145 changes: 136 additions & 9 deletions flyteadmin/pkg/manager/impl/node_execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,17 +784,17 @@ func TestListNodeExecutionsLevelZero(t *testing.T) {
assert.Equal(t, 1, input.Limit)
assert.Equal(t, 2, input.Offset)
assert.Len(t, input.InlineFilters, 3)
assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity())
queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, "project", queryExpr.Args)
assert.Equal(t, "execution_project = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity())
queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr()
assert.Equal(t, "domain", queryExpr.Args)
assert.Equal(t, "execution_domain = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity())
queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr()
assert.Equal(t, "name", queryExpr.Args)
assert.Equal(t, "execution_name = ?", queryExpr.Query)
Expand All @@ -806,6 +806,10 @@ func TestListNodeExecutionsLevelZero(t *testing.T) {
"parent_task_execution_id": nil,
}, filter)

assert.EqualValues(t, input.JoinTableEntities, map[common.Entity]bool{
common.NodeExecution: true,
})

assert.Equal(t, "execution_domain asc", input.SortParameter.GetGormOrderExpr())
return interfaces.NodeExecutionCollectionOutput{
NodeExecutions: []models.NodeExecution{
Expand Down Expand Up @@ -904,17 +908,17 @@ func TestListNodeExecutionsWithParent(t *testing.T) {
assert.Equal(t, 1, input.Limit)
assert.Equal(t, 2, input.Offset)
assert.Len(t, input.InlineFilters, 4)
assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity())
queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, "project", queryExpr.Args)
assert.Equal(t, "execution_project = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity())
queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr()
assert.Equal(t, "domain", queryExpr.Args)
assert.Equal(t, "execution_domain = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity())
queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr()
assert.Equal(t, "name", queryExpr.Args)
assert.Equal(t, "execution_name = ?", queryExpr.Query)
Expand Down Expand Up @@ -979,6 +983,129 @@ func TestListNodeExecutionsWithParent(t *testing.T) {
assert.Equal(t, "3", nodeExecutions.Token)
}

func TestListNodeExecutions_WithJoinTableFilter(t *testing.T) {
repository := repositoryMocks.NewMockRepository()
expectedClosure := admin.NodeExecutionClosure{
Phase: core.NodeExecution_SUCCEEDED,
}
expectedMetadata := admin.NodeExecutionMetaData{
SpecNodeId: "spec_node_id",
RetryGroup: "retry_group",
}
metadataBytes, _ := proto.Marshal(&expectedMetadata)
closureBytes, _ := proto.Marshal(&expectedClosure)

repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetListCallback(
func(ctx context.Context, input interfaces.ListResourceInput) (
interfaces.NodeExecutionCollectionOutput, error) {
assert.Equal(t, 1, input.Limit)
assert.Equal(t, 2, input.Offset)
assert.Len(t, input.InlineFilters, 4)
assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity())
queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, "project", queryExpr.Args)
assert.Equal(t, "execution_project = ?", queryExpr.Query)

assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity())
queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr()
assert.Equal(t, "domain", queryExpr.Args)
assert.Equal(t, "execution_domain = ?", queryExpr.Query)

assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity())
queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr()
assert.Equal(t, "name", queryExpr.Args)
assert.Equal(t, "execution_name = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[3].GetEntity())
queryExpr, _ = input.InlineFilters[3].GetGormQueryExpr()
assert.Equal(t, "SUCCEEDED", queryExpr.Args)
assert.Equal(t, "phase = ?", queryExpr.Query)

assert.Len(t, input.MapFilters, 1)
filter := input.MapFilters[0].GetFilter()
assert.Equal(t, map[string]interface{}{
"parent_id": nil,
"parent_task_execution_id": nil,
}, filter)

assert.EqualValues(t, input.JoinTableEntities, map[common.Entity]bool{
common.NodeExecution: true,
common.Execution: true,
})

assert.Equal(t, "execution_domain asc", input.SortParameter.GetGormOrderExpr())
return interfaces.NodeExecutionCollectionOutput{
NodeExecutions: []models.NodeExecution{
{
NodeExecutionKey: models.NodeExecutionKey{
NodeID: "node id",
ExecutionKey: models.ExecutionKey{
Project: "project",
Domain: "domain",
Name: "name",
},
},
Phase: core.NodeExecution_SUCCEEDED.String(),
InputURI: "input uri",
StartedAt: &occurredAt,
Closure: closureBytes,
NodeExecutionMetadata: metadataBytes,
},
},
}, nil
})
repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetGetWithChildrenCallback(
func(
ctx context.Context, input interfaces.NodeExecutionResource) (models.NodeExecution, error) {
return models.NodeExecution{
NodeExecutionKey: models.NodeExecutionKey{
NodeID: "node id",
ExecutionKey: models.ExecutionKey{
Project: "project",
Domain: "domain",
Name: "name",
},
},
Phase: core.NodeExecution_SUCCEEDED.String(),
InputURI: "input uri",
StartedAt: &occurredAt,
Closure: closureBytes,
NodeExecutionMetadata: metadataBytes,
}, nil
})
nodeExecManager := NewNodeExecutionManager(repository, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{})
nodeExecutions, err := nodeExecManager.ListNodeExecutions(context.Background(), &admin.NodeExecutionListRequest{
WorkflowExecutionId: &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Name: "name",
},
Limit: 1,
Token: "2",
SortBy: &admin.Sort{
Direction: admin.Sort_ASCENDING,
Key: "execution_domain",
},
Filters: "eq(execution.phase, SUCCEEDED)",
})
assert.NoError(t, err)
assert.Len(t, nodeExecutions.NodeExecutions, 1)
assert.True(t, proto.Equal(&admin.NodeExecution{
Id: &core.NodeExecutionIdentifier{
NodeId: "node id",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Name: "name",
},
},
InputUri: "input uri",
Closure: &expectedClosure,
Metadata: &expectedMetadata,
}, nodeExecutions.NodeExecutions[0]))
assert.Equal(t, "3", nodeExecutions.Token)
}

func TestListNodeExecutions_InvalidParams(t *testing.T) {
nodeExecManager := NewNodeExecutionManager(nil, getMockExecutionsConfigProvider(), make([]string, 0), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockNodeExecutionRemoteURL, nil, nil, &eventWriterMocks.NodeExecutionEventWriter{})
_, err := nodeExecManager.ListNodeExecutions(context.Background(), &admin.NodeExecutionListRequest{
Expand Down Expand Up @@ -1120,17 +1247,17 @@ func TestListNodeExecutionsForTask(t *testing.T) {
assert.Equal(t, 1, input.Limit)
assert.Equal(t, 2, input.Offset)
assert.Len(t, input.InlineFilters, 4)
assert.Equal(t, common.Execution, input.InlineFilters[0].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[0].GetEntity())
queryExpr, _ := input.InlineFilters[0].GetGormQueryExpr()
assert.Equal(t, "project", queryExpr.Args)
assert.Equal(t, "execution_project = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[1].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[1].GetEntity())
queryExpr, _ = input.InlineFilters[1].GetGormQueryExpr()
assert.Equal(t, "domain", queryExpr.Args)
assert.Equal(t, "execution_domain = ?", queryExpr.Query)

assert.Equal(t, common.Execution, input.InlineFilters[2].GetEntity())
assert.Equal(t, common.NodeExecution, input.InlineFilters[2].GetEntity())
queryExpr, _ = input.InlineFilters[2].GetGormQueryExpr()
assert.Equal(t, "name", queryExpr.Args)
assert.Equal(t, "execution_name = ?", queryExpr.Query)
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/manager/impl/signal_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (s *SignalManager) ListSignals(ctx context.Context, request *admin.SignalLi
}
ctx = getExecutionContext(ctx, request.WorkflowExecutionId)

identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId)
identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, request.WorkflowExecutionId, common.Signal)
if err != nil {
return nil, err
}
Expand Down
15 changes: 10 additions & 5 deletions flyteadmin/pkg/manager/impl/task_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (m *TaskExecutionManager) ListTaskExecutions(
}
ctx = getNodeExecutionContext(ctx, request.NodeExecutionId)

identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, request.NodeExecutionId)
identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, request.NodeExecutionId, common.TaskExecution)
if err != nil {
return nil, err
}
Expand All @@ -267,12 +267,17 @@ func (m *TaskExecutionManager) ListTaskExecutions(
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListTaskExecutions", request.Token)
}
joinTableEntities := make(map[common.Entity]bool)
for _, filter := range filters {
joinTableEntities[filter.GetEntity()] = true
}

output, err := m.db.TaskExecutionRepo().List(ctx, repoInterfaces.ListResourceInput{
InlineFilters: filters,
Offset: offset,
Limit: int(request.Limit),
SortParameter: sortParameter,
InlineFilters: filters,
Offset: offset,
Limit: int(request.Limit),
SortParameter: sortParameter,
JoinTableEntities: joinTableEntities,
})
if err != nil {
logger.Debugf(ctx, "Failed to list task executions with request [%+v] with err %v",
Expand Down
Loading

0 comments on commit 66f09b7

Please sign in to comment.