diff --git a/go.mod b/go.mod index c8816706d..3d15572f8 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.10.0 - github.com/flyteorg/flyteidl v0.21.4 - github.com/flyteorg/flyteplugins v0.8.0 + github.com/flyteorg/flyteidl v0.21.11 + github.com/flyteorg/flyteplugins v0.8.1 github.com/flyteorg/flytestdlib v0.4.4 github.com/ghodss/yaml v1.0.0 github.com/go-redis/redis v6.15.7+incompatible diff --git a/go.sum b/go.sum index 2542a4de2..b2378eecd 100644 --- a/go.sum +++ b/go.sum @@ -236,10 +236,10 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flyteidl v0.21.4 h1:gtJK5rX2ydLAo2xLRHHznOSLuLHrRRdXDbpEAlxluhk= -github.com/flyteorg/flyteidl v0.21.4/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= -github.com/flyteorg/flyteplugins v0.8.0 h1:Jiy7Ugm9olGmm5OFAbbxv/VfVmYib3JqGdeytyoiwnU= -github.com/flyteorg/flyteplugins v0.8.0/go.mod h1:kOiuXk1ddIEVSPoHcc4kBfVQcLuyf8jw3vWJT2Was90= +github.com/flyteorg/flyteidl v0.21.11 h1:oH9YPoR7scO9GFF/I8D0gCTOB+JP5HRK7b7cLUBRz90= +github.com/flyteorg/flyteidl v0.21.11/go.mod h1:576W2ViEyjTpT+kEVHAGbrTP3HARNUZ/eCwrNPmdx9U= +github.com/flyteorg/flyteplugins v0.8.1 h1:wZ8JRWOXPZ2+O5LI2kxwkTaoxER2ag+iYpm5S8KLmww= +github.com/flyteorg/flyteplugins v0.8.1/go.mod h1:tmU5lkRQjftCNd7T4gHykh5zZNNTdrxNmQRSBrFWCyg= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= github.com/flyteorg/flytestdlib v0.3.36/go.mod h1:7cDWkY3v7xsoesFcDdu6DSW5Q2U2W5KlHUbUHSwBG1Q= github.com/flyteorg/flytestdlib v0.4.4 h1:oPADei4KEjxtUqkTwrIjXB1nuH+JEKjwmwF92DSO3NM= diff --git a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go index 31ffda868..d91936553 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog.go @@ -264,6 +264,85 @@ func (m *CatalogClient) Put(ctx context.Context, key catalog.Key, reader io.Outp return catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, EventCatalogMetadata(datasetID, tag, nil)), nil } +// GetOrExtendReservation attempts to get a reservation for the cachable task. If you have +// previously acquired a reservation it will be extended. If another entity holds the reservation +// that is returned. +func (m *CatalogClient) GetOrExtendReservation(ctx context.Context, key catalog.Key, ownerID string, heartbeatInterval time.Duration) (*datacatalog.Reservation, error) { + datasetID, err := GenerateDatasetIDForTask(ctx, key) + if err != nil { + return nil, err + } + + inputs := &core.LiteralMap{} + if key.TypedInterface.Inputs != nil { + retInputs, err := key.InputReader.Get(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to read inputs when trying to query catalog") + } + inputs = retInputs + } + + tag, err := GenerateArtifactTagName(ctx, inputs) + if err != nil { + return nil, err + } + + reservationQuery := &datacatalog.GetOrExtendReservationRequest{ + ReservationId: &datacatalog.ReservationID{ + DatasetId: datasetID, + TagName: tag, + }, + OwnerId: ownerID, + HeartbeatInterval: ptypes.DurationProto(heartbeatInterval), + } + + response, err := m.client.GetOrExtendReservation(ctx, reservationQuery) + if err != nil { + return nil, err + } + + return response.Reservation, nil +} + +// ReleaseReservation attempts to release a reservation for a cachable task. If the reservation +// does not exist (e.x. it never existed or has been acquired by another owner) then this call +// still succeeds. +func (m *CatalogClient) ReleaseReservation(ctx context.Context, key catalog.Key, ownerID string) error { + datasetID, err := GenerateDatasetIDForTask(ctx, key) + if err != nil { + return err + } + + inputs := &core.LiteralMap{} + if key.TypedInterface.Inputs != nil { + retInputs, err := key.InputReader.Get(ctx) + if err != nil { + return errors.Wrap(err, "failed to read inputs when trying to query catalog") + } + inputs = retInputs + } + + tag, err := GenerateArtifactTagName(ctx, inputs) + if err != nil { + return err + } + + reservationQuery := &datacatalog.ReleaseReservationRequest{ + ReservationId: &datacatalog.ReservationID{ + DatasetId: datasetID, + TagName: tag, + }, + OwnerId: ownerID, + } + + _, err = m.client.ReleaseReservation(ctx, reservationQuery) + if err != nil { + return err + } + + return nil +} + // Create a new Datacatalog client for task execution caching func NewDataCatalog(ctx context.Context, endpoint string, insecureConnection bool, maxCacheAge time.Duration) (*CatalogClient, error) { var opts []grpc.DialOption diff --git a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog_test.go b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog_test.go index 6d45676ef..72d9a4efc 100644 --- a/pkg/controller/nodes/task/catalog/datacatalog/datacatalog_test.go +++ b/pkg/controller/nodes/task/catalog/datacatalog/datacatalog_test.go @@ -539,3 +539,132 @@ func TestCatalog_Put(t *testing.T) { }) } + +var tagName = "flyte_cached-BE6CZsMk6N3ExR_4X9EuwBgj2Jh2UwasXK3a_pM9xlY" +var reservationID = datacatalog.ReservationID{ + DatasetId: datasetID, + TagName: tagName, +} +var prevOwner = "prevOwner" +var currentOwner = "currentOwner" + +func TestCatalog_GetOrExtendReservation(t *testing.T) { + ctx := context.Background() + + heartbeatInterval := time.Second * 5 + prevReservation := datacatalog.Reservation{ + ReservationId: &reservationID, + OwnerId: prevOwner, + } + + currentReservation := datacatalog.Reservation{ + ReservationId: &reservationID, + OwnerId: currentOwner, + } + + t.Run("CreateOrUpdateReservation", func(t *testing.T) { + ir := &mocks2.InputReader{} + ir.On("Get", mock.Anything).Return(sampleParameters, nil, nil) + + mockClient := &mocks.DataCatalogClient{} + catalogClient := &CatalogClient{ + client: mockClient, + } + + mockClient.On("GetOrExtendReservation", + ctx, + mock.MatchedBy(func(o *datacatalog.GetOrExtendReservationRequest) bool { + assert.EqualValues(t, datasetID.String(), o.ReservationId.DatasetId.String()) + assert.EqualValues(t, tagName, o.ReservationId.TagName) + return true + }), + ).Return(&datacatalog.GetOrExtendReservationResponse{Reservation: ¤tReservation}, nil, "") + + newKey := sampleKey + newKey.InputReader = ir + reservation, err := catalogClient.GetOrExtendReservation(ctx, newKey, currentOwner, heartbeatInterval) + + assert.NoError(t, err) + assert.Equal(t, reservation.OwnerId, currentOwner) + }) + + t.Run("ExistingReservation", func(t *testing.T) { + ir := &mocks2.InputReader{} + ir.On("Get", mock.Anything).Return(sampleParameters, nil, nil) + + mockClient := &mocks.DataCatalogClient{} + catalogClient := &CatalogClient{ + client: mockClient, + } + + mockClient.On("GetOrExtendReservation", + ctx, + mock.MatchedBy(func(o *datacatalog.GetOrExtendReservationRequest) bool { + assert.EqualValues(t, datasetID.String(), o.ReservationId.DatasetId.String()) + assert.EqualValues(t, tagName, o.ReservationId.TagName) + return true + }), + ).Return(&datacatalog.GetOrExtendReservationResponse{Reservation: &prevReservation}, nil, "") + + newKey := sampleKey + newKey.InputReader = ir + reservation, err := catalogClient.GetOrExtendReservation(ctx, newKey, currentOwner, heartbeatInterval) + + assert.NoError(t, err) + assert.Equal(t, reservation.OwnerId, prevOwner) + }) +} + +func TestCatalog_ReleaseReservation(t *testing.T) { + ctx := context.Background() + + t.Run("ReleaseReservation", func(t *testing.T) { + ir := &mocks2.InputReader{} + ir.On("Get", mock.Anything).Return(sampleParameters, nil, nil) + + mockClient := &mocks.DataCatalogClient{} + catalogClient := &CatalogClient{ + client: mockClient, + } + + mockClient.On("ReleaseReservation", + ctx, + mock.MatchedBy(func(o *datacatalog.ReleaseReservationRequest) bool { + assert.EqualValues(t, datasetID.String(), o.ReservationId.DatasetId.String()) + assert.EqualValues(t, tagName, o.ReservationId.TagName) + return true + }), + ).Return(&datacatalog.ReleaseReservationResponse{}, nil, "") + + newKey := sampleKey + newKey.InputReader = ir + err := catalogClient.ReleaseReservation(ctx, newKey, currentOwner) + + assert.NoError(t, err) + }) + + t.Run("ReleaseReservationFailure", func(t *testing.T) { + ir := &mocks2.InputReader{} + ir.On("Get", mock.Anything).Return(sampleParameters, nil, nil) + + mockClient := &mocks.DataCatalogClient{} + catalogClient := &CatalogClient{ + client: mockClient, + } + + mockClient.On("ReleaseReservation", + ctx, + mock.MatchedBy(func(o *datacatalog.ReleaseReservationRequest) bool { + assert.EqualValues(t, datasetID.String(), o.ReservationId.DatasetId.String()) + assert.EqualValues(t, tagName, o.ReservationId.TagName) + return true + }), + ).Return(nil, status.Error(codes.NotFound, "reservation not found")) + + newKey := sampleKey + newKey.InputReader = ir + err := catalogClient.ReleaseReservation(ctx, newKey, currentOwner) + + assertGrpcErr(t, err, codes.NotFound) + }) +} diff --git a/pkg/controller/nodes/task/catalog/noop_catalog.go b/pkg/controller/nodes/task/catalog/noop_catalog.go index 9d33d3e59..2f799e442 100644 --- a/pkg/controller/nodes/task/catalog/noop_catalog.go +++ b/pkg/controller/nodes/task/catalog/noop_catalog.go @@ -2,8 +2,10 @@ package catalog import ( "context" + "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/datacatalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ) @@ -20,3 +22,11 @@ func (n NOOPCatalog) Get(_ context.Context, _ catalog.Key) (catalog.Entry, error func (n NOOPCatalog) Put(_ context.Context, _ catalog.Key, _ io.OutputReader, _ catalog.Metadata) (catalog.Status, error) { return disabledStatus, nil } + +func (n NOOPCatalog) GetOrExtendReservation(_ context.Context, _ catalog.Key, _ string, _ time.Duration) (*datacatalog.Reservation, error) { + return nil, nil +} + +func (n NOOPCatalog) ReleaseReservation(_ context.Context, _ catalog.Key, _ string) error { + return nil +} diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index b5cb30ce2..4b0cb9f85 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -42,15 +42,19 @@ import ( const pluginContextKey = contextutils.Key("plugin") type metrics struct { - pluginPanics labeled.Counter - unsupportedTaskType labeled.Counter - catalogPutFailureCount labeled.Counter - catalogGetFailureCount labeled.Counter - catalogPutSuccessCount labeled.Counter - catalogMissCount labeled.Counter - catalogHitCount labeled.Counter - pluginExecutionLatency labeled.StopWatch - pluginQueueLatency labeled.StopWatch + pluginPanics labeled.Counter + unsupportedTaskType labeled.Counter + catalogPutFailureCount labeled.Counter + catalogGetFailureCount labeled.Counter + catalogPutSuccessCount labeled.Counter + catalogMissCount labeled.Counter + catalogHitCount labeled.Counter + pluginExecutionLatency labeled.StopWatch + pluginQueueLatency labeled.StopWatch + reservationGetSuccessCount labeled.Counter + reservationGetFailureCount labeled.Counter + reservationReleaseSuccessCount labeled.Counter + reservationReleaseFailureCount labeled.Counter // TODO We should have a metric to capture custom state size scope promutils.Scope @@ -90,6 +94,20 @@ func (p *pluginRequestedTransition) PopulateCacheInfo(entry catalog.Entry) { } } +// PopulateReservationInfo sets the ReservationStatus of a requested plugin transition based on the +// provided ReservationEntry. +func (p *pluginRequestedTransition) PopulateReservationInfo(entry catalog.ReservationEntry) { + if p.execInfo.TaskNodeInfo == nil { + p.execInfo.TaskNodeInfo = &handler.TaskNodeInfo{ + TaskNodeMetadata: &event.TaskNodeMetadata{ + ReservationStatus: entry.GetStatus(), + }, + } + } else { + p.execInfo.TaskNodeInfo.TaskNodeMetadata.ReservationStatus = entry.GetStatus() + } +} + func (p *pluginRequestedTransition) ObservedTransitionAndState(trns pluginCore.Transition, pluginStateVersion uint32, pluginState []byte) { p.ttype = ToTransitionType(trns.Type()) p.pInfo = trns.Info() @@ -494,7 +512,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) // TODO @kumare re-evaluate this decision // STEP 1: Check Cache - if ts.PluginPhase == pluginCore.PhaseUndefined && checkCatalog { + if (ts.PluginPhase == pluginCore.PhaseUndefined || ts.PluginPhase == pluginCore.PhaseWaitingForCache) && checkCatalog { // This is assumed to be first time. we will check catalog and call handle entry, err := t.CheckCatalogCache(ctx, tCtx.tr, nCtx.InputReader(), tCtx.ow) if err != nil { @@ -527,9 +545,42 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) } } + // Check catalog for cache reservation and acquire if none exists + if checkCatalog && (pluginTrns.execInfo.TaskNodeInfo == nil || pluginTrns.execInfo.TaskNodeInfo.TaskNodeMetadata.CacheStatus != core.CatalogCacheStatus_CACHE_HIT) { + ownerID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + reservation, err := t.GetOrExtendCatalogReservation(ctx, ownerID, controllerConfig.GetConfig().WorkflowReEval.Duration, tCtx.tr, nCtx.InputReader()) + if err != nil { + logger.Errorf(ctx, "failed to get or extend catalog reservation with error") + return handler.UnknownTransition, err + } + + pluginTrns.PopulateReservationInfo(reservation) + + if reservation.GetStatus() == core.CatalogReservation_RESERVATION_ACQUIRED && + (ts.PluginPhase == pluginCore.PhaseUndefined || ts.PluginPhase == pluginCore.PhaseWaitingForCache) { + logger.Infof(ctx, "Acquired cache reservation") + } + + // If we do not own the reservation then we transition to WaitingForCache phase. If we are + // already running (ie. in a phase other than PhaseUndefined or PhaseWaitingForCache) and + // somehow lost the reservation (ex. by expiration), continue to execute until completion. + if reservation.GetStatus() == core.CatalogReservation_RESERVATION_EXISTS { + if ts.PluginPhase == pluginCore.PhaseUndefined || ts.PluginPhase == pluginCore.PhaseWaitingForCache { + pluginTrns.ttype = handler.TransitionTypeEphemeral + pluginTrns.pInfo = pluginCore.PhaseInfoWaitingForCache(pluginCore.DefaultPhaseVersion, nil) + } + + if ts.PluginPhase == pluginCore.PhaseWaitingForCache { + logger.Debugf(ctx, "No state change for Task, previously observed same transition. Short circuiting.") + return pluginTrns.FinalTransition(ctx) + } + } + } + barrierTick := uint32(0) - // STEP 2: If no cache-hit, then lets invoke the plugin and wait for a transition out of undefined - if pluginTrns.execInfo.TaskNodeInfo == nil || pluginTrns.execInfo.TaskNodeInfo.TaskNodeMetadata.CacheStatus != core.CatalogCacheStatus_CACHE_HIT { + // STEP 2: If no cache-hit and not transitioning to PhaseWaitingForCache, then lets invoke the plugin and wait for a transition out of undefined + if pluginTrns.execInfo.TaskNodeInfo == nil || (pluginTrns.pInfo.Phase() != pluginCore.PhaseWaitingForCache && + pluginTrns.execInfo.TaskNodeInfo.TaskNodeMetadata.CacheStatus != core.CatalogCacheStatus_CACHE_HIT) { prevBarrier := t.barrierCache.GetPreviousBarrierTransition(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) // Lets start with the current barrierTick (the value to be stored) same as the barrierTick in the cache barrierTick = prevBarrier.BarrierClockTick @@ -734,6 +785,14 @@ func (t Handler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext err = fmt.Errorf("panic when executing a plugin for TaskType [%s]. Stack: [%s]", ttype, string(stack)) } }() + + // release catalog reservation (if exists) + ownerID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + _, err = t.ReleaseCatalogReservation(ctx, ownerID, tCtx.tr, tCtx.InputReader()) + if err != nil { + return errors.Wrapf(errors.CatalogCallFailed, nCtx.NodeID(), err, "failed to release reservation") + } + childCtx := context.WithValue(ctx, pluginContextKey, p.GetID()) err = p.Finalize(childCtx, tCtx) return @@ -758,16 +817,20 @@ func New(ctx context.Context, kubeClient executors.Client, client catalog.Client pluginsForType: make(map[pluginCore.TaskType]map[pluginID]pluginCore.Plugin), taskMetricsMap: make(map[MetricKey]*taskMetrics), metrics: &metrics{ - pluginPanics: labeled.NewCounter("plugin_panic", "Task plugin paniced when trying to execute a Handler.", scope), - unsupportedTaskType: labeled.NewCounter("unsupported_tasktype", "No Handler plugin configured for Handler type", scope), - catalogHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", scope), - catalogMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", scope), - catalogPutSuccessCount: labeled.NewCounter("discovery_put_success_count", "Discovery Put success count", scope), - catalogPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", scope), - catalogGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", scope), - pluginExecutionLatency: labeled.NewStopWatch("plugin_exec_latency", "Time taken to invoke plugin for one round", time.Microsecond, scope), - pluginQueueLatency: labeled.NewStopWatch("plugin_queue_latency", "Time spent by plugin in queued phase", time.Microsecond, scope), - scope: scope, + pluginPanics: labeled.NewCounter("plugin_panic", "Task plugin paniced when trying to execute a Handler.", scope), + unsupportedTaskType: labeled.NewCounter("unsupported_tasktype", "No Handler plugin configured for Handler type", scope), + catalogHitCount: labeled.NewCounter("discovery_hit_count", "Task cached in Discovery", scope), + catalogMissCount: labeled.NewCounter("discovery_miss_count", "Task not cached in Discovery", scope), + catalogPutSuccessCount: labeled.NewCounter("discovery_put_success_count", "Discovery Put success count", scope), + catalogPutFailureCount: labeled.NewCounter("discovery_put_failure_count", "Discovery Put failure count", scope), + catalogGetFailureCount: labeled.NewCounter("discovery_get_failure_count", "Discovery Get faillure count", scope), + pluginExecutionLatency: labeled.NewStopWatch("plugin_exec_latency", "Time taken to invoke plugin for one round", time.Microsecond, scope), + pluginQueueLatency: labeled.NewStopWatch("plugin_queue_latency", "Time spent by plugin in queued phase", time.Microsecond, scope), + reservationGetFailureCount: labeled.NewCounter("reservation_get_failure_count", "Reservation GetOrExtend failure count", scope), + reservationGetSuccessCount: labeled.NewCounter("reservation_get_success_count", "Reservation GetOrExtend success count", scope), + reservationReleaseFailureCount: labeled.NewCounter("reservation_release_failure_count", "Reservation Release failure count", scope), + reservationReleaseSuccessCount: labeled.NewCounter("reservation_release_success_count", "Reservation Release success count", scope), + scope: scope, }, pluginScope: scope.NewSubScope("plugin"), kubeClient: kubeClient, diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index 142c4ef34..f5e0d0ee2 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -23,6 +23,7 @@ import ( "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/datacatalog" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginCatalogMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/catalog/mocks" @@ -939,6 +940,215 @@ func Test_task_Handle_Catalog(t *testing.T) { } } +func Test_task_Handle_Reservation(t *testing.T) { + + createNodeContext := func(recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder) *nodeMocks.NodeExecutionContext { + wfExecID := &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } + + nodeID := "n1" + + nm := &nodeMocks.NodeExecutionMetadata{} + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, + }) + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ + Kind: "sample", + Name: "name", + }) + nm.OnIsInterruptible().Return(true) + + taskID := &core.Identifier{} + tk := &core.TaskTemplate{ + Id: taskID, + Type: "test", + Metadata: &core.TaskMetadata{ + Discoverable: true, + CacheSerializable: true, + }, + Interface: &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + }, + }, + }, + }, + } + tr := &nodeMocks.TaskReader{} + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnReadMatch(mock.Anything).Return(tk, nil) + + ns := &flyteMocks.ExecutableNodeStatus{} + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) + + res := &v1.ResourceRequirements{} + n := &flyteMocks.ExecutableNode{} + ma := 5 + n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) + n.OnGetResources().Return(res) + + ir := &ioMocks.InputReader{} + ir.OnGetInputPath().Return(storage.DataReference("input")) + nCtx := &nodeMocks.NodeExecutionContext{} + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return(nodeID) + nCtx.OnEventsRecorder().Return(recorder) + nCtx.OnEnqueueOwnerFunc().Return(nil) + + executionContext := &mocks.ExecutionContext{} + executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) + executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) + executionContext.OnGetParentInfo().Return(nil) + executionContext.OnIncrementParallelism().Return(1) + nCtx.OnExecutionContext().Return(executionContext) + + nCtx.OnRawOutputPrefix().Return("s3://sandbox/") + nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) + + nCtx.OnNodeStateWriter().Return(s) + return nCtx + } + + noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) + + type args struct { + catalogFetch bool + pluginPhase pluginCore.Phase + ownerID string + } + type want struct { + pluginPhase pluginCore.Phase + handlerPhase handler.EPhase + eventPhase core.TaskExecution_Phase + } + tests := []struct { + name string + args args + want want + }{ + { + "reservation-create-or-update", + args{ + catalogFetch: false, + pluginPhase: pluginCore.PhaseUndefined, + ownerID: "name-n1-1", + }, + want{ + pluginPhase: pluginCore.PhaseSuccess, + handlerPhase: handler.EPhaseSuccess, + eventPhase: core.TaskExecution_SUCCEEDED, + }, + }, + { + "reservation-exists", + args{ + catalogFetch: false, + pluginPhase: pluginCore.PhaseUndefined, + ownerID: "nilOwner", + }, + want{ + pluginPhase: pluginCore.PhaseWaitingForCache, + handlerPhase: handler.EPhaseRunning, + eventPhase: core.TaskExecution_UNDEFINED, + }, + }, + { + "cache-hit", + args{ + catalogFetch: true, + pluginPhase: pluginCore.PhaseWaitingForCache, + }, + want{ + pluginPhase: pluginCore.PhaseSuccess, + handlerPhase: handler.EPhaseSuccess, + eventPhase: core.TaskExecution_SUCCEEDED, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + state := &taskNodeStateHolder{} + ev := &fakeBufferedTaskEventRecorder{} + nCtx := createNodeContext(ev, "test", state) + c := &pluginCatalogMocks.Client{} + nr := &nodeMocks.NodeStateReader{} + st := bytes.NewBuffer([]byte{}) + cod := codex.GobStateCodec{} + assert.NoError(t, cod.Encode(&fakeplugins.NextPhaseState{ + Phase: pluginCore.PhaseSuccess, + OutputExists: true, + }, st)) + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ + PluginPhase: tt.args.pluginPhase, + PluginState: st.Bytes(), + }) + nCtx.OnNodeStateReader().Return(nr) + if tt.args.catalogFetch { + or := &ioMocks.OutputReader{} + or.OnReadMatch(mock.Anything).Return(&core.LiteralMap{}, nil, nil) + c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.NewCatalogEntry(or, catalog.NewStatus(core.CatalogCacheStatus_CACHE_HIT, nil)), nil) + } else { + c.OnGetMatch(mock.Anything, mock.Anything).Return(catalog.NewFailedCatalogEntry(catalog.NewStatus(core.CatalogCacheStatus_CACHE_MISS, nil)), nil) + } + c.OnPutMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(catalog.NewStatus(core.CatalogCacheStatus_CACHE_POPULATED, nil), nil) + c.OnGetOrExtendReservationMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&datacatalog.Reservation{OwnerId: tt.args.ownerID}, nil) + tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, eventConfig, promutils.NewTestScope()) + assert.NoError(t, err) + tk.defaultPlugins = map[pluginCore.TaskType]pluginCore.Plugin{ + "test": fakeplugins.NewPhaseBasedPlugin(), + } + tk.catalog = c + tk.resourceManager = noopRm + got, err := tk.Handle(context.TODO(), nCtx) + if err != nil { + t.Errorf("Handler.Handle() error = %v", err) + return + } + if err == nil { + assert.Equal(t, tt.want.handlerPhase.String(), got.Info().GetPhase().String()) + if assert.Equal(t, 1, len(ev.evs)) { + e := ev.evs[0] + assert.Equal(t, tt.want.eventPhase.String(), e.Phase.String()) + } + assert.Equal(t, tt.want.pluginPhase.String(), state.s.PluginPhase.String()) + assert.Equal(t, uint32(0), state.s.PluginPhaseVersion) + } + }) + } +} + func Test_task_Handle_Barrier(t *testing.T) { // NOTE: Caching is disabled for this test @@ -1532,93 +1742,119 @@ func Test_task_Abort_v1(t *testing.T) { func Test_task_Finalize(t *testing.T) { - wfExecID := &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - } + createNodeContext := func(cacheSerializable bool) *nodeMocks.NodeExecutionContext { + wfExecID := &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } - nodeID := "n1" - - nm := &nodeMocks.NodeExecutionMetadata{} - nm.OnGetAnnotations().Return(map[string]string{}) - nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ - NodeId: nodeID, - ExecutionId: wfExecID, - }) - nm.OnGetK8sServiceAccount().Return("service-account") - nm.OnGetLabels().Return(map[string]string{}) - nm.OnGetNamespace().Return("namespace") - nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.OnGetOwnerReference().Return(v12.OwnerReference{ - Kind: "sample", - Name: "name", - }) - - taskID := &core.Identifier{} - tr := &nodeMocks.TaskReader{} - tr.OnGetTaskID().Return(taskID) - tr.OnGetTaskType().Return("x") - - ns := &flyteMocks.ExecutableNodeStatus{} - ns.OnGetDataDir().Return(storage.DataReference("data-dir")) - ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) - - res := &v1.ResourceRequirements{} - n := &flyteMocks.ExecutableNode{} - ma := 5 - n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) - n.OnGetResources().Return(res) - - ir := &ioMocks.InputReader{} - nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.OnNodeExecutionMetadata().Return(nm) - nCtx.OnNode().Return(n) - nCtx.OnInputReader().Return(ir) - ds, err := storage.NewDataStore( - &storage.Config{ - Type: storage.TypeMemory, - }, - promutils.NewTestScope(), - ) - assert.NoError(t, err) - nCtx.OnDataStore().Return(ds) - nCtx.OnCurrentAttempt().Return(uint32(1)) - nCtx.OnTaskReader().Return(tr) - nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) - nCtx.OnNodeStatus().Return(ns) - nCtx.OnNodeID().Return("n1") - nCtx.OnEventsRecorder().Return(nil) - nCtx.OnEnqueueOwnerFunc().Return(nil) - - executionContext := &mocks.ExecutionContext{} - executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) - executionContext.OnGetParentInfo().Return(nil) - executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) - nCtx.OnExecutionContext().Return(executionContext) - - nCtx.OnRawOutputPrefix().Return("s3://sandbox/") - nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) + nodeID := "n1" - noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) + nm := &nodeMocks.NodeExecutionMetadata{} + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, + }) + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ + Kind: "sample", + Name: "name", + }) - st := bytes.NewBuffer([]byte{}) - a := 45 - type test struct { - A int + taskID := &core.Identifier{} + tk := &core.TaskTemplate{ + Id: taskID, + Type: "test", + Metadata: &core.TaskMetadata{ + CacheSerializable: cacheSerializable, + Discoverable: cacheSerializable, + }, + Interface: &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + }, + }, + }, + }, + } + tr := &nodeMocks.TaskReader{} + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return("x") + tr.OnReadMatch(mock.Anything).Return(tk, nil) + + ns := &flyteMocks.ExecutableNodeStatus{} + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) + + res := &v1.ResourceRequirements{} + n := &flyteMocks.ExecutableNode{} + ma := 5 + n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) + n.OnGetResources().Return(res) + + ir := &ioMocks.InputReader{} + nCtx := &nodeMocks.NodeExecutionContext{} + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nCtx.OnEventsRecorder().Return(nil) + nCtx.OnEnqueueOwnerFunc().Return(nil) + + executionContext := &mocks.ExecutionContext{} + executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) + executionContext.OnGetParentInfo().Return(nil) + executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) + nCtx.OnExecutionContext().Return(executionContext) + + nCtx.OnRawOutputPrefix().Return("s3://sandbox/") + nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) + + st := bytes.NewBuffer([]byte{}) + a := 45 + type test struct { + A int + } + cod := codex.GobStateCodec{} + assert.NoError(t, cod.Encode(test{A: a}, st)) + nr := &nodeMocks.NodeStateReader{} + nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ + PluginState: st.Bytes(), + }) + nCtx.On("NodeStateReader").Return(nr) + return nCtx } - cod := codex.GobStateCodec{} - assert.NoError(t, cod.Encode(test{A: a}, st)) - nr := &nodeMocks.NodeStateReader{} - nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ - PluginState: st.Bytes(), - }) - nCtx.On("NodeStateReader").Return(nr) + + noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) type fields struct { defaultPluginCallback func() pluginCore.Plugin } type args struct { - nCtx handler.NodeExecutionContext + releaseReservation bool + releaseReservationError bool } tests := []struct { name string @@ -1629,7 +1865,7 @@ func Test_task_Finalize(t *testing.T) { }{ {"no-plugin", fields{defaultPluginCallback: func() pluginCore.Plugin { return nil - }}, args{nCtx: nCtx}, true, false}, + }}, args{}, true, false}, {"finalize-fails", fields{defaultPluginCallback: func() pluginCore.Plugin { p := &pluginCoreMocks.Plugin{} @@ -1637,23 +1873,46 @@ func Test_task_Finalize(t *testing.T) { p.OnGetProperties().Return(pluginCore.PluginProperties{}) p.On("Finalize", mock.Anything, mock.Anything).Return(fmt.Errorf("error")) return p - }}, args{nCtx: nCtx}, true, true}, + }}, args{}, true, true}, {"finalize-success", fields{defaultPluginCallback: func() pluginCore.Plugin { p := &pluginCoreMocks.Plugin{} p.On("GetID").Return("id") p.OnGetProperties().Return(pluginCore.PluginProperties{}) p.On("Finalize", mock.Anything, mock.Anything).Return(nil) return p - }}, args{nCtx: nCtx}, false, true}, + }}, args{}, false, true}, + {"release-reservation", fields{defaultPluginCallback: func() pluginCore.Plugin { + p := &pluginCoreMocks.Plugin{} + p.On("GetID").Return("id") + p.OnGetProperties().Return(pluginCore.PluginProperties{}) + p.On("Finalize", mock.Anything, mock.Anything).Return(nil) + return p + }}, args{releaseReservation: true}, false, true}, + {"release-reservation-error", fields{defaultPluginCallback: func() pluginCore.Plugin { + p := &pluginCoreMocks.Plugin{} + p.On("GetID").Return("id") + p.OnGetProperties().Return(pluginCore.PluginProperties{}) + p.On("Finalize", mock.Anything, mock.Anything).Return(nil) + return p + }}, args{releaseReservation: true, releaseReservationError: true}, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := tt.fields.defaultPluginCallback() - tk := Handler{ - defaultPlugin: m, - resourceManager: noopRm, + nCtx := createNodeContext(tt.args.releaseReservation) + + catalog := &pluginCatalogMocks.Client{} + if tt.args.releaseReservationError { + catalog.OnReleaseReservationMatch(mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("failed to release reservation")) + } else { + catalog.OnReleaseReservationMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) } - if err := tk.Finalize(context.TODO(), tt.args.nCtx); (err != nil) != tt.wantErr { + + m := tt.fields.defaultPluginCallback() + tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), catalog, eventConfig, promutils.NewTestScope()) + assert.NoError(t, err) + tk.defaultPlugin = m + tk.resourceManager = noopRm + if err := tk.Finalize(context.TODO(), nCtx); (err != nil) != tt.wantErr { t.Errorf("Handler.Finalize() error = %v, wantErr %v", err, tt.wantErr) } c := 0 diff --git a/pkg/controller/nodes/task/pre_post_execution.go b/pkg/controller/nodes/task/pre_post_execution.go index 3f3a54ae7..1a6903b76 100644 --- a/pkg/controller/nodes/task/pre_post_execution.go +++ b/pkg/controller/nodes/task/pre_post_execution.go @@ -2,6 +2,7 @@ package task import ( "context" + "time" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -71,6 +72,49 @@ func (t *Handler) CheckCatalogCache(ctx context.Context, tr pluginCore.TaskReade return catalog.NewCatalogEntry(nil, cacheDisabled), nil } +// GetOrExtendCatalogReservation attempts to acquire an artifact reservation if the task is +// cachable and cache serializable. If the reservation already exists for this owner, the +// reservation is extended. +func (t *Handler) GetOrExtendCatalogReservation(ctx context.Context, ownerID string, heartbeatInterval time.Duration, tr pluginCore.TaskReader, inputReader io.InputReader) (catalog.ReservationEntry, error) { + tk, err := tr.Read(ctx) + if err != nil { + logger.Errorf(ctx, "Failed to read TaskTemplate, error :%s", err.Error()) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err + } + + if tk.Metadata.Discoverable && tk.Metadata.CacheSerializable { + logger.Infof(ctx, "Catalog CacheSerializeEnabled: creating catalog reservation.") + key := catalog.Key{ + Identifier: *tk.Id, + CacheVersion: tk.Metadata.DiscoveryVersion, + TypedInterface: *tk.Interface, + InputReader: inputReader, + } + + reservation, err := t.catalog.GetOrExtendReservation(ctx, key, ownerID, heartbeatInterval) + if err != nil { + t.metrics.reservationGetFailureCount.Inc(ctx) + logger.Errorf(ctx, "Catalog Failure: reservation get or extend failed. err: %v", err.Error()) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err + } + + expiresAt := reservation.ExpiresAt.AsTime() + heartbeatInterval := reservation.HeartbeatInterval.AsDuration() + + var status core.CatalogReservation_Status + if reservation.OwnerId == ownerID { + status = core.CatalogReservation_RESERVATION_ACQUIRED + } else { + status = core.CatalogReservation_RESERVATION_EXISTS + } + + t.metrics.reservationGetSuccessCount.Inc(ctx) + return catalog.NewReservationEntry(expiresAt, heartbeatInterval, reservation.OwnerId, status), nil + } + logger.Infof(ctx, "Catalog CacheSerializeDisabled: for Task [%s/%s/%s/%s]", tk.Id.Project, tk.Id.Domain, tk.Id.Name, tk.Id.Version) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED), nil +} + func (t *Handler) ValidateOutputAndCacheAdd(ctx context.Context, nodeID v1alpha1.NodeID, i io.InputReader, r io.OutputReader, outputCommitter io.OutputWriter, executionConfig v1alpha1.ExecutionConfig, tr ioutils.SimpleTaskReader, m catalog.Metadata) (catalog.Status, *io.ExecutionError, error) { @@ -188,3 +232,36 @@ func (t *Handler) ValidateOutputAndCacheAdd(ctx context.Context, nodeID v1alpha1 logger.Infof(ctx, "Successfully cached results to catalog - Task [%v]", tk.GetId()) return s, nil, nil } + +// ReleaseCatalogReservation attempts to release an artifact reservation if the task is cachable +// and cache serializable. If the reservation does not exist for this owner (e.x. it never existed +// or has been acquired by another owner) this call is still successful. +func (t *Handler) ReleaseCatalogReservation(ctx context.Context, ownerID string, tr pluginCore.TaskReader, inputReader io.InputReader) (catalog.ReservationEntry, error) { + tk, err := tr.Read(ctx) + if err != nil { + logger.Errorf(ctx, "Failed to read TaskTemplate, error :%s", err.Error()) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err + } + + if tk.Metadata.Discoverable && tk.Metadata.CacheSerializable { + logger.Infof(ctx, "Catalog CacheSerializeEnabled: releasing catalog reservation.") + key := catalog.Key{ + Identifier: *tk.Id, + CacheVersion: tk.Metadata.DiscoveryVersion, + TypedInterface: *tk.Interface, + InputReader: inputReader, + } + + err := t.catalog.ReleaseReservation(ctx, key, ownerID) + if err != nil { + t.metrics.reservationReleaseFailureCount.Inc(ctx) + logger.Errorf(ctx, "Catalog Failure: release reservation failed. err: %v", err.Error()) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_FAILURE), err + } + + t.metrics.reservationReleaseSuccessCount.Inc(ctx) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_RELEASED), nil + } + logger.Infof(ctx, "Catalog CacheSerializeDisabled: for Task [%s/%s/%s/%s]", tk.Id.Project, tk.Id.Domain, tk.Id.Name, tk.Id.Version) + return catalog.NewReservationEntryStatus(core.CatalogReservation_RESERVATION_DISABLED), nil +} diff --git a/pkg/controller/workflow/errors/errors.go b/pkg/controller/workflow/errors/errors.go index b6695a0fb..1905cbfbc 100644 --- a/pkg/controller/workflow/errors/errors.go +++ b/pkg/controller/workflow/errors/errors.go @@ -26,6 +26,12 @@ func (w *WorkflowError) Is(target error) bool { if !ok { return false } + if w == nil && t == nil { + return true + } + if w == nil || t == nil { + return false + } return w.Code == t.Code } @@ -47,6 +53,12 @@ func (w *WorkflowErrorWithCause) Is(target error) bool { if !ok { return false } + if w == nil && t == nil { + return true + } + if w == nil || t == nil { + return false + } return w.Code == t.Code && (w.cause == t.cause || t.cause == nil) && (w.Message == t.Message || t.Message == "") && (w.Workflow == t.Workflow || t.Workflow == "") }