diff --git a/proxy/proxy/proxy.go b/proxy/proxy/proxy.go index 607a939d..bb893f11 100644 --- a/proxy/proxy/proxy.go +++ b/proxy/proxy/proxy.go @@ -269,18 +269,41 @@ func (p *proxyStream) RecvMsg(m interface{}) error { } // Since the API is an interface{} we can change what this normally - // expects from a proto.Message to a *[]*ProxyRet instead. + // expects from a proto.Message to a *[]*Ret instead. // - // Anything else is an error. - manyRet, ok := m.(*[]*Ret) - if !ok { - return status.Errorf(codes.InvalidArgument, "args for proxy RecvMsg must be a *[]*ProxyRet) - got %T", m) + // Anything else is an error if we have > 1 target. In the one target + // case validate it's a proto.Message and unwrap into that instead. + var replyMsg proto.Message + var manyRet *[]*Ret + switch v := m.(type) { + case *[]*Ret: + manyRet = v + case proto.Message: + if len(p.ids) != 1 { + return status.Errorf(codes.InvalidArgument, "args for proxy RecvMsg must be a *[]*Ret) when called in OneMany context - got %T", m) + } + replyMsg = v + default: + if len(p.ids) != 1 { + return status.Errorf(codes.InvalidArgument, "args for proxy RecvMsg must be a *[]*Ret) when called in OneMany context - got %T", m) + } + return status.Errorf(codes.InvalidArgument, "args for proxy RecvMsg must be proto.Message when called directly - got %T", m) } // If we have any pre-canned errors push them on now. // Only send once or else the user gets spammed with errors for every Recv called. if !p.sentErrors { - *manyRet = append(*manyRet, p.errors...) + // In non OneMany context just return this directly as an error. + // Any other calls to RecvMsg will fall through below and get whatever + // the stream returns at that point. + if len(p.errors) > 0 { + if replyMsg != nil { + p.sentErrors = true + return p.errors[0].Error + } else { + *manyRet = append(*manyRet, p.errors...) + } + } p.sentErrors = true } @@ -302,7 +325,17 @@ func (p *proxyStream) RecvMsg(m interface{}) error { } p.ids[id].Resp = d.Payload p.ids[id].Error = nil - *manyRet = append(*manyRet, p.ids[id]) + if manyRet != nil { + *manyRet = append(*manyRet, p.ids[id]) + } + if replyMsg != nil { + if err := d.Payload.UnmarshalTo(replyMsg); err != nil { + return status.Errorf(codes.Internal, "can't unmarshal reply: %v", err) + } + // We know there's only one due to the precheck when we construct + // replyMsg. + return nil + } } case cl != nil: code := codes.Code(cl.GetStatus().GetCode()) @@ -325,6 +358,10 @@ func (p *proxyStream) RecvMsg(m interface{}) error { if streamStatus.Code() != codes.OK { closedErr = streamStatus.Err() } + if replyMsg != nil { + // Easy case for Recv() + return closedErr + } for _, id := range cl.StreamIds { p.ids[id].Error = closedErr p.ids[id].Resp = nil diff --git a/proxy/proxy/proxy_test.go b/proxy/proxy/proxy_test.go index a4732714..0f99c0b3 100644 --- a/proxy/proxy/proxy_test.go +++ b/proxy/proxy/proxy_test.go @@ -293,6 +293,41 @@ func TestStreaming(t *testing.T) { }, } { tc := tc + t.Run(tc.name+" direct", func(t *testing.T) { + conn, err := proxy.Dial(tc.proxy, tc.targets, testutil.WithBufDialer(bufMap), grpc.WithTransportCredentials(insecure.NewCredentials())) + tu.FatalOnErr("Dial", err, t) + + ts := tdpb.NewTestServiceClientProxy(conn) + stream, err := ts.TestBidiStream(context.Background()) + tu.FatalOnErr("getting stream", err, t) + + // We only care about validating Send/Recv work cleanly in 1:1 or error in 1:N + + // Should always be able to Send + err = stream.Send(&tdpb.TestRequest{Input: "input"}) + tu.FatalOnErr("Send", err, t) + + // Now a normal recv should either work or fail depending on > 1 target (or not) + _, err = stream.Recv() + if len(tc.targets) > 1 { + tu.FatalOnNoErr("recv didn't fail for > 1 target", err, t) + } else { + tu.FatalOnErr("Recv", err, t) + } + + // Now test the error case + err = stream.Send(&tdpb.TestRequest{Input: "error"}) + tu.FatalOnErr("Send error", err, t) + + // Shouldn't fail even we close send twice. + err = stream.CloseSend() + tu.FatalOnErr("CloseSend", err, t) + err = stream.CloseSend() + tu.FatalOnErr("CloseSend", err, t) + _, err = stream.Recv() + tu.FatalOnNoErr("recv should get error from send", err, t) + t.Log(err) + }) t.Run(tc.name, func(t *testing.T) { conn, err := proxy.Dial(tc.proxy, tc.targets, testutil.WithBufDialer(bufMap), grpc.WithTransportCredentials(insecure.NewCredentials())) tu.FatalOnErr("Dial", err, t)