diff --git a/server.go b/server.go index 70fe23f55022..1da2a542acde 100644 --- a/server.go +++ b/server.go @@ -1598,6 +1598,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv s: stream, p: &parser{r: stream, bufferPool: s.opts.bufferPool}, codec: s.getCodec(stream.ContentSubtype()), + desc: sd, maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, diff --git a/stream.go b/stream.go index ca6948926f93..f2698d47f216 100644 --- a/stream.go +++ b/stream.go @@ -1580,6 +1580,7 @@ type serverStream struct { s *transport.ServerStream p *parser codec baseCodec + desc *StreamDesc compressorV0 Compressor compressorV1 encoding.Compressor @@ -1774,6 +1775,9 @@ func (ss *serverStream) RecvMsg(m any) (err error) { binlog.Log(ss.ctx, chc) } } + if !ss.desc.ClientStreams { + return status.Error(codes.Internal, "cardinality violation: received no request message from non-client-stream RPC") + } return err } if err == io.ErrUnexpectedEOF { @@ -1800,7 +1804,19 @@ func (ss *serverStream) RecvMsg(m any) (err error) { binlog.Log(ss.ctx, cm) } } - return nil + + if ss.desc.ClientStreams { + // Subsequent messages should be received by subsequent RecvMsg calls. + return nil + } + // Special handling for non-client-stream rpcs. + // This recv expects EOF or errors, so we don't collect inPayload. + if err := recv(ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF { + return nil + } else if err != nil { + return err + } + return status.Error(codes.Internal, "cardinality violation: expected for non client-streaming RPCs, but received another message") } // MethodFromServerStream returns the method string for the input stream. diff --git a/test/end2end_test.go b/test/end2end_test.go index 584c90ca3b15..416ea459a2cf 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3740,6 +3740,345 @@ func (s) TestClientStreaming_ReturnErrorAfterSendAndClose(t *testing.T) { } } +// Tests the behavior for server-side streaming when server calls RecvMsg twice. +// Second call to RecvMsg should fail with Internal error. +func (s) TestServerStreaming_ServerCallRecvMsgTwice(t *testing.T) { + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + ss := stubserver.StubServer{ + StreamingOutputCallF: func(_ *testpb.StreamingOutputCallRequest, stream testgrpc.TestService_StreamingOutputCallServer) error { + // This is second call to RecvMsg(), the initial call having been performed by the server handler. + if err := stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + stream, err := ss.Client.StreamingOutputCall(ctx, &testpb.StreamingOutputCallRequest{}) + if err != nil { + t.Fatalf(".StreamingOutputCall(_) = _, %v, want ", err) + } + + if err := stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests the behavior for server-side streaming when client calls SendMsg twice. +// Second call to SendMsg should fail with Internal error. +func (s) TestServerStreaming_ClientCallSendMsgTwice(t *testing.T) { + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + s := grpc.NewServer() + serviceDesc := grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*any)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "ServerStreaming", + Handler: func(_ any, _ grpc.ServerStream) error { + return nil + }, + ClientStreams: false, + ServerStreams: true, + }, + }, + } + s.RegisterService(&serviceDesc, &testServer{}) + go s.Serve(lis) + defer s.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", lis.Addr(), err) + } + defer cc.Close() + + desc := &grpc.StreamDesc{ + StreamName: "ServerStreaming", + ServerStreams: true, + ClientStreams: false, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/ServerStreaming") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.SendMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests the behavior for unary RPC when server calls RecvMsg twice. Second call +// to RecvMsg should fail with Internal error. +func (s) TestUnaryRPC_ServerCallRecvMsgTwice(t *testing.T) { + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + s := grpc.NewServer() + serviceDesc := grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*any)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "UnaryCall", + Handler: func(_ any, stream grpc.ServerStream) error { + err := stream.RecvMsg(&testpb.Empty{}) + if err != nil { + t.Errorf("stream.RecvMsg() = %v, want ", err) + } + + if err = stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } + return nil + }, + ClientStreams: false, + ServerStreams: false, + }, + }, + } + s.RegisterService(&serviceDesc, &testServer{}) + go s.Serve(lis) + defer s.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", lis.Addr(), err) + } + defer cc.Close() + + desc := &grpc.StreamDesc{ + StreamName: "UnaryCall", + ServerStreams: false, + ClientStreams: false, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/UnaryCall") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + if err := stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests the behavior for unary RPC when client calls SendMsg twice. Second call +// to SendMsg should fail with Internal error. +func (s) TestUnaryRPC_ClientCallSendMsgTwice(t *testing.T) { + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + s := grpc.NewServer() + serviceDesc := grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*any)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "UnaryCall", + Handler: func(_ any, _ grpc.ServerStream) error { + return nil + }, + ClientStreams: false, + ServerStreams: false, + }, + }, + } + s.RegisterService(&serviceDesc, &testServer{}) + go s.Serve(lis) + defer s.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", lis.Addr(), err) + } + defer cc.Close() + + desc := &grpc.StreamDesc{ + StreamName: "UnaryCall", + ServerStreams: false, + ClientStreams: false, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/UnaryCall") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.SendMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests the behavior for server-side streaming RPC when client misbehaves as Bidi-streaming +// and sends multiple nessages. +func (s) TestServerStreaming_ClientSendsMultipleMessages(t *testing.T) { + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + s := grpc.NewServer() + serviceDesc := grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*any)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "ServerStreaming", + Handler: func(_ any, stream grpc.ServerStream) error { + if err = stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } + return nil + }, + ClientStreams: false, + ServerStreams: true, + }, + }, + } + s.RegisterService(&serviceDesc, &testServer{}) + go s.Serve(lis) + defer s.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", lis.Addr(), err) + } + defer cc.Close() + + // Making the client bi-di to bypass the client side checks that stop a non-streaming client + // from sending multiple messages. + desc := &grpc.StreamDesc{ + StreamName: "ServerStreaming", + ServerStreams: true, + ClientStreams: true, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/ServerStreaming") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + if err := stream.SendMsg(&testpb.Empty{}); err != nil { + t.Errorf("stream.SendMsg() = %v, want ", err) + } + + if err := stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + +// Tests the behavior for server-side streaming RPC when client sends zero request message. +func (s) TestServerStreaming_ClientSendsZeroRequest(t *testing.T) { + lis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + s := grpc.NewServer() + serviceDesc := grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*any)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "ServerStreaming", + Handler: func(_ any, stream grpc.ServerStream) error { + if err = stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } + return nil + }, + ClientStreams: false, + ServerStreams: true, + }, + }, + } + s.RegisterService(&serviceDesc, &testServer{}) + go s.Serve(lis) + defer s.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%q) failed unexpectedly: %v", lis.Addr(), err) + } + defer cc.Close() + + desc := &grpc.StreamDesc{ + StreamName: "ServerStreaming", + ServerStreams: true, + ClientStreams: false, + } + + stream, err := cc.NewStream(ctx, desc, "/grpc.testing.TestService/ServerStreaming") + if err != nil { + t.Fatalf("cc.NewStream() failed unexpectedly: %v", err) + } + + if err := stream.CloseSend(); err != nil { + t.Errorf("stream.CloseSend() = %v, want ", err) + } + if err := stream.RecvMsg(&testpb.Empty{}); status.Code(err) != codes.Internal { + t.Errorf("stream.RecvMsg() = %v, want error %v", status.Code(err), codes.Internal) + } +} + // Tests that a client receives a cardinality violation error for client-streaming // RPCs if the server call SendMsg multiple times. func (s) TestClientStreaming_ServerHandlerSendMsgAfterSendMsg(t *testing.T) {