From e69551ba07d14dee5dccd90b28cf8b497943f415 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Wed, 2 Oct 2024 09:25:04 +0530 Subject: [PATCH] feat: Use gRPC bidirectional streaming for source transformer (#2071) --- Makefile | 2 +- go.mod | 4 +- go.sum | 8 +- hack/generate-proto.sh | 7 +- pkg/apis/proto/daemon/daemon_grpc.pb.go | 25 +- .../proto/mvtxdaemon/mvtxdaemon_grpc.pb.go | 16 +- .../sourcetransform/v1/sourcetransform.proto | 30 +- pkg/isb/tracker/message_tracker.go | 56 ++ .../tracker/message_tracker_test.go} | 30 +- pkg/sdkclient/grpc/grpc_utils.go | 2 - pkg/sdkclient/sourcetransformer/client.go | 120 +++- .../sourcetransformer/client_test.go | 168 +++-- pkg/sdkclient/sourcetransformer/interface.go | 2 +- .../forward/applier/sourcetransformer.go | 8 +- pkg/sources/forward/data_forward.go | 69 +-- pkg/sources/forward/data_forward_test.go | 91 ++- pkg/sources/forward/shutdown_test.go | 12 +- pkg/sources/generator/tickgen.go | 1 - pkg/sources/source.go | 2 +- pkg/sources/transformer/grpc_transformer.go | 160 +++-- .../transformer/grpc_transformer_test.go | 575 +++++------------- pkg/udf/forward/forward.go | 2 +- pkg/udf/rpc/grpc_batch_map.go | 33 +- pkg/udf/rpc/tracker.go | 75 --- pkg/webhook/validator/validator.go | 5 +- rust/Cargo.lock | 2 +- rust/numaflow-core/Cargo.toml | 2 +- .../numaflow-core/proto/sourcetransform.proto | 33 +- rust/numaflow-core/src/config.rs | 18 +- rust/numaflow-core/src/message.rs | 16 +- .../numaflow-core/src/monovertex/forwarder.rs | 32 +- .../src/transformer/user_defined.rs | 224 +++++-- rust/servesink/Cargo.toml | 2 +- .../extract-event-time-from-payload.yaml | 2 +- test/transformer-e2e/transformer_test.go | 30 +- 35 files changed, 946 insertions(+), 918 deletions(-) create mode 100644 pkg/isb/tracker/message_tracker.go rename pkg/{udf/rpc/tracker_test.go => isb/tracker/message_tracker_test.go} (53%) delete mode 100644 pkg/udf/rpc/tracker.go diff --git a/Makefile b/Makefile index a4bc2012bc..11d91c5890 100644 --- a/Makefile +++ b/Makefile @@ -244,7 +244,7 @@ manifests: crds kubectl kustomize config/extensions/webhook > config/validating-webhook-install.yaml $(GOPATH)/bin/golangci-lint: - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b `go env GOPATH`/bin v1.54.1 + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b `go env GOPATH`/bin v1.61.0 .PHONY: lint lint: $(GOPATH)/bin/golangci-lint diff --git a/go.mod b/go.mod index c2d7d6edd5..ba62a6f28d 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe github.com/nats-io/nats-server/v2 v2.10.20 github.com/nats-io/nats.go v1.37.0 - github.com/numaproj/numaflow-go v0.8.2-0.20240923064822-e16694a878d0 + github.com/numaproj/numaflow-go v0.8.2-0.20241001031210-60188185d9c0 github.com/prometheus/client_golang v1.18.0 github.com/prometheus/client_model v0.5.0 github.com/prometheus/common v0.45.0 @@ -55,7 +55,7 @@ require ( golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d google.golang.org/genproto/googleapis/api v0.0.0-20240604185151-ef581f913117 google.golang.org/grpc v1.66.0 - google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 + google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.4.0 google.golang.org/protobuf v1.34.2 k8s.io/api v0.29.2 k8s.io/apimachinery v0.29.2 diff --git a/go.sum b/go.sum index b17e994439..9670ccac4b 100644 --- a/go.sum +++ b/go.sum @@ -485,8 +485,8 @@ github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDm github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/numaproj/numaflow-go v0.8.2-0.20240923064822-e16694a878d0 h1:qPqZfJdPdsz4qymyzMSNICQe/xBnx9P/G3hRbC1DR7k= -github.com/numaproj/numaflow-go v0.8.2-0.20240923064822-e16694a878d0/go.mod h1:g4JZOyUPhjfhv+kR0sX5d8taw/dasgKPXLvQBi39mJ4= +github.com/numaproj/numaflow-go v0.8.2-0.20241001031210-60188185d9c0 h1:MN4Q36mPrXqPrv2dNoK3gyV7c1CGwUF3wNJxTZSw1lk= +github.com/numaproj/numaflow-go v0.8.2-0.20241001031210-60188185d9c0/go.mod h1:FaCMeV0V9SiLcVf2fwT+GeTJHNaK2gdQsTAIqQ4x7oc= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -1049,8 +1049,8 @@ google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.66.0 h1:DibZuoBznOxbDQxRINckZcUvnCEvrW9pcWIE2yF9r1c= google.golang.org/grpc v1.66.0/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 h1:rNBFJjBCOgVr9pWD7rs/knKL4FRTKgpZmsRfV214zcA= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0/go.mod h1:Dk1tviKTvMCz5tvh7t+fh94dhmQVHuCt2OzJB3CTW9Y= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.4.0 h1:9SxA29VM43MF5Z9dQu694wmY5t8E/Gxr7s+RSxiIDmc= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.4.0/go.mod h1:yZOK5zhQMiALmuweVdIVoQPa6eIJyXn2B9g5dJDhqX4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/hack/generate-proto.sh b/hack/generate-proto.sh index bf970ce318..7d9f19cb67 100755 --- a/hack/generate-proto.sh +++ b/hack/generate-proto.sh @@ -22,11 +22,14 @@ install-protobuf() { ARCH=$(uname_arch) echo "OS: $OS ARCH: $ARCH" + if [[ "$ARCH" = "amd64" ]]; then + ARCH="x86_64" + elif [[ "$ARCH" = "arm64" ]]; then + ARCH="aarch_64" + fi BINARY_URL=$PB_REL/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-${OS}-${ARCH}.zip if [[ "$OS" = "darwin" ]]; then BINARY_URL=$PB_REL/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-osx-universal_binary.zip - elif [[ "$OS" = "linux" ]]; then - BINARY_URL=$PB_REL/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip fi echo "Downloading $BINARY_URL" diff --git a/pkg/apis/proto/daemon/daemon_grpc.pb.go b/pkg/apis/proto/daemon/daemon_grpc.pb.go index 61e15a2a62..6b348d8fdf 100644 --- a/pkg/apis/proto/daemon/daemon_grpc.pb.go +++ b/pkg/apis/proto/daemon/daemon_grpc.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.3.0 +// - protoc-gen-go-grpc v1.4.0 // - protoc v5.27.2 // source: pkg/apis/proto/daemon/daemon.proto @@ -30,8 +30,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( DaemonService_ListBuffers_FullMethodName = "/daemon.DaemonService/ListBuffers" @@ -44,6 +44,8 @@ const ( // DaemonServiceClient is the client API for DaemonService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// DaemonService is a grpc service that is used to provide APIs for giving any pipeline information. type DaemonServiceClient interface { ListBuffers(ctx context.Context, in *ListBuffersRequest, opts ...grpc.CallOption) (*ListBuffersResponse, error) GetBuffer(ctx context.Context, in *GetBufferRequest, opts ...grpc.CallOption) (*GetBufferResponse, error) @@ -62,8 +64,9 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient { } func (c *daemonServiceClient) ListBuffers(ctx context.Context, in *ListBuffersRequest, opts ...grpc.CallOption) (*ListBuffersResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListBuffersResponse) - err := c.cc.Invoke(ctx, DaemonService_ListBuffers_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_ListBuffers_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -71,8 +74,9 @@ func (c *daemonServiceClient) ListBuffers(ctx context.Context, in *ListBuffersRe } func (c *daemonServiceClient) GetBuffer(ctx context.Context, in *GetBufferRequest, opts ...grpc.CallOption) (*GetBufferResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetBufferResponse) - err := c.cc.Invoke(ctx, DaemonService_GetBuffer_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetBuffer_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -80,8 +84,9 @@ func (c *daemonServiceClient) GetBuffer(ctx context.Context, in *GetBufferReques } func (c *daemonServiceClient) GetVertexMetrics(ctx context.Context, in *GetVertexMetricsRequest, opts ...grpc.CallOption) (*GetVertexMetricsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetVertexMetricsResponse) - err := c.cc.Invoke(ctx, DaemonService_GetVertexMetrics_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetVertexMetrics_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -89,8 +94,9 @@ func (c *daemonServiceClient) GetVertexMetrics(ctx context.Context, in *GetVerte } func (c *daemonServiceClient) GetPipelineWatermarks(ctx context.Context, in *GetPipelineWatermarksRequest, opts ...grpc.CallOption) (*GetPipelineWatermarksResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetPipelineWatermarksResponse) - err := c.cc.Invoke(ctx, DaemonService_GetPipelineWatermarks_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetPipelineWatermarks_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -98,8 +104,9 @@ func (c *daemonServiceClient) GetPipelineWatermarks(ctx context.Context, in *Get } func (c *daemonServiceClient) GetPipelineStatus(ctx context.Context, in *GetPipelineStatusRequest, opts ...grpc.CallOption) (*GetPipelineStatusResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetPipelineStatusResponse) - err := c.cc.Invoke(ctx, DaemonService_GetPipelineStatus_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, DaemonService_GetPipelineStatus_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -109,6 +116,8 @@ func (c *daemonServiceClient) GetPipelineStatus(ctx context.Context, in *GetPipe // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility +// +// DaemonService is a grpc service that is used to provide APIs for giving any pipeline information. type DaemonServiceServer interface { ListBuffers(context.Context, *ListBuffersRequest) (*ListBuffersResponse, error) GetBuffer(context.Context, *GetBufferRequest) (*GetBufferResponse, error) diff --git a/pkg/apis/proto/mvtxdaemon/mvtxdaemon_grpc.pb.go b/pkg/apis/proto/mvtxdaemon/mvtxdaemon_grpc.pb.go index 33f0b26d6b..76477c3de0 100644 --- a/pkg/apis/proto/mvtxdaemon/mvtxdaemon_grpc.pb.go +++ b/pkg/apis/proto/mvtxdaemon/mvtxdaemon_grpc.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.3.0 +// - protoc-gen-go-grpc v1.4.0 // - protoc v5.27.2 // source: pkg/apis/proto/mvtxdaemon/mvtxdaemon.proto @@ -31,8 +31,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 const ( MonoVertexDaemonService_GetMonoVertexMetrics_FullMethodName = "/mvtxdaemon.MonoVertexDaemonService/GetMonoVertexMetrics" @@ -42,6 +42,8 @@ const ( // MonoVertexDaemonServiceClient is the client API for MonoVertexDaemonService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// MonoVertexDaemonService is a grpc service that is used to provide APIs for giving any MonoVertex information. type MonoVertexDaemonServiceClient interface { GetMonoVertexMetrics(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMonoVertexMetricsResponse, error) GetMonoVertexStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMonoVertexStatusResponse, error) @@ -56,8 +58,9 @@ func NewMonoVertexDaemonServiceClient(cc grpc.ClientConnInterface) MonoVertexDae } func (c *monoVertexDaemonServiceClient) GetMonoVertexMetrics(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMonoVertexMetricsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetMonoVertexMetricsResponse) - err := c.cc.Invoke(ctx, MonoVertexDaemonService_GetMonoVertexMetrics_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, MonoVertexDaemonService_GetMonoVertexMetrics_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -65,8 +68,9 @@ func (c *monoVertexDaemonServiceClient) GetMonoVertexMetrics(ctx context.Context } func (c *monoVertexDaemonServiceClient) GetMonoVertexStatus(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*GetMonoVertexStatusResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetMonoVertexStatusResponse) - err := c.cc.Invoke(ctx, MonoVertexDaemonService_GetMonoVertexStatus_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, MonoVertexDaemonService_GetMonoVertexStatus_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -76,6 +80,8 @@ func (c *monoVertexDaemonServiceClient) GetMonoVertexStatus(ctx context.Context, // MonoVertexDaemonServiceServer is the server API for MonoVertexDaemonService service. // All implementations must embed UnimplementedMonoVertexDaemonServiceServer // for forward compatibility +// +// MonoVertexDaemonService is a grpc service that is used to provide APIs for giving any MonoVertex information. type MonoVertexDaemonServiceServer interface { GetMonoVertexMetrics(context.Context, *emptypb.Empty) (*GetMonoVertexMetricsResponse, error) GetMonoVertexStatus(context.Context, *emptypb.Empty) (*GetMonoVertexStatusResponse, error) diff --git a/pkg/apis/proto/sourcetransform/v1/sourcetransform.proto b/pkg/apis/proto/sourcetransform/v1/sourcetransform.proto index b93d82b9a8..740ae1c671 100644 --- a/pkg/apis/proto/sourcetransform/v1/sourcetransform.proto +++ b/pkg/apis/proto/sourcetransform/v1/sourcetransform.proto @@ -28,21 +28,35 @@ service SourceTransform { // SourceTransformFn applies a function to each request element. // In addition to map function, SourceTransformFn also supports assigning a new event time to response. // SourceTransformFn can be used only at source vertex by source data transformer. - rpc SourceTransformFn(SourceTransformRequest) returns (SourceTransformResponse); + rpc SourceTransformFn(stream SourceTransformRequest) returns (stream SourceTransformResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); } +/* + * Handshake message between client and server to indicate the start of transmission. + */ + message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; +} + /** * SourceTransformerRequest represents a request element. */ message SourceTransformRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + // This ID is used to uniquely identify a transform request + string id = 6; + } + Request request = 1; + optional Handshake handshake = 2; } /** @@ -56,6 +70,10 @@ message SourceTransformResponse { repeated string tags = 4; } repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; + // Handshake message between client and server to indicate the start of transmission. + optional Handshake handshake = 3; } /** diff --git a/pkg/isb/tracker/message_tracker.go b/pkg/isb/tracker/message_tracker.go new file mode 100644 index 0000000000..dfd608e5bf --- /dev/null +++ b/pkg/isb/tracker/message_tracker.go @@ -0,0 +1,56 @@ +package tracker + +import ( + "sync" + + "github.com/numaproj/numaflow/pkg/isb" +) + +// MessageTracker is used to store a key value pair for string and *ReadMessage +// as it can be accessed by concurrent goroutines, we keep all operations +// under a mutex +type MessageTracker struct { + lock sync.RWMutex + m map[string]*isb.ReadMessage +} + +// NewMessageTracker initializes a new instance of a Tracker +func NewMessageTracker(messages []*isb.ReadMessage) *MessageTracker { + m := make(map[string]*isb.ReadMessage, len(messages)) + for _, msg := range messages { + id := msg.ReadOffset.String() + m[id] = msg + } + return &MessageTracker{ + m: m, + lock: sync.RWMutex{}, + } +} + +// Remove will remove the entry for a given id and return the stored value corresponding to this id. +// A `nil` return value indicates that the id doesn't exist in the tracker. +func (t *MessageTracker) Remove(id string) *isb.ReadMessage { + t.lock.Lock() + defer t.lock.Unlock() + item, ok := t.m[id] + if !ok { + return nil + } + delete(t.m, id) + return item +} + +// IsEmpty is a helper function which checks if the Tracker map is empty +// return true if empty +func (t *MessageTracker) IsEmpty() bool { + t.lock.RLock() + defer t.lock.RUnlock() + return len(t.m) == 0 +} + +// Len returns the number of messages currently stored in the tracker +func (t *MessageTracker) Len() int { + t.lock.RLock() + defer t.lock.RUnlock() + return len(t.m) +} diff --git a/pkg/udf/rpc/tracker_test.go b/pkg/isb/tracker/message_tracker_test.go similarity index 53% rename from pkg/udf/rpc/tracker_test.go rename to pkg/isb/tracker/message_tracker_test.go index 21704f4425..3c2ae767d0 100644 --- a/pkg/udf/rpc/tracker_test.go +++ b/pkg/isb/tracker/message_tracker_test.go @@ -1,4 +1,4 @@ -package rpc +package tracker import ( "testing" @@ -6,32 +6,34 @@ import ( "github.com/stretchr/testify/assert" + "github.com/numaproj/numaflow/pkg/isb" "github.com/numaproj/numaflow/pkg/isb/testutils" ) func TestTracker_AddRequest(t *testing.T) { - tr := NewTracker() readMessages := testutils.BuildTestReadMessages(3, time.Unix(1661169600, 0), nil) - for _, msg := range readMessages { - tr.addRequest(&msg) + messages := make([]*isb.ReadMessage, len(readMessages)) + for i, msg := range readMessages { + messages[i] = &msg } + tr := NewMessageTracker(messages) id := readMessages[0].ReadOffset.String() - m, ok := tr.getRequest(id) - assert.True(t, ok) + m := tr.Remove(id) + assert.NotNil(t, m) assert.Equal(t, readMessages[0], *m) } func TestTracker_RemoveRequest(t *testing.T) { - tr := NewTracker() readMessages := testutils.BuildTestReadMessages(3, time.Unix(1661169600, 0), nil) - for _, msg := range readMessages { - tr.addRequest(&msg) + messages := make([]*isb.ReadMessage, len(readMessages)) + for i, msg := range readMessages { + messages[i] = &msg } + tr := NewMessageTracker(messages) id := readMessages[0].ReadOffset.String() - m, ok := tr.getRequest(id) - assert.True(t, ok) + m := tr.Remove(id) + assert.NotNil(t, m) assert.Equal(t, readMessages[0], *m) - tr.removeRequest(id) - _, ok = tr.getRequest(id) - assert.False(t, ok) + m = tr.Remove(id) + assert.Nil(t, m) } diff --git a/pkg/sdkclient/grpc/grpc_utils.go b/pkg/sdkclient/grpc/grpc_utils.go index 293ba8e8d7..71ae252738 100644 --- a/pkg/sdkclient/grpc/grpc_utils.go +++ b/pkg/sdkclient/grpc/grpc_utils.go @@ -18,7 +18,6 @@ package grpc import ( "fmt" - "log" "strconv" "google.golang.org/grpc" @@ -56,7 +55,6 @@ func ConnectToServer(udsSockAddr string, serverInfo *serverinfo.ServerInfo, maxM ) } else { sockAddr = getUdsSockAddr(udsSockAddr) - log.Println("UDS Client:", sockAddr) conn, err = grpc.NewClient(sockAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMessageSize), grpc.MaxCallSendMsgSize(maxMessageSize))) diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index d9d47302c0..92372ff7a4 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -18,7 +18,10 @@ package sourcetransformer import ( "context" + "fmt" + "time" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" @@ -28,16 +31,18 @@ import ( sdkerr "github.com/numaproj/numaflow/pkg/sdkclient/error" grpcutil "github.com/numaproj/numaflow/pkg/sdkclient/grpc" "github.com/numaproj/numaflow/pkg/sdkclient/serverinfo" + "github.com/numaproj/numaflow/pkg/shared/logging" ) // client contains the grpc connection and the grpc client. type client struct { conn *grpc.ClientConn grpcClt transformpb.SourceTransformClient + stream transformpb.SourceTransform_SourceTransformFnClient } // New creates a new client object. -func New(serverInfo *serverinfo.ServerInfo, inputOptions ...sdkclient.Option) (Client, error) { +func New(ctx context.Context, serverInfo *serverinfo.ServerInfo, inputOptions ...sdkclient.Option) (Client, error) { var opts = sdkclient.DefaultOptions(sdkclient.SourceTransformerAddr) for _, inputOption := range inputOptions { @@ -53,18 +58,81 @@ func New(serverInfo *serverinfo.ServerInfo, inputOptions ...sdkclient.Option) (C c := new(client) c.conn = conn c.grpcClt = transformpb.NewSourceTransformClient(conn) + + var logger = logging.FromContext(ctx) + +waitUntilReady: + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("waiting for transformer gRPC server to be ready: %w", ctx.Err()) + default: + _, err := c.IsReady(ctx, &emptypb.Empty{}) + if err != nil { + logger.Warnf("Transformer server is not ready: %v", err) + time.Sleep(100 * time.Millisecond) + continue waitUntilReady + } + break waitUntilReady + } + } + + c.stream, err = c.grpcClt.SourceTransformFn(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create a gRPC stream for source transform: %w", err) + } + + if err := doHandshake(c.stream); err != nil { + return nil, err + } + return c, nil } +func doHandshake(stream transformpb.SourceTransform_SourceTransformFnClient) error { + // Send handshake request + handshakeReq := &transformpb.SourceTransformRequest{ + Handshake: &transformpb.Handshake{ + Sot: true, + }, + } + if err := stream.Send(handshakeReq); err != nil { + return fmt.Errorf("failed to send handshake request for source tansform: %w", err) + } + + handshakeResp, err := stream.Recv() + if err != nil { + return fmt.Errorf("failed to receive handshake response from source transform stream: %w", err) + } + if resp := handshakeResp.GetHandshake(); resp == nil || !resp.GetSot() { + return fmt.Errorf("invalid handshake response for source transform. Received='%+v'", resp) + } + return nil +} + // NewFromClient creates a new client object from a grpc client. This is used for testing. -func NewFromClient(c transformpb.SourceTransformClient) (Client, error) { +func NewFromClient(ctx context.Context, c transformpb.SourceTransformClient) (Client, error) { + stream, err := c.SourceTransformFn(ctx) + if err != nil { + return nil, err + } + + if err := doHandshake(stream); err != nil { + return nil, err + } + return &client{ grpcClt: c, + stream: stream, }, nil } // CloseConn closes the grpc client connection. -func (c *client) CloseConn(ctx context.Context) error { +func (c *client) CloseConn(_ context.Context) error { + err := c.stream.CloseSend() + if err != nil { + return err + } if c.conn == nil { return nil } @@ -81,11 +149,47 @@ func (c *client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) { } // SourceTransformFn SourceTransformerFn applies a function to each request element. -func (c *client) SourceTransformFn(ctx context.Context, request *transformpb.SourceTransformRequest) (*transformpb.SourceTransformResponse, error) { - transformResponse, err := c.grpcClt.SourceTransformFn(ctx, request) - err = sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn", err) - if err != nil { +// Response channel will not be closed. Caller can select on response and error channel to exit on first error. +func (c *client) SourceTransformFn(ctx context.Context, requests []*transformpb.SourceTransformRequest) ([]*transformpb.SourceTransformResponse, error) { + var eg errgroup.Group + // send n requests + eg.Go(func() error { + for _, req := range requests { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if err := c.stream.Send(req); err != nil { + return sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err) + } + } + return nil + }) + + // receive n responses + responses := make([]*transformpb.SourceTransformResponse, len(requests)) + eg.Go(func() error { + for i := 0; i < len(requests); i++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + resp, err := c.stream.Recv() + if err != nil { + return sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Recv", err) + } + responses[i] = resp + } + return nil + }) + + // wait for the send and receive goroutines to finish + // if any of the goroutines return an error, the error will be caught here + if err := eg.Wait(); err != nil { return nil, err } - return transformResponse, nil + + return responses, nil } diff --git a/pkg/sdkclient/sourcetransformer/client_test.go b/pkg/sdkclient/sourcetransformer/client_test.go index 27526312fd..c66abbd6ea 100644 --- a/pkg/sdkclient/sourcetransformer/client_test.go +++ b/pkg/sdkclient/sourcetransformer/client_test.go @@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -18,80 +18,132 @@ package sourcetransformer import ( "context" + "errors" "fmt" - "reflect" + "net" "testing" + "time" - "github.com/golang/mock/gomock" transformpb "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" - transformermock "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1/transformmock" - "github.com/stretchr/testify/assert" + "github.com/numaproj/numaflow-go/pkg/sourcetransformer" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestClient_IsReady(t *testing.T) { var ctx = context.Background() + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + return sourcetransformer.MessagesBuilder() + }), + } + + // Start the gRPC server + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + defer conn.Close() + + // Create a client connection to the server + client := transformpb.NewSourceTransformClient(conn) - ctrl := gomock.NewController(t) - defer ctrl.Finish() + testClient, err := NewFromClient(ctx, client) + require.NoError(t, err) - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().IsReady(gomock.Any(), gomock.Any()).Return(&transformpb.ReadyResponse{Ready: true}, nil) - mockClient.EXPECT().IsReady(gomock.Any(), gomock.Any()).Return(&transformpb.ReadyResponse{Ready: false}, fmt.Errorf("mock connection refused")) + ready, err := testClient.IsReady(ctx, &emptypb.Empty{}) + require.True(t, ready) + require.NoError(t, err) +} - testClient, err := NewFromClient(mockClient) - assert.NoError(t, err) - reflect.DeepEqual(testClient, &client{ - grpcClt: mockClient, +func newServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientConn { + lis := bufconn.Listen(100) + t.Cleanup(func() { + _ = lis.Close() }) - ready, err := testClient.IsReady(ctx, &emptypb.Empty{}) - assert.True(t, ready) - assert.NoError(t, err) + server := grpc.NewServer() + t.Cleanup(func() { + server.Stop() + }) - ready, err = testClient.IsReady(ctx, &emptypb.Empty{}) - assert.False(t, ready) - assert.EqualError(t, err, "mock connection refused") -} + register(server) -func TestClient_SourceTransformFn(t *testing.T) { - var ctx = context.Background() + errChan := make(chan error, 1) + go func() { + // t.Fatal should only be called from the goroutine running the test + if err := server.Serve(lis); err != nil { + errChan <- err + } + }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), gomock.Any()).Return(&transformpb.SourceTransformResponse{Results: []*transformpb.SourceTransformResponse_Result{ - { - Keys: []string{"temp-key"}, - Value: []byte("mock result"), - Tags: nil, - }, - }}, nil) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), gomock.Any()).Return(&transformpb.SourceTransformResponse{Results: []*transformpb.SourceTransformResponse_Result{ - { - Keys: []string{"temp-key"}, - Value: []byte("mock result"), - Tags: nil, - }, - }}, fmt.Errorf("mock connection refused")) - - testClient, err := NewFromClient(mockClient) - assert.NoError(t, err) - reflect.DeepEqual(testClient, &client{ - grpcClt: mockClient, + dialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + conn, err := grpc.NewClient("passthrough://", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + t.Cleanup(func() { + _ = conn.Close() }) + if err != nil { + t.Fatalf("Creating new gRPC client connection: %v", err) + } + + var grpcServerErr error + select { + case grpcServerErr = <-errChan: + case <-time.After(500 * time.Millisecond): + grpcServerErr = errors.New("gRPC server didn't start in 500ms") + } + if err != nil { + t.Fatalf("Failed to start gRPC server: %v", grpcServerErr) + } + + return conn +} - result, err := testClient.SourceTransformFn(ctx, &transformpb.SourceTransformRequest{}) - assert.Equal(t, &transformpb.SourceTransformResponse{Results: []*transformpb.SourceTransformResponse_Result{ - { - Keys: []string{"temp-key"}, - Value: []byte("mock result"), - Tags: nil, - }, - }}, result) - assert.NoError(t, err) - - _, err = testClient.SourceTransformFn(ctx, &transformpb.SourceTransformRequest{}) - assert.EqualError(t, err, "NonRetryable: mock connection refused") +func TestClient_SourceTransformFn(t *testing.T) { + var testTime = time.Date(2021, 8, 15, 14, 30, 45, 100, time.Local) + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + msg := datum.Value() + return sourcetransformer.MessagesBuilder().Append(sourcetransformer.NewMessage(msg, testTime).WithKeys([]string{keys[0] + "_test"})) + }), + } + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + var ctx = context.Background() + client, _ := NewFromClient(ctx, transformClient) + + requests := make([]*transformpb.SourceTransformRequest, 5) + go func() { + for i := 0; i < 5; i++ { + requests[i] = &transformpb.SourceTransformRequest{ + Request: &transformpb.SourceTransformRequest_Request{ + Keys: []string{fmt.Sprintf("client_key_%d", i)}, + Value: []byte("test"), + }, + } + } + }() + + responses, err := client.SourceTransformFn(ctx, requests) + require.NoError(t, err) + var results [][]*transformpb.SourceTransformResponse_Result + for _, resp := range responses { + results = append(results, resp.GetResults()) + } + expected := [][]*transformpb.SourceTransformResponse_Result{ + {{Keys: []string{"client_key_0_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + {{Keys: []string{"client_key_1_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + {{Keys: []string{"client_key_2_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + {{Keys: []string{"client_key_3_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + {{Keys: []string{"client_key_4_test"}, Value: []byte("test"), EventTime: timestamppb.New(testTime)}}, + } + require.ElementsMatch(t, expected, results) } diff --git a/pkg/sdkclient/sourcetransformer/interface.go b/pkg/sdkclient/sourcetransformer/interface.go index 4d8e3d8f71..883353f3a6 100644 --- a/pkg/sdkclient/sourcetransformer/interface.go +++ b/pkg/sdkclient/sourcetransformer/interface.go @@ -27,5 +27,5 @@ import ( type Client interface { CloseConn(ctx context.Context) error IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) - SourceTransformFn(ctx context.Context, request *transformpb.SourceTransformRequest) (*transformpb.SourceTransformResponse, error) + SourceTransformFn(ctx context.Context, requests []*transformpb.SourceTransformRequest) ([]*transformpb.SourceTransformResponse, error) } diff --git a/pkg/sources/forward/applier/sourcetransformer.go b/pkg/sources/forward/applier/sourcetransformer.go index 795cd4c5a2..a935d511ea 100644 --- a/pkg/sources/forward/applier/sourcetransformer.go +++ b/pkg/sources/forward/applier/sourcetransformer.go @@ -25,13 +25,13 @@ import ( // SourceTransformApplier applies the source transform on the read message and gives back a new message. Any UserError will be retried here, while // InternalErr can be returned and could be retried by the callee. type SourceTransformApplier interface { - ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) + ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) } // ApplySourceTransformFunc is a function type that implements SourceTransformApplier interface. -type ApplySourceTransformFunc func(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) +type ApplySourceTransformFunc func(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) // ApplyTransform implements SourceTransformApplier interface. -func (f ApplySourceTransformFunc) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return f(ctx, message) +func (f ApplySourceTransformFunc) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + return f(ctx, messages) } diff --git a/pkg/sources/forward/data_forward.go b/pkg/sources/forward/data_forward.go index 48ff97da3a..913be67939 100644 --- a/pkg/sources/forward/data_forward.go +++ b/pkg/sources/forward/data_forward.go @@ -305,34 +305,14 @@ func (df *DataForward) forwardAChunk(ctx context.Context) { // If a user-defined transformer exists, apply it if df.opts.transformer != nil { - // user-defined transformer concurrent processing request channel - transformerCh := make(chan *isb.ReadWriteMessagePair) - - // create a pool of Transformer Processors - var wg sync.WaitGroup - for i := 0; i < df.opts.transformerConcurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - df.concurrentApplyTransformer(ctx, transformerCh) - }() + for _, m := range readMessages { + // assign watermark to the message + m.Watermark = time.Time(processorWM) } concurrentTransformerProcessingStart := time.Now() - for idx, m := range readMessages { + readWriteMessagePairs = df.applyTransformer(ctx, readMessages) - // assign watermark to the message - m.Watermark = time.Time(processorWM) - readWriteMessagePairs[idx].ReadMessage = m - // send transformer processing work to the channel. Thus, the results of the transformer - // application on a read message will be stored as the corresponding writeMessage in readWriteMessagePairs - transformerCh <- &readWriteMessagePairs[idx] - } - // let the go routines know that there is no more work - close(transformerCh) - // wait till the processing is done. this will not be an infinite wait because the transformer processing will exit if - // context.Done() is closed. - wg.Wait() df.opts.logger.Debugw("concurrent applyTransformer completed", zap.Int("concurrency", df.opts.transformerConcurrency), zap.Duration("took", time.Since(concurrentTransformerProcessingStart)), @@ -536,6 +516,7 @@ func (df *DataForward) writeToBuffers( for toVertexName, toVertexMessages := range messageToStep { writeOffsets[toVertexName] = make([][]isb.Offset, len(toVertexMessages)) } + for toVertexName, toVertexBuffer := range df.toBuffers { for index, partition := range toVertexBuffer { writeOffsets[toVertexName][index], err = df.writeToBuffer(ctx, partition, messageToStep[toVertexName][index]) @@ -591,6 +572,7 @@ func (df *DataForward) writeToBuffer(ctx context.Context, toBufferPartition isb. zap.String("reason", err.Error()), zap.String("partition", toBufferPartition.GetName()), zap.String("vertex", df.vertexName), zap.String("pipeline", df.pipelineName), + zap.String("msg_id", msg.ID.String()), ) } else { needRetry = true @@ -661,42 +643,12 @@ func (df *DataForward) writeToBuffer(ctx context.Context, toBufferPartition isb. return writeOffsets, nil } -// concurrentApplyTransformer applies the transformer based on the request from the channel -func (df *DataForward) concurrentApplyTransformer(ctx context.Context, readMessagePair <-chan *isb.ReadWriteMessagePair) { - for message := range readMessagePair { - start := time.Now() - metrics.SourceTransformerReadMessagesCount.With(map[string]string{ - metrics.LabelVertex: df.vertexName, - metrics.LabelPipeline: df.pipelineName, - metrics.LabelVertexReplicaIndex: strconv.Itoa(int(df.vertexReplica)), - metrics.LabelPartitionName: df.reader.GetName(), - }).Inc() - - writeMessages, err := df.applyTransformer(ctx, message.ReadMessage) - metrics.SourceTransformerWriteMessagesCount.With(map[string]string{ - metrics.LabelVertex: df.vertexName, - metrics.LabelPipeline: df.pipelineName, - metrics.LabelVertexReplicaIndex: strconv.Itoa(int(df.vertexReplica)), - metrics.LabelPartitionName: df.reader.GetName(), - }).Add(float64(len(writeMessages))) - - message.WriteMessages = append(message.WriteMessages, writeMessages...) - message.Err = err - metrics.SourceTransformerProcessingTime.With(map[string]string{ - metrics.LabelVertex: df.vertexName, - metrics.LabelPipeline: df.pipelineName, - metrics.LabelVertexReplicaIndex: strconv.Itoa(int(df.vertexReplica)), - metrics.LabelPartitionName: df.reader.GetName(), - }).Observe(float64(time.Since(start).Microseconds())) - } -} - // applyTransformer applies the transformer and will block if there is any InternalErr. On the other hand, if this is a UserError // the skip flag is set. The ShutDown flag will only if there is an InternalErr and ForceStop has been invoked. // The UserError retry will be done on the applyTransformer. -func (df *DataForward) applyTransformer(ctx context.Context, readMessage *isb.ReadMessage) ([]*isb.WriteMessage, error) { +func (df *DataForward) applyTransformer(ctx context.Context, messages []*isb.ReadMessage) []isb.ReadWriteMessagePair { for { - writeMessages, err := df.opts.transformer.ApplyTransform(ctx, readMessage) + transformResults, err := df.opts.transformer.ApplyTransform(ctx, messages) if err != nil { df.opts.logger.Errorw("Transformer.Apply error", zap.Error(err)) // TODO: implement retry with backoff etc. @@ -712,12 +664,11 @@ func (df *DataForward) applyTransformer(ctx context.Context, readMessage *isb.Re metrics.LabelVertexType: string(dfv1.VertexTypeSource), metrics.LabelVertexReplicaIndex: strconv.Itoa(int(df.vertexReplica)), }).Inc() - - return nil, err + return []isb.ReadWriteMessagePair{{Err: err}} } continue } - return writeMessages, nil + return transformResults } } diff --git a/pkg/sources/forward/data_forward_test.go b/pkg/sources/forward/data_forward_test.go index 25e41a9fa6..96cb6760e6 100644 --- a/pkg/sources/forward/data_forward_test.go +++ b/pkg/sources/forward/data_forward_test.go @@ -121,8 +121,16 @@ func (f myForwardTest) WhereTo(_ []string, _ []string, s string) ([]forwarder.Ve }}, nil } -func (f myForwardTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "test-vertex", message) +func (f myForwardTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + out := make([]isb.ReadWriteMessagePair, len(messages)) + for i, msg := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "test-vertex", msg) + out[i] = isb.ReadWriteMessagePair{ + ReadMessage: msg, + WriteMessages: writeMsg, + } + } + return out, nil } func TestNewDataForward(t *testing.T) { @@ -856,36 +864,31 @@ func (f *mySourceForwardTestRoundRobin) WhereTo(_ []string, _ []string, s string // such that we can verify message IsLate attribute gets set to true. var testSourceNewEventTime = testSourceWatermark.Add(time.Duration(-1) * time.Minute) -func (f mySourceForwardTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return func(ctx context.Context, readMessage *isb.ReadMessage) ([]*isb.WriteMessage, error) { - _ = ctx - offset := readMessage.ReadOffset - payload := readMessage.Body.Payload - parentPaneInfo := readMessage.MessageInfo - - // apply source data transformer - _ = payload - // copy the payload - result := payload - // assign new event time - parentPaneInfo.EventTime = testSourceNewEventTime - var key []string - - writeMessage := isb.Message{ +func (f mySourceForwardTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + message.MessageInfo.EventTime = testSourceNewEventTime + writeMsg := isb.Message{ Header: isb.Header{ - MessageInfo: parentPaneInfo, + MessageInfo: message.MessageInfo, ID: isb.MessageID{ VertexName: "test-vertex", - Offset: offset.String(), + Offset: message.ReadOffset.String(), }, - Keys: key, + Keys: []string{}, }, Body: isb.Body{ - Payload: result, + Payload: message.Body.Payload, }, } - return []*isb.WriteMessage{{Message: writeMessage}}, nil - }(ctx, message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: []*isb.WriteMessage{{ + Message: writeMsg, + }}, + } + } + return results, nil } // TestSourceWatermarkPublisher is a dummy implementation of isb.SourceWatermarkPublisher interface @@ -1153,8 +1156,16 @@ func (f myForwardDropTest) WhereTo(_ []string, _ []string, s string) ([]forwarde return []forwarder.VertexBuffer{}, nil } -func (f myForwardDropTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "test-vertex", message) +func (f myForwardDropTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "test-vertex", message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: writeMsg, + } + } + return results, nil } type myForwardToAllTest struct { @@ -1174,8 +1185,16 @@ func (f *myForwardToAllTest) WhereTo(_ []string, _ []string, s string) ([]forwar return output, nil } -func (f *myForwardToAllTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "test-vertex", message) +func (f *myForwardToAllTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "test-vertex", message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: writeMsg, + } + } + return results, nil } type myForwardInternalErrTest struct { @@ -1188,7 +1207,7 @@ func (f myForwardInternalErrTest) WhereTo(_ []string, _ []string, s string) ([]f }}, nil } -func (f myForwardInternalErrTest) ApplyTransform(_ context.Context, _ *isb.ReadMessage) ([]*isb.WriteMessage, error) { +func (f myForwardInternalErrTest) ApplyTransform(ctx context.Context, _ []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { return nil, &udfapplier.ApplyUDFErr{ UserUDFErr: false, InternalErr: struct { @@ -1209,8 +1228,16 @@ func (f myForwardApplyWhereToErrTest) WhereTo(_ []string, _ []string, s string) }}, fmt.Errorf("whereToStep failed") } -func (f myForwardApplyWhereToErrTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "test-vertex", message) +func (f myForwardApplyWhereToErrTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "test-vertex", message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: writeMsg, + } + } + return results, nil } type myForwardApplyTransformerErrTest struct { @@ -1223,7 +1250,7 @@ func (f myForwardApplyTransformerErrTest) WhereTo(_ []string, _ []string, s stri }}, nil } -func (f myForwardApplyTransformerErrTest) ApplyTransform(_ context.Context, _ *isb.ReadMessage) ([]*isb.WriteMessage, error) { +func (f myForwardApplyTransformerErrTest) ApplyTransform(_ context.Context, _ []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { return nil, fmt.Errorf("transformer error") } diff --git a/pkg/sources/forward/shutdown_test.go b/pkg/sources/forward/shutdown_test.go index a4ffc5e2e2..34003e729f 100644 --- a/pkg/sources/forward/shutdown_test.go +++ b/pkg/sources/forward/shutdown_test.go @@ -43,8 +43,16 @@ func (s myShutdownTest) WhereTo([]string, []string, string) ([]forwarder.VertexB return []forwarder.VertexBuffer{}, nil } -func (s myShutdownTest) ApplyTransform(ctx context.Context, message *isb.ReadMessage) ([]*isb.WriteMessage, error) { - return testutils.CopyUDFTestApply(ctx, "", message) +func (f myShutdownTest) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + results := make([]isb.ReadWriteMessagePair, len(messages)) + for i, message := range messages { + writeMsg, _ := testutils.CopyUDFTestApply(ctx, "", message) + results[i] = isb.ReadWriteMessagePair{ + ReadMessage: message, + WriteMessages: writeMsg, + } + } + return results, nil } func (s myShutdownTest) ApplyMapStream(ctx context.Context, message *isb.ReadMessage, writeMessageCh chan<- isb.WriteMessage) error { diff --git a/pkg/sources/generator/tickgen.go b/pkg/sources/generator/tickgen.go index ff00ba8cba..c0cdb9dcf1 100644 --- a/pkg/sources/generator/tickgen.go +++ b/pkg/sources/generator/tickgen.go @@ -202,7 +202,6 @@ loop: tickgenSourceReadCount.With(map[string]string{metrics.LabelVertex: mg.vertexName, metrics.LabelPipeline: mg.pipelineName}).Inc() msgs = append(msgs, mg.newReadMessage(r.key, r.data, r.offset, r.ts)) case <-timeout: - mg.logger.Infow("Timed out waiting for messages to read.", zap.Duration("waited", mg.readTimeout)) break loop } } diff --git a/pkg/sources/source.go b/pkg/sources/source.go index 0b3e23a94b..69bc0c0099 100644 --- a/pkg/sources/source.go +++ b/pkg/sources/source.go @@ -240,7 +240,7 @@ func (sp *SourceProcessor) Start(ctx context.Context) error { return err } - srcTransformerClient, err := sourcetransformer.New(serverInfo, sdkclient.WithMaxMessageSize(maxMessageSize)) + srcTransformerClient, err := sourcetransformer.New(ctx, serverInfo, sdkclient.WithMaxMessageSize(maxMessageSize)) if err != nil { return fmt.Errorf("failed to create transformer gRPC client, %w", err) } diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index 14b414a348..459e99f21b 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -24,10 +24,8 @@ import ( v1 "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" - "k8s.io/apimachinery/pkg/util/wait" "github.com/numaproj/numaflow/pkg/isb" - sdkerr "github.com/numaproj/numaflow/pkg/sdkclient/error" "github.com/numaproj/numaflow/pkg/sdkclient/sourcetransformer" "github.com/numaproj/numaflow/pkg/shared/logging" "github.com/numaproj/numaflow/pkg/udf/rpc" @@ -54,7 +52,7 @@ func (u *GRPCBasedTransformer) IsHealthy(ctx context.Context) error { // WaitUntilReady waits until the client is connected. func (u *GRPCBasedTransformer) WaitUntilReady(ctx context.Context) error { - log := logging.FromContext(ctx) + logger := logging.FromContext(ctx) for { select { case <-ctx.Done(): @@ -63,7 +61,7 @@ func (u *GRPCBasedTransformer) WaitUntilReady(ctx context.Context) error { if _, err := u.client.IsReady(ctx, &emptypb.Empty{}); err == nil { return nil } else { - log.Infof("waiting for transformer to be ready: %v", err) + logger.Infof("waiting for transformer to be ready: %v", err) time.Sleep(1 * time.Second) } } @@ -75,103 +73,81 @@ func (u *GRPCBasedTransformer) CloseConn(ctx context.Context) error { return u.client.CloseConn(ctx) } -func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, readMessage *isb.ReadMessage) ([]*isb.WriteMessage, error) { - keys := readMessage.Keys - payload := readMessage.Body.Payload - offset := readMessage.ReadOffset - parentMessageInfo := readMessage.MessageInfo - var req = &v1.SourceTransformRequest{ - Keys: keys, - Value: payload, - EventTime: timestamppb.New(parentMessageInfo.EventTime), - Watermark: timestamppb.New(readMessage.Watermark), - Headers: readMessage.Headers, +func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*isb.ReadMessage) ([]isb.ReadWriteMessagePair, error) { + transformResults := make([]isb.ReadWriteMessagePair, len(messages)) + requests := make([]*v1.SourceTransformRequest, len(messages)) + idToMsgMapping := make(map[string]*isb.ReadMessage) + + for i, msg := range messages { + // we track the id to the message mapping to be able to match the response with the original message. + // we use the original message's event time if the user doesn't change it. Also we use the original message's + // read offset + index as the id for the response. + id := msg.ReadOffset.String() + idToMsgMapping[id] = msg + req := &v1.SourceTransformRequest{ + Request: &v1.SourceTransformRequest_Request{ + Keys: msg.Keys, + Value: msg.Body.Payload, + EventTime: timestamppb.New(msg.MessageInfo.EventTime), + Watermark: timestamppb.New(msg.Watermark), + Headers: msg.Headers, + Id: id, + }, + } + requests[i] = req } - response, err := u.client.SourceTransformFn(ctx, req) + responses, err := u.client.SourceTransformFn(ctx, requests) + if err != nil { - udfErr, _ := sdkerr.FromError(err) - switch udfErr.ErrorKind() { - case sdkerr.Retryable: - var success bool - _ = wait.ExponentialBackoffWithContext(ctx, wait.Backoff{ - // retry every "duration * factor + [0, jitter]" interval for 5 times - Duration: 1 * time.Second, - Factor: 1, - Jitter: 0.1, - Steps: 5, - }, func(_ context.Context) (done bool, err error) { - response, err = u.client.SourceTransformFn(ctx, req) - if err != nil { - udfErr, _ = sdkerr.FromError(err) - switch udfErr.ErrorKind() { - case sdkerr.Retryable: - return false, nil - case sdkerr.NonRetryable: - return true, nil - default: - return true, nil - } - } - success = true - return true, nil - }) - if !success { - return nil, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - } - } - case sdkerr.NonRetryable: - return nil, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - } - default: - return nil, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - } + err = &rpc.ApplyUDFErr{ + UserUDFErr: false, + Message: fmt.Sprintf("gRPC client.SourceTransformFn failed, %s", err), + InternalErr: rpc.InternalErr{ + Flag: true, + MainCarDown: false, + }, } + return nil, err } - taggedMessages := make([]*isb.WriteMessage, 0) - for i, result := range response.GetResults() { - keys := result.Keys - if result.EventTime != nil { - // Transformer supports changing event time. - parentMessageInfo.EventTime = result.EventTime.AsTime() + for i, resp := range responses { + parentMessage, ok := idToMsgMapping[resp.GetId()] + if !ok { + panic("tracker doesn't contain the message ID received from the response") } - taggedMessage := &isb.WriteMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: parentMessageInfo, - ID: isb.MessageID{ - VertexName: u.vertexName, - Offset: offset.String(), - Index: int32(i), + taggedMessages := make([]*isb.WriteMessage, len(resp.GetResults())) + for i, result := range resp.GetResults() { + keys := result.Keys + if result.EventTime != nil { + // Transformer supports changing event time. + parentMessage.MessageInfo.EventTime = result.EventTime.AsTime() + } + taggedMessage := &isb.WriteMessage{ + Message: isb.Message{ + Header: isb.Header{ + MessageInfo: parentMessage.MessageInfo, + ID: isb.MessageID{ + VertexName: u.vertexName, + Offset: parentMessage.ReadOffset.String(), + Index: int32(i), + }, + Keys: keys, + }, + Body: isb.Body{ + Payload: result.Value, }, - Keys: keys, - }, - Body: isb.Body{ - Payload: result.Value, }, - }, - Tags: result.Tags, + Tags: result.Tags, + } + taggedMessages[i] = taggedMessage + } + responsePair := isb.ReadWriteMessagePair{ + ReadMessage: parentMessage, + WriteMessages: taggedMessages, + Err: nil, } - taggedMessages = append(taggedMessages, taggedMessage) + transformResults[i] = responsePair } - return taggedMessages, nil + return transformResults, nil } diff --git a/pkg/sources/transformer/grpc_transformer_test.go b/pkg/sources/transformer/grpc_transformer_test.go index 959a40bf51..cd8ccbe852 100644 --- a/pkg/sources/transformer/grpc_transformer_test.go +++ b/pkg/sources/transformer/grpc_transformer_test.go @@ -19,101 +19,60 @@ package transformer import ( "context" "encoding/json" - "fmt" + "errors" + "net" "testing" "time" - "github.com/golang/mock/gomock" - v1 "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" - transformermock "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1/transformmock" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/timestamppb" + "github.com/numaproj/numaflow-go/pkg/sourcetransformer" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + transformpb "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" "github.com/numaproj/numaflow/pkg/isb" "github.com/numaproj/numaflow/pkg/isb/testutils" - "github.com/numaproj/numaflow/pkg/sdkclient/sourcetransformer" + sourcetransformerSdk "github.com/numaproj/numaflow/pkg/sdkclient/sourcetransformer" "github.com/numaproj/numaflow/pkg/udf/rpc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" ) -func NewMockGRPCBasedTransformer(mockClient *transformermock.MockSourceTransformClient) *GRPCBasedTransformer { - c, _ := sourcetransformer.NewFromClient(mockClient) - return &GRPCBasedTransformer{"test-vertex", c} -} - -func TestGRPCBasedTransformer_WaitUntilReadyWithMockClient(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().IsReady(gomock.Any(), gomock.Any()).Return(&v1.ReadyResponse{Ready: true}, nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - err := u.WaitUntilReady(ctx) - assert.NoError(t, err) -} - -type rpcMsg struct { - msg proto.Message -} - -func (r *rpcMsg) Matches(msg interface{}) bool { - m, ok := msg.(proto.Message) - if !ok { - return false +func TestGRPCBasedTransformer_WaitUntilReadyWithServer(t *testing.T) { + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + return sourcetransformer.Messages{} + }), } - return proto.Equal(m, r.msg) -} -func (r *rpcMsg) String() string { - return fmt.Sprintf("is %s", r.msg) + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + client, _ := sourcetransformerSdk.NewFromClient(context.Background(), transformClient) + u := NewGRPCBasedTransformer("testVertex", client) + err := u.WaitUntilReady(context.Background()) + assert.NoError(t, err) } -func TestGRPCBasedTransformer_BasicApplyWithMockClient(t *testing.T) { +func TestGRPCBasedTransformer_BasicApplyWithServer(t *testing.T) { t.Run("test success", func(t *testing.T) { - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_success_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169600, 0)), - Watermark: timestamppb.New(time.Time{}), + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + return sourcetransformer.MessagesBuilder().Append(sourcetransformer.NewMessage(datum.Value(), datum.EventTime()).WithKeys(keys)) + }), } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(&v1.SourceTransformResponse{ - Results: []*v1.SourceTransformResponse_Result{ - { - Keys: []string{"test_success_key"}, - Value: []byte(`forward_message`), - }, - }, - }, nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - u := NewMockGRPCBasedTransformer(mockClient) - got, err := u.ApplyTransform(ctx, &isb.ReadMessage{ + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + ctx := context.Background() + client, err := sourcetransformerSdk.NewFromClient(ctx, transformClient) + require.NoError(t, err, "creating source transformer client") + u := NewGRPCBasedTransformer("testVertex", client) + + got, err := u.ApplyTransform(ctx, []*isb.ReadMessage{{ Message: isb.Message{ Header: isb.Header{ MessageInfo: isb.MessageInfo{ @@ -130,94 +89,33 @@ func TestGRPCBasedTransformer_BasicApplyWithMockClient(t *testing.T) { }, }, ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, + }}, ) assert.NoError(t, err) - assert.Equal(t, req.Keys, got[0].Keys) - assert.Equal(t, req.Value, got[0].Payload) + assert.Equal(t, []string{"test_success_key"}, got[0].WriteMessages[0].Keys) + assert.Equal(t, []byte(`forward_message`), got[0].WriteMessages[0].Payload) }) t.Run("test error", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_error_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169660, 0)), - Watermark: timestamppb.New(time.Time{}), + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + return sourcetransformer.Messages{} + }), } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, fmt.Errorf("mock error")) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - u := NewMockGRPCBasedTransformer(mockClient) - _, err := u.ApplyTransform(ctx, &isb.ReadMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: isb.MessageInfo{ - EventTime: time.Unix(1661169660, 0), - }, - ID: isb.MessageID{ - VertexName: "test-vertex", - Offset: "0-0", - }, - Keys: []string{"test_error_key"}, - }, - Body: isb.Body{ - Payload: []byte(`forward_message`), - }, - }, - ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, - ) - assert.ErrorIs(t, err, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("%s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) }) - }) + transformClient := transformpb.NewSourceTransformClient(conn) + ctx, cancel := context.WithCancel(context.Background()) + client, err := sourcetransformerSdk.NewFromClient(ctx, transformClient) + require.NoError(t, err, "creating source transformer client") + u := NewGRPCBasedTransformer("testVertex", client) - t.Run("test error retryable: failed after 5 retries", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() + // This cancelled context is passed to the ApplyTransform function to simulate failure + cancel() - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_error_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169660, 0)), - Watermark: timestamppb.New(time.Time{}), - } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - _, err := u.ApplyTransform(ctx, &isb.ReadMessage{ + _, err = u.ApplyTransform(ctx, []*isb.ReadMessage{{ Message: isb.Message{ Header: isb.Header{ MessageInfo: isb.MessageInfo{ @@ -234,292 +132,155 @@ func TestGRPCBasedTransformer_BasicApplyWithMockClient(t *testing.T) { }, }, ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, + }}, ) - assert.ErrorIs(t, err, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("%s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - }) - }) - - t.Run("test error retryable: failed after 1 retry", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_error_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169660, 0)), - Watermark: timestamppb.New(time.Time{}), - } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.InvalidArgument, "mock test err: non retryable").Err()) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - _, err := u.ApplyTransform(ctx, &isb.ReadMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: isb.MessageInfo{ - EventTime: time.Unix(1661169660, 0), - }, - ID: isb.MessageID{ - VertexName: "test-vertex", - Offset: "0-0", - }, - Keys: []string{"test_error_key"}, - }, - Body: isb.Body{ - Payload: []byte(`forward_message`), - }, - }, - ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, - ) - assert.ErrorIs(t, err, &rpc.ApplyUDFErr{ + expectedUDFErr := &rpc.ApplyUDFErr{ UserUDFErr: false, - Message: fmt.Sprintf("%s", err), + Message: "gRPC client.SourceTransformFn failed, context canceled", InternalErr: rpc.InternalErr{ Flag: true, MainCarDown: false, }, - }) - }) - - t.Run("test error retryable: success after 1 retry", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_success_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169720, 0)), - Watermark: timestamppb.New(time.Time{}), - } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.DeadlineExceeded, "mock test err").Err()) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(&v1.SourceTransformResponse{ - Results: []*v1.SourceTransformResponse_Result{ - { - Keys: []string{"test_success_key"}, - Value: []byte(`forward_message`), - }, - }, - }, nil) - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - got, err := u.ApplyTransform(ctx, &isb.ReadMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: isb.MessageInfo{ - EventTime: time.Unix(1661169720, 0), - }, - ID: isb.MessageID{ - VertexName: "test-vertex", - Offset: "0-0", - }, - Keys: []string{"test_success_key"}, - }, - Body: isb.Body{ - Payload: []byte(`forward_message`), - }, - }, - ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, - ) - assert.NoError(t, err) - assert.Equal(t, req.Keys, got[0].Keys) - assert.Equal(t, req.Value, got[0].Payload) - }) - - t.Run("test error non retryable", func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - req := &v1.SourceTransformRequest{ - Keys: []string{"test_error_key"}, - Value: []byte(`forward_message`), - EventTime: timestamppb.New(time.Unix(1661169660, 0)), - Watermark: timestamppb.New(time.Time{}), } - mockClient.EXPECT().SourceTransformFn(gomock.Any(), &rpcMsg{msg: req}).Return(nil, status.New(codes.InvalidArgument, "mock test err: non retryable").Err()) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() - - u := NewMockGRPCBasedTransformer(mockClient) - _, err := u.ApplyTransform(ctx, &isb.ReadMessage{ - Message: isb.Message{ - Header: isb.Header{ - MessageInfo: isb.MessageInfo{ - EventTime: time.Unix(1661169660, 0), - }, - ID: isb.MessageID{ - VertexName: "test-vertex", - Offset: "0-0", - }, - Keys: []string{"test_error_key"}, - }, - Body: isb.Body{ - Payload: []byte(`forward_message`), - }, - }, - ReadOffset: isb.SimpleStringOffset(func() string { return "0" }), - }, - ) - assert.ErrorIs(t, err, &rpc.ApplyUDFErr{ - UserUDFErr: false, - Message: fmt.Sprintf("%s", err), - InternalErr: rpc.InternalErr{ - Flag: true, - MainCarDown: false, - }, - }) + var receivedErr *rpc.ApplyUDFErr + assert.ErrorAs(t, err, &receivedErr) + assert.Equal(t, expectedUDFErr, receivedErr) }) } -func TestGRPCBasedTransformer_ApplyWithMockClient_ChangePayload(t *testing.T) { - multiplyBy2 := func(body []byte) interface{} { - var result testutils.PayloadForTest - _ = json.Unmarshal(body, &result) - result.Value = result.Value * 2 - return result - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, datum *v1.SourceTransformRequest, opts ...grpc.CallOption) (*v1.SourceTransformResponse, error) { +func TestGRPCBasedTransformer_ApplyWithServer_ChangePayload(t *testing.T) { + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { var originalValue testutils.PayloadForTest - _ = json.Unmarshal(datum.GetValue(), &originalValue) - doubledValue, _ := json.Marshal(multiplyBy2(datum.GetValue()).(testutils.PayloadForTest)) - var Results []*v1.SourceTransformResponse_Result + _ = json.Unmarshal(datum.Value(), &originalValue) + doubledValue := testutils.PayloadForTest{ + Value: originalValue.Value * 2, + Key: originalValue.Key, + } + doubledValueBytes, _ := json.Marshal(&doubledValue) + + var resultKeys []string if originalValue.Value%2 == 0 { - Results = append(Results, &v1.SourceTransformResponse_Result{ - Keys: []string{"even"}, - Value: doubledValue, - }) + resultKeys = []string{"even"} } else { - Results = append(Results, &v1.SourceTransformResponse_Result{ - Keys: []string{"odd"}, - Value: doubledValue, - }) - } - datumList := &v1.SourceTransformResponse{ - Results: Results, + resultKeys = []string{"odd"} } - return datumList, nil - }, - ).AnyTimes() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") - } - }() + return sourcetransformer.MessagesBuilder().Append(sourcetransformer.NewMessage(doubledValueBytes, datum.EventTime()).WithKeys(resultKeys)) + }), + } - u := NewMockGRPCBasedTransformer(mockClient) + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + ctx := context.Background() + client, _ := sourcetransformerSdk.NewFromClient(ctx, transformClient) + u := NewGRPCBasedTransformer("testVertex", client) var count = int64(10) readMessages := testutils.BuildTestReadMessages(count, time.Unix(1661169600, 0), nil) - - var results = make([][]byte, len(readMessages)) - var resultKeys = make([][]string, len(readMessages)) + messages := make([]*isb.ReadMessage, len(readMessages)) for idx, readMessage := range readMessages { - apply, err := u.ApplyTransform(ctx, &readMessage) - assert.NoError(t, err) - results[idx] = apply[0].Payload - resultKeys[idx] = apply[0].Header.Keys + messages[idx] = &readMessage } + apply, err := u.ApplyTransform(context.TODO(), messages) + assert.NoError(t, err) - var expectedResults = make([][]byte, count) - var expectedKeys = make([][]string, count) - for idx, readMessage := range readMessages { + for _, pair := range apply { + resultPayload := pair.WriteMessages[0].Payload + resultKeys := pair.WriteMessages[0].Header.Keys var readMessagePayload testutils.PayloadForTest - _ = json.Unmarshal(readMessage.Payload, &readMessagePayload) + _ = json.Unmarshal(pair.ReadMessage.Payload, &readMessagePayload) + var expectedKeys []string if readMessagePayload.Value%2 == 0 { - expectedKeys[idx] = []string{"even"} + expectedKeys = []string{"even"} } else { - expectedKeys[idx] = []string{"odd"} + expectedKeys = []string{"odd"} } - marshal, _ := json.Marshal(multiplyBy2(readMessage.Payload)) - expectedResults[idx] = marshal - } + assert.Equal(t, expectedKeys, resultKeys) - assert.Equal(t, expectedResults, results) - assert.Equal(t, expectedKeys, resultKeys) + doubledValue := testutils.PayloadForTest{ + Key: readMessagePayload.Key, + Value: readMessagePayload.Value * 2, + } + marshal, _ := json.Marshal(doubledValue) + assert.Equal(t, marshal, resultPayload) + } } -func TestGRPCBasedTransformer_ApplyWithMockClient_ChangeEventTime(t *testing.T) { - testEventTime := time.Date(1992, 2, 8, 0, 0, 0, 100, time.UTC) - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := transformermock.NewMockSourceTransformClient(ctrl) - mockClient.EXPECT().SourceTransformFn(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, datum *v1.SourceTransformRequest, opts ...grpc.CallOption) (*v1.SourceTransformResponse, error) { - var Results []*v1.SourceTransformResponse_Result - Results = append(Results, &v1.SourceTransformResponse_Result{ - Keys: []string{"even"}, - Value: datum.Value, - EventTime: timestamppb.New(testEventTime), - }) - datumList := &v1.SourceTransformResponse{ - Results: Results, - } - return datumList, nil - }, - ).AnyTimes() +func newServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientConn { + lis := bufconn.Listen(100) + t.Cleanup(func() { + _ = lis.Close() + }) + + server := grpc.NewServer() + t.Cleanup(func() { + server.Stop() + }) + + register(server) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() + errChan := make(chan error, 1) go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - t.Log(t.Name(), "test timeout") + // t.Fatal should only be called from the goroutine running the test + if err := server.Serve(lis); err != nil { + errChan <- err } }() - u := NewMockGRPCBasedTransformer(mockClient) + dialer := func(context.Context, string) (net.Conn, error) { + return lis.Dial() + } + + conn, err := grpc.NewClient("passthrough://", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + t.Cleanup(func() { + _ = conn.Close() + }) + if err != nil { + t.Fatalf("Creating new gRPC client connection: %v", err) + } + + var grpcServerErr error + select { + case grpcServerErr = <-errChan: + case <-time.After(500 * time.Millisecond): + grpcServerErr = errors.New("gRPC server didn't start in 500ms") + } + if err != nil { + t.Fatalf("Failed to start gRPC server: %v", grpcServerErr) + } + + return conn +} + +func TestGRPCBasedTransformer_Apply_ChangeEventTime(t *testing.T) { + testEventTime := time.Date(1992, 2, 8, 0, 0, 0, 100, time.UTC) + svc := &sourcetransformer.Service{ + Transformer: sourcetransformer.SourceTransformFunc(func(ctx context.Context, keys []string, datum sourcetransformer.Datum) sourcetransformer.Messages { + msg := datum.Value() + return sourcetransformer.MessagesBuilder().Append(sourcetransformer.NewMessage(msg, testEventTime).WithKeys([]string{"even"})) + }), + } + conn := newServer(t, func(server *grpc.Server) { + transformpb.RegisterSourceTransformServer(server, svc) + }) + transformClient := transformpb.NewSourceTransformClient(conn) + ctx := context.Background() + client, _ := sourcetransformerSdk.NewFromClient(ctx, transformClient) + u := NewGRPCBasedTransformer("testVertex", client) var count = int64(2) readMessages := testutils.BuildTestReadMessages(count, time.Unix(1661169600, 0), nil) - for _, readMessage := range readMessages { - apply, err := u.ApplyTransform(ctx, &readMessage) - assert.NoError(t, err) - assert.Equal(t, testEventTime, apply[0].EventTime) + messages := make([]*isb.ReadMessage, len(readMessages)) + for idx, readMessage := range readMessages { + messages[idx] = &readMessage + } + apply, err := u.ApplyTransform(context.TODO(), messages) + assert.NoError(t, err) + for _, pair := range apply { + assert.NoError(t, pair.Err) + assert.Equal(t, testEventTime, pair.WriteMessages[0].EventTime) } } diff --git a/pkg/udf/forward/forward.go b/pkg/udf/forward/forward.go index 53efc945da..e768808cc3 100644 --- a/pkg/udf/forward/forward.go +++ b/pkg/udf/forward/forward.go @@ -481,7 +481,7 @@ func (isdf *InterStepDataForward) streamMessage(ctx context.Context, dataMessage if len(dataMessages) > 1 { errMsg := "data message size is not 1 with map UDF streaming" isdf.opts.logger.Errorw(errMsg) - return nil, fmt.Errorf(errMsg) + return nil, errors.New(errMsg) } else if len(dataMessages) == 1 { // send to map UDF only the data messages diff --git a/pkg/udf/rpc/grpc_batch_map.go b/pkg/udf/rpc/grpc_batch_map.go index 6d6c397642..ce65d201fb 100644 --- a/pkg/udf/rpc/grpc_batch_map.go +++ b/pkg/udf/rpc/grpc_batch_map.go @@ -26,26 +26,21 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/numaproj/numaflow/pkg/isb" + "github.com/numaproj/numaflow/pkg/isb/tracker" "github.com/numaproj/numaflow/pkg/sdkclient/batchmapper" "github.com/numaproj/numaflow/pkg/shared/logging" ) // GRPCBasedBatchMap is a map applier that uses gRPC client to invoke the map UDF. It implements the applier.MapApplier interface. type GRPCBasedBatchMap struct { - vertexName string - client batchmapper.Client - requestTracker *tracker + vertexName string + client batchmapper.Client } func NewUDSgRPCBasedBatchMap(vertexName string, client batchmapper.Client) *GRPCBasedBatchMap { return &GRPCBasedBatchMap{ vertexName: vertexName, client: client, - // requestTracker is used to store the read messages in a key, value manner where - // key is the read offset and the reference to read message as the value. - // Once the results are received from the UDF, we map the responses to the corresponding request - // using a lookup on this tracker. - requestTracker: NewTracker(), } } @@ -93,18 +88,17 @@ func (u *GRPCBasedBatchMap) ApplyBatchMap(ctx context.Context, messages []*isb.R // trackerReq is used to store the read messages in a key, value manner where // key is the read offset and the reference to read message as the value. // Once the results are received from the UDF, we map the responses to the corresponding request - // using a lookup on this tracker. - trackerReq := NewTracker() + // using a lookup on this Tracker. + trackerReq := tracker.NewMessageTracker(messages) // Read routine: this goroutine iterates over the input messages and sends each // of the read messages to the grpc client after transforming it to a BatchMapRequest. // Once all messages are sent, it closes the input channel to indicate that all requests have been read. - // On creating a new request, we add it to a tracker map so that the responses on the stream + // On creating a new request, we add it to a Tracker map so that the responses on the stream // can be mapped backed to the given parent request go func() { defer close(inputChan) for _, msg := range messages { - trackerReq.addRequest(msg) inputChan <- u.parseInputRequest(msg) } }() @@ -139,14 +133,14 @@ loop: } // Get the unique request ID for which these responses are meant for. msgId := grpcResp.GetId() - // Fetch the request value for the given ID from the tracker - parentMessage, ok := trackerReq.getRequest(msgId) - if !ok { - // this case is when the given request ID was not present in the tracker. + // Fetch the request value for the given ID from the Tracker + parentMessage := trackerReq.Remove(msgId) + if parentMessage == nil { + // this case is when the given request ID was not present in the Tracker. // This means that either the UDF added an incorrect ID // This cannot be processed further and should result in an error // Can there be another case for this? - logger.Error("Request missing from tracker, ", msgId) + logger.Error("Request missing from message tracker, ", msgId) return nil, fmt.Errorf("incorrect ID found during batch map processing") } // parse the responses received @@ -159,12 +153,11 @@ loop: Err: nil, } udfResults = append(udfResults, responsePair) - trackerReq.removeRequest(msgId) } } - // check if there are elements left in the tracker. This cannot be an acceptable case as we want the + // check if there are elements left in the Tracker. This cannot be an acceptable case as we want the // UDF to send responses for all elements. - if !trackerReq.isEmpty() { + if !trackerReq.IsEmpty() { logger.Error("BatchMap response for all requests not received from UDF") return nil, fmt.Errorf("batchMap response for all requests not received from UDF") } diff --git a/pkg/udf/rpc/tracker.go b/pkg/udf/rpc/tracker.go deleted file mode 100644 index 60b57a7af9..0000000000 --- a/pkg/udf/rpc/tracker.go +++ /dev/null @@ -1,75 +0,0 @@ -package rpc - -import ( - "sync" - - "github.com/numaproj/numaflow/pkg/isb" -) - -// tracker is used to store a key value pair for string and *isb.ReadMessage -// as it can be accessed by concurrent goroutines, we keep all operations -// under a mutex -type tracker struct { - lock sync.RWMutex - m map[string]*isb.ReadMessage -} - -// NewTracker initializes a new instance of a tracker -func NewTracker() *tracker { - return &tracker{ - m: make(map[string]*isb.ReadMessage), - lock: sync.RWMutex{}, - } -} - -// addRequest add a new entry for a given message to the tracker. -// the key is chosen as the read offset of the message -func (t *tracker) addRequest(msg *isb.ReadMessage) { - id := msg.ReadOffset.String() - t.set(id, msg) -} - -// getRequest returns the message corresponding to a given id, along with a bool -// to indicate if it does not exist -func (t *tracker) getRequest(id string) (*isb.ReadMessage, bool) { - return t.get(id) -} - -// removeRequest will remove the entry for a given id -func (t *tracker) removeRequest(id string) { - t.delete(id) -} - -// get is a helper function which fetches the message corresponding to a given id -// it acquires a lock before accessing the map -func (t *tracker) get(key string) (*isb.ReadMessage, bool) { - t.lock.RLock() - defer t.lock.RUnlock() - item, ok := t.m[key] - return item, ok -} - -// set is a helper function which add a key, value pair to the tracker map -// it acquires a lock before accessing the map -func (t *tracker) set(key string, msg *isb.ReadMessage) { - t.lock.Lock() - defer t.lock.Unlock() - t.m[key] = msg -} - -// delete is a helper function which will remove the entry for a given id -// it acquires a lock before accessing the map -func (t *tracker) delete(key string) { - t.lock.Lock() - defer t.lock.Unlock() - delete(t.m, key) -} - -// isEmpty is a helper function which checks if the tracker map is empty -// return true if empty -func (t *tracker) isEmpty() bool { - t.lock.RLock() - defer t.lock.RUnlock() - items := len(t.m) - return items == 0 -} diff --git a/pkg/webhook/validator/validator.go b/pkg/webhook/validator/validator.go index d5f2e86664..6d4e3e46a1 100644 --- a/pkg/webhook/validator/validator.go +++ b/pkg/webhook/validator/validator.go @@ -83,7 +83,10 @@ func GetValidator(ctx context.Context, NumaClient v1alpha1.NumaflowV1alpha1Inter // DeniedResponse constructs a denied AdmissionResponse func DeniedResponse(reason string, args ...interface{}) *admissionv1.AdmissionResponse { - result := apierrors.NewBadRequest(fmt.Sprintf(reason, args...)).Status() + if len(args) > 0 { + reason = fmt.Sprintf(reason, args) + } + result := apierrors.NewBadRequest(reason).Status() return &admissionv1.AdmissionResponse{ Result: &result, Allowed: false, diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 655f30bc4d..e2b3045712 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1557,7 +1557,7 @@ dependencies = [ [[package]] name = "numaflow" version = "0.1.1" -source = "git+https://github.com/numaproj/numaflow-rs.git?rev=0c1682864a4b906fab52e149cfd7cacc679ce688#0c1682864a4b906fab52e149cfd7cacc679ce688" +source = "git+https://github.com/numaproj/numaflow-rs.git?rev=30d8ce1972fd3f0c0b8059fee209516afeef0088#30d8ce1972fd3f0c0b8059fee209516afeef0088" dependencies = [ "chrono", "futures-util", diff --git a/rust/numaflow-core/Cargo.toml b/rust/numaflow-core/Cargo.toml index 85a3bc39b1..a10a46b9ab 100644 --- a/rust/numaflow-core/Cargo.toml +++ b/rust/numaflow-core/Cargo.toml @@ -38,7 +38,7 @@ log = "0.4.22" [dev-dependencies] tempfile = "3.11.0" -numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "0c1682864a4b906fab52e149cfd7cacc679ce688" } +numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "30d8ce1972fd3f0c0b8059fee209516afeef0088" } [build-dependencies] tonic-build = "0.12.1" diff --git a/rust/numaflow-core/proto/sourcetransform.proto b/rust/numaflow-core/proto/sourcetransform.proto index 18e045c323..9d0a63a9dc 100644 --- a/rust/numaflow-core/proto/sourcetransform.proto +++ b/rust/numaflow-core/proto/sourcetransform.proto @@ -9,21 +9,36 @@ service SourceTransform { // SourceTransformFn applies a function to each request element. // In addition to map function, SourceTransformFn also supports assigning a new event time to response. // SourceTransformFn can be used only at source vertex by source data transformer. - rpc SourceTransformFn(SourceTransformRequest) returns (SourceTransformResponse); + rpc SourceTransformFn(stream SourceTransformRequest) returns (stream SourceTransformResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); } +/* + * Handshake message between client and server to indicate the start of transmission. + */ + message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; +} + + /** * SourceTransformerRequest represents a request element. */ message SourceTransformRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + // This ID is used to uniquely identify a transform request + string id = 6; + } + Request request = 1; + optional Handshake handshake = 2; } /** @@ -37,6 +52,10 @@ message SourceTransformResponse { repeated string tags = 4; } repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; + // Handshake message between client and server to indicate the start of transmission. + optional Handshake handshake = 3; } /** @@ -44,4 +63,4 @@ message SourceTransformResponse { */ message ReadyResponse { bool ready = 1; -} \ No newline at end of file +} diff --git a/rust/numaflow-core/src/config.rs b/rust/numaflow-core/src/config.rs index 5d245ed397..c3263e999c 100644 --- a/rust/numaflow-core/src/config.rs +++ b/rust/numaflow-core/src/config.rs @@ -3,6 +3,7 @@ use base64::prelude::BASE64_STANDARD; use base64::Engine; use numaflow_models::models::{Backoff, MonoVertex, RetryStrategy}; use std::env; +use std::fmt::Display; use std::sync::OnceLock; const DEFAULT_SOURCE_SOCKET: &str = "/var/run/numaflow/source.sock"; @@ -53,17 +54,14 @@ impl OnFailureStrategy { _ => Some(DEFAULT_SINK_RETRY_ON_FAIL_STRATEGY), } } +} - /// Converts the `OnFailureStrategy` enum variant to a String. - /// This facilitates situations where the enum needs to be displayed or logged as a string. - /// - /// # Returns - /// A string representing the `OnFailureStrategy` enum variant. - fn to_string(&self) -> String { +impl Display for OnFailureStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match *self { - OnFailureStrategy::Retry => "retry".to_string(), - OnFailureStrategy::Fallback => "fallback".to_string(), - OnFailureStrategy::Drop => "drop".to_string(), + OnFailureStrategy::Retry => write!(f, "retry"), + OnFailureStrategy::Fallback => write!(f, "fallback"), + OnFailureStrategy::Drop => write!(f, "drop"), } } } @@ -647,4 +645,4 @@ mod tests { let drop = OnFailureStrategy::Drop; assert_eq!(drop.to_string(), "drop"); } -} +} \ No newline at end of file diff --git a/rust/numaflow-core/src/message.rs b/rust/numaflow-core/src/message.rs index b99a61b31d..d230e994fb 100644 --- a/rust/numaflow-core/src/message.rs +++ b/rust/numaflow-core/src/message.rs @@ -7,7 +7,7 @@ use chrono::{DateTime, Utc}; use crate::error::Error; use crate::monovertex::sink_pb::sink_request::Request; use crate::monovertex::sink_pb::SinkRequest; -use crate::monovertex::source_pb; +use crate::monovertex::{source_pb, sourcetransform_pb}; use crate::monovertex::source_pb::{read_response, AckRequest}; use crate::monovertex::sourcetransform_pb::SourceTransformRequest; use crate::shared::utils::{prost_timestamp_from_utc, utc_from_timestamp}; @@ -58,11 +58,15 @@ impl From for AckRequest { impl From for SourceTransformRequest { fn from(message: Message) -> Self { Self { - keys: message.keys, - value: message.value, - event_time: prost_timestamp_from_utc(message.event_time), - watermark: None, - headers: message.headers, + request: Some(sourcetransform_pb::source_transform_request::Request { + id: message.id, + keys: message.keys, + value: message.value, + event_time: prost_timestamp_from_utc(message.event_time), + watermark: None, + headers: message.headers, + }), + handshake: None, } } } diff --git a/rust/numaflow-core/src/monovertex/forwarder.rs b/rust/numaflow-core/src/monovertex/forwarder.rs index a32aff093b..ab58cfad03 100644 --- a/rust/numaflow-core/src/monovertex/forwarder.rs +++ b/rust/numaflow-core/src/monovertex/forwarder.rs @@ -1,3 +1,10 @@ +use chrono::Utc; +use log::warn; +use std::collections::HashMap; +use tokio::time::sleep; +use tokio_util::sync::CancellationToken; +use tracing::{debug, info}; + use crate::config::{config, OnFailureStrategy}; use crate::error; use crate::error::Error; @@ -8,13 +15,6 @@ use crate::monovertex::sink_pb::Status::{Failure, Fallback, Success}; use crate::sink::user_defined::SinkWriter; use crate::source::user_defined::Source; use crate::transformer::user_defined::SourceTransformer; -use chrono::Utc; -use log::warn; -use std::collections::HashMap; -use tokio::task::JoinSet; -use tokio::time::sleep; -use tokio_util::sync::CancellationToken; -use tracing::{debug, info}; /// Forwarder is responsible for reading messages from the source, applying transformation if /// transformer is present, writing the messages to the sink, and then acknowledging the messages @@ -193,26 +193,14 @@ impl Forwarder { // Applies transformation to the messages if transformer is present // we concurrently apply transformation to all the messages. - async fn apply_transformer(&self, messages: Vec) -> error::Result> { - let Some(transformer_client) = &self.source_transformer else { + async fn apply_transformer(&mut self, messages: Vec) -> error::Result> { + let Some(transformer_client) = &mut self.source_transformer else { // return early if there is no transformer return Ok(messages); }; let start_time = tokio::time::Instant::now(); - let mut jh = JoinSet::new(); - for message in messages { - let mut transformer_client = transformer_client.clone(); - jh.spawn(async move { transformer_client.transform_fn(message).await }); - } - - let mut results = Vec::new(); - while let Some(task) = jh.join_next().await { - let result = task.map_err(|e| Error::TransformerError(format!("{:?}", e)))?; - if let Some(result) = result? { - results.extend(result); - } - } + let results = transformer_client.transform_fn(messages).await?; debug!( "Transformer latency - {}ms", diff --git a/rust/numaflow-core/src/transformer/user_defined.rs b/rust/numaflow-core/src/transformer/user_defined.rs index de7b765b79..71a9d24cd6 100644 --- a/rust/numaflow-core/src/transformer/user_defined.rs +++ b/rust/numaflow-core/src/transformer/user_defined.rs @@ -1,67 +1,178 @@ -use crate::error; -use crate::message::Message; -use crate::monovertex::sourcetransform_pb::source_transform_client::SourceTransformClient; -use crate::monovertex::sourcetransform_pb::SourceTransformRequest; -use crate::shared::utils::utc_from_timestamp; +use std::collections::HashMap; + use tonic::transport::Channel; +use tonic::{Request, Streaming}; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; +use tracing::warn; +use crate::error::{Result, Error}; +use crate::message::{Message, Offset}; +use crate::monovertex::sourcetransform_pb::{self, SourceTransformRequest, SourceTransformResponse, source_transform_client::SourceTransformClient}; +use crate::shared::utils::utc_from_timestamp; +use crate::config::config; const DROP: &str = "U+005C__DROP__"; /// TransformerClient is a client to interact with the transformer server. -#[derive(Clone)] pub struct SourceTransformer { - client: SourceTransformClient, + read_tx: mpsc::Sender, + resp_stream: Streaming, } impl SourceTransformer { - pub(crate) async fn new(client: SourceTransformClient) -> error::Result { - Ok(Self { client }) - } + pub(crate) async fn new(mut client: SourceTransformClient) -> Result { + let (read_tx, read_rx) = mpsc::channel(config().batch_size as usize); + let read_stream = ReceiverStream::new(read_rx); - pub(crate) async fn transform_fn( - &mut self, - message: Message, - ) -> error::Result>> { - // fields which will not be changed - let offset = message.offset.clone(); - let id = message.id.clone(); - let headers = message.headers.clone(); - - // TODO: is this complex? the reason to do this is, tomorrow when we have the normal - // Pipeline CRD, we can require the Into trait. - let response = self - .client - .source_transform_fn(>::into(message)) + // do a handshake for read with the server before we start sending read requests + let handshake_request = SourceTransformRequest { + request: None, + handshake: Some(sourcetransform_pb::Handshake { sot: true }), + }; + read_tx.send(handshake_request).await.map_err(|e| { + Error::TransformerError(format!("failed to send handshake request: {}", e)) + })?; + + let mut resp_stream = client + .source_transform_fn(Request::new(read_stream)) .await? .into_inner(); - let mut messages = Vec::new(); - for result in response.results { - // if the message is tagged with DROP, we will not forward it. - if result.tags.contains(&DROP.to_string()) { - return Ok(None); + // first response from the server will be the handshake response. We need to check if the + // server has accepted the handshake. + let handshake_response = resp_stream.message().await?.ok_or(Error::TransformerError( + "failed to receive handshake response".to_string(), + ))?; + // handshake cannot to None during the initial phase and it has to set `sot` to true. + if handshake_response.handshake.map_or(true, |h| !h.sot) { + return Err(Error::TransformerError( + "invalid handshake response".to_string(), + )); + } + + Ok(Self { + read_tx, + resp_stream, + }) + } + + pub(crate) async fn transform_fn(&mut self, messages: Vec) -> Result> { + // fields which will not be changed + struct MessageInfo { + offset: Offset, + headers: HashMap, + } + + let mut tracker: HashMap = HashMap::with_capacity(messages.len()); + for message in &messages { + tracker.insert( + message.id.clone(), + MessageInfo { + offset: message.offset.clone(), + headers: message.headers.clone(), + }, + ); + } + + // Cancellation token is used to cancel either sending task (if an error occurs while receiving) or receiving messages (if an error occurs on sending task) + let token = CancellationToken::new(); + + // Send transform requests to the source transformer server + let sender_task: JoinHandle> = tokio::spawn({ + let read_tx = self.read_tx.clone(); + let token = token.clone(); + async move { + for msg in messages { + let result = tokio::select! { + result = read_tx.send(msg.into()) => result, + _ = token.cancelled() => { + warn!("Cancellation token was cancelled while sending source transform requests"); + return Ok(()); + }, + }; + + match result { + Ok(()) => continue, + Err(e) => { + token.cancel(); + return Err(Error::TransformerError(e.to_string())); + } + }; + } + Ok(()) } - let message = Message { - keys: result.keys, - value: result.value, - offset: offset.clone(), - id: id.clone(), - event_time: utc_from_timestamp(result.event_time), - headers: headers.clone(), + }); + + // Receive transformer results + let mut messages = Vec::new(); + while !tracker.is_empty() { + let resp = tokio::select! { + _ = token.cancelled() => { + break; + }, + resp = self.resp_stream.message() => {resp} + }; + + let resp = match resp { + Ok(Some(val)) => val, + Ok(None) => { + // Logging at warning level since we don't expect this to happen + warn!("Source transformer server closed its sending end of the stream. No more messages to receive"); + token.cancel(); + break; + } + Err(e) => { + token.cancel(); + return Err(Error::TransformerError(format!( + "gRPC error while receiving messages from source transformer server: {e:?}" + ))); + } + }; + + let Some((msg_id, msg_info)) = tracker.remove_entry(&resp.id) else { + token.cancel(); + return Err(Error::TransformerError(format!( + "Received message with unknown ID {}", + resp.id + ))); }; - messages.push(message); + + for (i, result) in resp.results.into_iter().enumerate() { + // TODO: Expose metrics + if result.tags.iter().any(|x| x == DROP) { + continue; + } + let message = Message { + id: format!("{}-{}", msg_id, i), + keys: result.keys, + value: result.value, + offset: msg_info.offset.clone(), + event_time: utc_from_timestamp(result.event_time), + headers: msg_info.headers.clone(), + }; + messages.push(message); + } } - Ok(Some(messages)) + sender_task.await.unwrap().map_err(|e| { + Error::TransformerError(format!( + "Sending messages to gRPC transformer failed: {e:?}", + )) + })?; + + Ok(messages) } } #[cfg(test)] mod tests { use std::error::Error; + use std::time::Duration; - use crate::monovertex::sourcetransform_pb::source_transform_client::SourceTransformClient; use crate::shared::utils::create_rpc_channel; + use crate::transformer::user_defined::sourcetransform_pb::source_transform_client::SourceTransformClient; use crate::transformer::user_defined::SourceTransformer; use numaflow::sourcetransform; use tempfile::TempDir; @@ -105,7 +216,7 @@ mod tests { let mut client = SourceTransformer::new(SourceTransformClient::new( create_rpc_channel(sock_file).await?, )) - .await?; + .await?; let message = crate::message::Message { keys: vec!["first".into()], @@ -115,18 +226,29 @@ mod tests { offset: "0".into(), }, event_time: chrono::Utc::now(), - id: "".to_string(), + id: "1".to_string(), headers: Default::default(), }; - let resp = client.transform_fn(message).await?; - assert!(resp.is_some()); - assert_eq!(resp.unwrap().len(), 1); + let resp = tokio::time::timeout( + tokio::time::Duration::from_secs(2), + client.transform_fn(vec![message]), + ) + .await??; + assert_eq!(resp.len(), 1); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(client); shutdown_tx .send(()) .expect("failed to send shutdown signal"); - handle.await.expect("failed to join server task"); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + handle.is_finished(), + "Expected gRPC server to have shut down" + ); Ok(()) } @@ -169,7 +291,7 @@ mod tests { let mut client = SourceTransformer::new(SourceTransformClient::new( create_rpc_channel(sock_file).await?, )) - .await?; + .await?; let message = crate::message::Message { keys: vec!["second".into()], @@ -183,8 +305,12 @@ mod tests { headers: Default::default(), }; - let resp = client.transform_fn(message).await?; - assert!(resp.is_none()); + let resp = client.transform_fn(vec![message]).await?; + assert!(resp.is_empty()); + + // we need to drop the client, because if there are any in-flight requests + // server fails to shut down. https://github.com/numaproj/numaflow-rs/issues/85 + drop(client); shutdown_tx .send(()) @@ -192,4 +318,4 @@ mod tests { handle.await.expect("failed to join server task"); Ok(()) } -} +} \ No newline at end of file diff --git a/rust/servesink/Cargo.toml b/rust/servesink/Cargo.toml index a9a768ac6c..80430c169b 100644 --- a/rust/servesink/Cargo.toml +++ b/rust/servesink/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] tonic = "0.12.0" tokio = { version = "1.0", features = ["macros", "rt-multi-thread"] } -numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "0c1682864a4b906fab52e149cfd7cacc679ce688" } +numaflow = { git = "https://github.com/numaproj/numaflow-rs.git", rev = "30d8ce1972fd3f0c0b8059fee209516afeef0088" } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml index 7bee8ef95f..8066caf9ec 100644 --- a/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml +++ b/test/transformer-e2e/testdata/extract-event-time-from-payload.yaml @@ -6,7 +6,7 @@ spec: vertices: - name: in source: - http: {} + http: { } transformer: builtin: name: eventTimeExtractor diff --git a/test/transformer-e2e/transformer_test.go b/test/transformer-e2e/transformer_test.go index 55b88f3683..e6b727fcb9 100644 --- a/test/transformer-e2e/transformer_test.go +++ b/test/transformer-e2e/transformer_test.go @@ -21,6 +21,7 @@ package e2e import ( "context" "encoding/json" + "errors" "fmt" "os" "strconv" @@ -142,7 +143,7 @@ wmLoop: for { select { case <-ctx.Done(): - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { s.T().Log("test timed out") assert.Fail(s.T(), "timed out") break wmLoop @@ -173,23 +174,24 @@ func (s *TransformerSuite) TestSourceTransformer() { } var wg sync.WaitGroup - wg.Add(4) - go func() { - defer wg.Done() - s.testSourceTransformer("python") - }() - go func() { - defer wg.Done() - s.testSourceTransformer("java") - }() + wg.Add(1) + // FIXME: Enable these tests after corresponding SDKs are changed to support bidirectional streaming + //go func() { + // defer wg.Done() + // s.testSourceTransformer("python") + //}() + //go func() { + // defer wg.Done() + // s.testSourceTransformer("java") + //}() go func() { defer wg.Done() s.testSourceTransformer("go") }() - go func() { - defer wg.Done() - s.testSourceTransformer("rust") - }() + //go func() { + // defer wg.Done() + // s.testSourceTransformer("rust") + //}() wg.Wait() }