From 599fda8d4cb65a1b8e69999ffe80db86e6c77286 Mon Sep 17 00:00:00 2001 From: colindickson Date: Thu, 10 Oct 2024 10:21:43 -0400 Subject: [PATCH] move metering middleware definitions to metering package --- metering/metering.go | 37 +++++++++++++++++++++++++++++++++++++ service/tier1.go | 34 ++-------------------------------- service/tier2.go | 19 +------------------ 3 files changed, 40 insertions(+), 50 deletions(-) diff --git a/metering/metering.go b/metering/metering.go index 477ea998..7794fc1e 100644 --- a/metering/metering.go +++ b/metering/metering.go @@ -5,6 +5,9 @@ import ( "fmt" "time" + "github.com/streamingfast/bstream" + pbbstream "github.com/streamingfast/bstream/pb/sf/bstream/v1" + "github.com/streamingfast/substreams/metrics" "github.com/streamingfast/dmetering" @@ -79,6 +82,40 @@ func GetTotalBytesWritten(meter dmetering.Meter) uint64 { return total } +func LiveSourceMiddlewareHandlerFactory(ctx context.Context) func(handler bstream.Handler) bstream.Handler { + return func(next bstream.Handler) bstream.Handler { + return bstream.HandlerFunc(func(blk *pbbstream.Block, obj interface{}) error { + stepable, ok := obj.(bstream.Stepable) + if ok { + step := stepable.Step() + if step.Matches(bstream.StepNew) { + dmetering.GetBytesMeter(ctx).CountInc(MeterLiveUncompressedReadBytes, len(blk.GetPayload().GetValue())) + } else { + dmetering.GetBytesMeter(ctx).CountInc(MeterLiveUncompressedReadForkedBytes, len(blk.GetPayload().GetValue())) + } + } + return next.ProcessBlock(blk, obj) + }) + } +} + +func FileSourceMiddlewareHandlerFactory(ctx context.Context) func(handler bstream.Handler) bstream.Handler { + return func(next bstream.Handler) bstream.Handler { + return bstream.HandlerFunc(func(blk *pbbstream.Block, obj interface{}) error { + stepable, ok := obj.(bstream.Stepable) + if ok { + step := stepable.Step() + if step.Matches(bstream.StepNew) { + dmetering.GetBytesMeter(ctx).CountInc(MeterFileUncompressedReadBytes, len(blk.GetPayload().GetValue())) + } else { + dmetering.GetBytesMeter(ctx).CountInc(MeterFileUncompressedReadForkedBytes, len(blk.GetPayload().GetValue())) + } + } + return next.ProcessBlock(blk, obj) + }) + } +} + func Send(ctx context.Context, userID, apiKeyID, ip, userMeta, endpoint string, resp proto.Message) { if reqctx.IsBackfillerRequest(ctx) { endpoint = fmt.Sprintf("%s%s", endpoint, "Backfill") diff --git a/service/tier1.go b/service/tier1.go index f7343b56..6ec7397f 100644 --- a/service/tier1.go +++ b/service/tier1.go @@ -576,36 +576,6 @@ func (s *Tier1Service) blocks(ctx context.Context, request *pbsubstreamsrpc.Requ streamHandler = pipe } - liveSourceMiddlewareHandler := func(next bstream.Handler) bstream.Handler { - return bstream.HandlerFunc(func(blk *pbbstream.Block, obj interface{}) error { - stepable, ok := obj.(bstream.Stepable) - if ok { - step := stepable.Step() - if step.Matches(bstream.StepNew) { - dmetering.GetBytesMeter(ctx).CountInc(metering.MeterLiveUncompressedReadBytes, len(blk.GetPayload().GetValue())) - } else { - dmetering.GetBytesMeter(ctx).CountInc(metering.MeterLiveUncompressedReadForkedBytes, len(blk.GetPayload().GetValue())) - } - } - return next.ProcessBlock(blk, obj) - }) - } - - fileSourceMiddlewareHandler := func(next bstream.Handler) bstream.Handler { - return bstream.HandlerFunc(func(blk *pbbstream.Block, obj interface{}) error { - stepable, ok := obj.(bstream.Stepable) - if ok { - step := stepable.Step() - if step.Matches(bstream.StepNew) { - dmetering.GetBytesMeter(ctx).CountInc(metering.MeterFileUncompressedReadBytes, len(blk.GetPayload().GetValue())) - } else { - dmetering.GetBytesMeter(ctx).CountInc(metering.MeterFileUncompressedReadForkedBytes, len(blk.GetPayload().GetValue())) - } - } - return next.ProcessBlock(blk, obj) - }) - } - blockStream, err := s.streamFactoryFunc( ctx, streamHandler, @@ -615,8 +585,8 @@ func (s *Tier1Service) blocks(ctx context.Context, request *pbsubstreamsrpc.Requ request.FinalBlocksOnly, cursorIsTarget, logger.Named("stream"), - bsstream.WithLiveSourceHandlerMiddleware(liveSourceMiddlewareHandler), - bsstream.WithFileSourceHandlerMiddleware(fileSourceMiddlewareHandler), + bsstream.WithLiveSourceHandlerMiddleware(metering.LiveSourceMiddlewareHandlerFactory(ctx)), + bsstream.WithFileSourceHandlerMiddleware(metering.FileSourceMiddlewareHandlerFactory(ctx)), ) if err != nil { return fmt.Errorf("error getting stream: %w", err) diff --git a/service/tier2.go b/service/tier2.go index aa67a3a1..afd42516 100644 --- a/service/tier2.go +++ b/service/tier2.go @@ -11,8 +11,6 @@ import ( "connectrpc.com/connect" "github.com/RoaringBitmap/roaring/roaring64" - "github.com/streamingfast/bstream" - pbbstream "github.com/streamingfast/bstream/pb/sf/bstream/v1" "github.com/streamingfast/bstream/stream" bsstream "github.com/streamingfast/bstream/stream" "github.com/streamingfast/dauth" @@ -417,21 +415,6 @@ excludable: streamFactoryFunc = s.streamFactoryFuncOverride } - fileSourceMiddlewareHandler := func(next bstream.Handler) bstream.Handler { - return bstream.HandlerFunc(func(blk *pbbstream.Block, obj interface{}) error { - stepable, ok := obj.(bstream.Stepable) - if ok { - step := stepable.Step() - if step.Matches(bstream.StepNew) { - dmetering.GetBytesMeter(ctx).CountInc(metering.MeterFileUncompressedReadBytes, len(blk.GetPayload().GetValue())) - } else { - dmetering.GetBytesMeter(ctx).CountInc(metering.MeterFileUncompressedReadForkedBytes, len(blk.GetPayload().GetValue())) - } - } - return next.ProcessBlock(blk, obj) - }) - } - blockStream, err := streamFactoryFunc( ctx, pipe, @@ -441,7 +424,7 @@ excludable: true, false, logger.Named("stream"), - bsstream.WithFileSourceHandlerMiddleware(fileSourceMiddlewareHandler), + bsstream.WithFileSourceHandlerMiddleware(metering.FileSourceMiddlewareHandlerFactory(ctx)), ) if err != nil { return fmt.Errorf("error getting stream: %w", err)