Skip to content

Commit

Permalink
make execution manager changes
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Jan 17, 2024
1 parent 81278f9 commit 2bc6e1d
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 45 deletions.
80 changes: 35 additions & 45 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ func (m *ExecutionManager) getStringFromInput(ctx context.Context, inputBinding
strVal = p.GetStringValue()
case *core.Primitive_Datetime:
t := time.Unix(p.GetDatetime().Seconds, int64(p.GetDatetime().Nanos))
t = t.In(time.UTC)
strVal = t.Format("2006-01-02")
case *core.Primitive_StringValue:
strVal = p.GetStringValue()
Expand Down Expand Up @@ -776,46 +777,6 @@ func (m *ExecutionManager) fillInTemplateArgs(ctx context.Context, query core.Ar
if query.GetUri() != "" {
// If a query string, then just pass it through, nothing to fill in.
return query, nil
} else if query.GetArtifactTag() != nil {
t := query.GetArtifactTag()
ak := t.GetArtifactKey()
if ak == nil {
return query, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "tag doesn't have key")
}
var project, domain string
if ak.GetProject() == "" {
project = contextutils.Value(ctx, contextutils.ProjectKey)
} else {
project = ak.GetProject()
}
if ak.GetDomain() == "" {
domain = contextutils.Value(ctx, contextutils.DomainKey)
} else {
domain = ak.GetDomain()
}
strValue, err := m.getLabelValue(ctx, t.GetValue(), inputs)
if err != nil {
logger.Errorf(ctx, "Failed to template input string [%s] [%v]", t.GetValue(), err)
return query, err
}

return core.ArtifactQuery{
Identifier: &core.ArtifactQuery_ArtifactTag{
ArtifactTag: &core.ArtifactTag{
ArtifactKey: &core.ArtifactKey{
Project: project,
Domain: domain,
Name: ak.GetName(),
},
Value: &core.LabelValue{
Value: &core.LabelValue_StaticValue{
StaticValue: strValue,
},
},
},
},
}, nil

} else if query.GetArtifactId() != nil {
artifactID := query.GetArtifactId()
ak := artifactID.GetArtifactKey()
Expand All @@ -836,7 +797,7 @@ func (m *ExecutionManager) fillInTemplateArgs(ctx context.Context, query core.Ar

var partitions map[string]*core.LabelValue

if artifactID.GetPartitions() != nil && artifactID.GetPartitions().GetValue() != nil {
if artifactID.GetPartitions().GetValue() != nil {
partitions = make(map[string]*core.LabelValue, len(artifactID.GetPartitions().Value))
for k, v := range artifactID.GetPartitions().GetValue() {
newValue, err := m.getLabelValue(ctx, v, inputs)
Expand All @@ -847,6 +808,36 @@ func (m *ExecutionManager) fillInTemplateArgs(ctx context.Context, query core.Ar
partitions[k] = &core.LabelValue{Value: &core.LabelValue_StaticValue{StaticValue: newValue}}
}
}

var timePartition *core.TimePartition
if artifactID.GetTimePartition().GetValue() != nil {
if artifactID.GetTimePartition().Value.GetTimeValue() != nil {
// If the time value is set, then just pass it through, nothing to fill in.
timePartition = artifactID.GetTimePartition()

Check warning on line 816 in flyteadmin/pkg/manager/impl/execution_manager.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/manager/impl/execution_manager.go#L815-L816

Added lines #L815 - L816 were not covered by tests
} else if artifactID.GetTimePartition().Value.GetInputBinding() != nil {
// Evaluate the time partition input binding
lit, ok := inputs[artifactID.GetTimePartition().Value.GetInputBinding().GetVar()]
if !ok {
return query, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "time partition input binding var [%s] not found in inputs %v", artifactID.GetTimePartition().Value.GetInputBinding().GetVar(), inputs)
}

if lit.GetScalar().GetPrimitive().GetDatetime() == nil {
return query, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"time partition binding to input var [%s] failing because %v is not a datetime",
artifactID.GetTimePartition().Value.GetInputBinding().GetVar(), lit)
}

Check warning on line 828 in flyteadmin/pkg/manager/impl/execution_manager.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/manager/impl/execution_manager.go#L825-L828

Added lines #L825 - L828 were not covered by tests
timePartition = &core.TimePartition{
Value: &core.LabelValue{
Value: &core.LabelValue_TimeValue{
TimeValue: lit.GetScalar().GetPrimitive().GetDatetime(),
},
},
}
} else {
return query, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "time partition value cannot be empty when evaluating query: %v", query)
}

Check warning on line 838 in flyteadmin/pkg/manager/impl/execution_manager.go

View check run for this annotation

Codecov / codecov/patch

flyteadmin/pkg/manager/impl/execution_manager.go#L836-L838

Added lines #L836 - L838 were not covered by tests
}

return core.ArtifactQuery{
Identifier: &core.ArtifactQuery_ArtifactId{
ArtifactId: &core.ArtifactID{
Expand All @@ -855,11 +846,10 @@ func (m *ExecutionManager) fillInTemplateArgs(ctx context.Context, query core.Ar
Domain: domain,
Name: ak.GetName(),
},
Dimensions: &core.ArtifactID_Partitions{
Partitions: &core.Partitions{
Value: partitions,
},
Partitions: &core.Partitions{
Value: partitions,
},
TimePartition: timePartition,
},
},
}, nil
Expand Down
91 changes: 91 additions & 0 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5719,5 +5719,96 @@ func TestAddStateFilter(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "state <> ?", expression.Query)
})
}

func TestQueryTemplate(t *testing.T) {
ctx := context.Background()

aTime := time.Date(
2063, 4, 5, 00, 00, 00, 0, time.UTC)

rawInputs := map[string]interface{}{
"aStr": "hello world",
"anInt": 1,
"aFloat": 1.3,
"aTime": aTime,
}

otherInputs, err := coreutils.MakeLiteralMap(rawInputs)
assert.NoError(t, err)

m := ExecutionManager{}

ak := &core.ArtifactKey{
Project: "project",
Domain: "domain",
Name: "testname",
}

t.Run("test all present, nothing to fill in", func(t *testing.T) {
pMap := map[string]*core.LabelValue{
"partition1": {Value: &core.LabelValue_StaticValue{StaticValue: "my value"}},
"partition2": {Value: &core.LabelValue_StaticValue{StaticValue: "my value 2"}},
}
p := &core.Partitions{Value: pMap}

q := core.ArtifactQuery{
Identifier: &core.ArtifactQuery_ArtifactId{
ArtifactId: &core.ArtifactID{
ArtifactKey: ak,
Partitions: p,
TimePartition: nil,
},
},
}

filledQuery, err := m.fillInTemplateArgs(ctx, q, otherInputs.Literals)
assert.NoError(t, err)
assert.True(t, proto.Equal(&q, &filledQuery))
})

t.Run("template date-times, both in explicit tp and not", func(t *testing.T) {
pMap := map[string]*core.LabelValue{
"partition1": {Value: &core.LabelValue_InputBinding{InputBinding: &core.InputBindingData{Var: "aTime"}}},
"partition2": {Value: &core.LabelValue_StaticValue{StaticValue: "my value 2"}},
}
p := &core.Partitions{Value: pMap}

q := core.ArtifactQuery{
Identifier: &core.ArtifactQuery_ArtifactId{
ArtifactId: &core.ArtifactID{
ArtifactKey: ak,
Partitions: p,
TimePartition: &core.TimePartition{Value: &core.LabelValue{Value: &core.LabelValue_InputBinding{InputBinding: &core.InputBindingData{Var: "aTime"}}}},
},
},
}

filledQuery, err := m.fillInTemplateArgs(ctx, q, otherInputs.Literals)
assert.NoError(t, err)
staticTime := filledQuery.GetArtifactId().Partitions.Value["partition1"].GetStaticValue()
assert.Equal(t, "2063-04-05", staticTime)
assert.Equal(t, int64(2942956800), filledQuery.GetArtifactId().TimePartition.Value.GetTimeValue().Seconds)
})

t.Run("something missing", func(t *testing.T) {
pMap := map[string]*core.LabelValue{
"partition1": {Value: &core.LabelValue_StaticValue{StaticValue: "my value"}},
"partition2": {Value: &core.LabelValue_StaticValue{StaticValue: "my value 2"}},
}
p := &core.Partitions{Value: pMap}

q := core.ArtifactQuery{
Identifier: &core.ArtifactQuery_ArtifactId{
ArtifactId: &core.ArtifactID{
ArtifactKey: ak,
Partitions: p,
TimePartition: &core.TimePartition{Value: &core.LabelValue{Value: &core.LabelValue_InputBinding{InputBinding: &core.InputBindingData{Var: "wrong var"}}}},
},
},
}

_, err := m.fillInTemplateArgs(ctx, q, otherInputs.Literals)
assert.Error(t, err)
})
}

0 comments on commit 2bc6e1d

Please sign in to comment.