diff --git a/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go b/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go index 561309cebe..6b04e14db7 100644 --- a/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go +++ b/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go @@ -31,14 +31,13 @@ func (ri *RecoveryInterceptor) UnaryServerInterceptor() grpc.UnaryServerIntercep defer func() { if r := recover(); r != nil { ri.panicCounter.Inc() - logger.Fatalf(ctx, "panic-ed for request: [%+v] to %s with err: %v with Stack: %v", req, info.FullMethod, r, string(debug.Stack())) + logger.Errorf(ctx, "panic-ed for request: [%+v] to %s with err: %v with Stack: %v", req, info.FullMethod, r, string(debug.Stack())) // Return INTERNAL to client with no info as to not leak implementation details err = status.Errorf(codes.Internal, "") } }() - resp, err := handler(ctx, req) - return resp, err + return handler(ctx, req) } } @@ -49,13 +48,12 @@ func (ri *RecoveryInterceptor) StreamServerInterceptor() grpc.StreamServerInterc defer func() { if r := recover(); r != nil { ri.panicCounter.Inc() - logger.Fatalf(stream.Context(), "panic-ed for stream to %s with err: %v with Stack: %v", info.FullMethod, r, string(debug.Stack())) + logger.Errorf(stream.Context(), "panic-ed for stream to %s with err: %v with Stack: %v", info.FullMethod, r, string(debug.Stack())) // Return INTERNAL to client with no info as to not leak implementation details err = status.Errorf(codes.Internal, "") } }() - err = handler(srv, stream) - return err + return handler(srv, stream) } } diff --git a/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor_test.go b/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor_test.go new file mode 100644 index 0000000000..5a6dfa832d --- /dev/null +++ b/flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor_test.go @@ -0,0 +1,38 @@ +package middleware + +import ( + "context" + mockScope "github.com/flyteorg/flyte/flytestdlib/promutils" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "testing" +) + +func TestRecoveryInterceptor(t *testing.T) { + ctx := context.Background() + testScope := mockScope.NewTestScope() + recoveryInterceptor := NewRecoveryInterceptor(testScope) + unaryInterceptor := recoveryInterceptor.UnaryServerInterceptor() + info := &grpc.UnaryServerInfo{} + req := "test-request" + + t.Run("should recover from panic", func(t *testing.T) { + _, err := unaryInterceptor(ctx, req, info, func(ctx context.Context, req any) (any, error) { + panic("synthetic") + }) + expectedErr := status.Errorf(codes.Internal, "") + require.Error(t, err) + require.Equal(t, expectedErr, err) + }) + + t.Run("should plumb response without panic", func(t *testing.T) { + mockedResponse := "test" + resp, err := unaryInterceptor(ctx, req, info, func(ctx context.Context, req any) (any, error) { + return mockedResponse, nil + }) + require.NoError(t, err) + require.Equal(t, mockedResponse, resp) + }) +}