diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 9752839eb4..ca5ab11bcc 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -95,6 +95,8 @@ type nodeMetrics struct { reservationGetFailureCount labeled.Counter reservationReleaseSuccessCount labeled.Counter reservationReleaseFailureCount labeled.Counter + + acceleratedInputCount labeled.Counter } // recursiveNodeExector implements the executors.Node interfaces and is the starting point for @@ -760,7 +762,7 @@ func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructur if nodeInputs != nil { if config.GetConfig().AcceleratedInputs.Enabled { - replaceRemotePathsForMap(ctx, nodeInputs) + c.replaceRemotePathsForMap(ctx, nodeInputs) } inputsFile := v1alpha1.GetInputsFile(dataDir) if err := c.store.WriteProtobuf(ctx, inputsFile, storage.Options{}, nodeInputs); err != nil { @@ -1395,28 +1397,28 @@ func (c *nodeExecutor) HandleNode(ctx context.Context, dag executors.DAGStructur return c.handleQueuedOrRunningNode(ctx, nCtx, h) } -func replaceRemotePathsForMap(ctx context.Context, inputs *core.LiteralMap) { +func (c *nodeExecutor) replaceRemotePathsForMap(ctx context.Context, inputs *core.LiteralMap) { for _, value := range inputs.GetLiterals() { - replaceRemotePathsForLiteral(ctx, value) + c.replaceRemotePathsForLiteral(ctx, value) } } -func replaceRemotePathsForLiteral(ctx context.Context, literal *core.Literal) { +func (c *nodeExecutor) replaceRemotePathsForLiteral(ctx context.Context, literal *core.Literal) { initialURI := "" switch v := literal.GetValue().(type) { case *core.Literal_Scalar: switch s := v.Scalar.GetValue().(type) { case *core.Scalar_Blob: initialURI = s.Blob.GetUri() - s.Blob.Uri = replaceRemotePrefix(ctx, initialURI) + s.Blob.Uri = c.replaceRemotePrefix(ctx, initialURI) case *core.Scalar_Schema: initialURI = s.Schema.GetUri() - s.Schema.Uri = replaceRemotePrefix(ctx, initialURI) + s.Schema.Uri = c.replaceRemotePrefix(ctx, initialURI) case *core.Scalar_StructuredDataset: initialURI = s.StructuredDataset.GetUri() - s.StructuredDataset.Uri = replaceRemotePrefix(ctx, initialURI) + s.StructuredDataset.Uri = c.replaceRemotePrefix(ctx, initialURI) case *core.Scalar_Union: - replaceRemotePathsForLiteral(ctx, s.Union.GetValue()) + c.replaceRemotePathsForLiteral(ctx, s.Union.GetValue()) } if initialURI != "" { if literal.Metadata == nil { @@ -1425,20 +1427,21 @@ func replaceRemotePathsForLiteral(ctx context.Context, literal *core.Literal) { literal.Metadata["initial_uri"] = initialURI } case *core.Literal_Map: - replaceRemotePathsForMap(ctx, v.Map) + c.replaceRemotePathsForMap(ctx, v.Map) case *core.Literal_Collection: for _, item := range v.Collection.GetLiterals() { - replaceRemotePathsForLiteral(ctx, item) + c.replaceRemotePathsForLiteral(ctx, item) } } } -func replaceRemotePrefix(ctx context.Context, s string) string { +func (c *nodeExecutor) replaceRemotePrefix(ctx context.Context, s string) string { cfg := config.GetConfig().AcceleratedInputs remotePrefix := cfg.RemotePathPrefix localPrefix := cfg.LocalPathPrefix if strings.HasPrefix(s, remotePrefix) { logger.Debugf(ctx, "replacing remote input prefix in %s with local %s", s, localPrefix) + c.metrics.acceleratedInputCount.Inc(ctx) return path.Join(localPrefix, strings.TrimPrefix(s, remotePrefix)) } return s @@ -1487,6 +1490,7 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora reservationGetSuccessCount: labeled.NewCounter("reservation_get_success_count", "Reservation GetOrExtend success count", scope), reservationReleaseFailureCount: labeled.NewCounter("reservation_release_failure_count", "Reservation Release failure count", scope), reservationReleaseSuccessCount: labeled.NewCounter("reservation_release_success_count", "Reservation Release success count", scope), + acceleratedInputCount: labeled.NewCounter("accelerated_input_count", "Number of accelerated inputs", scope), } nodeExecutor := &nodeExecutor{ diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go index 4a63bef96d..9df891b4db 100644 --- a/flytepropeller/pkg/controller/nodes/executor_test.go +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -3101,8 +3101,15 @@ func Test_replaceRemotePathsForMap(t *testing.T) { }, }, } + nodeExecutor := &nodeExecutor{ + metrics: &nodeMetrics{ + acceleratedInputCount: labeled.Counter{ + CounterVec: prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{}), + }, + }, + } - replaceRemotePathsForMap(ctx, input) + nodeExecutor.replaceRemotePathsForMap(ctx, input) assert.Equal(t, expected, input) }