Skip to content

Commit

Permalink
Add recovery interceptor
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Parraga <[email protected]>
  • Loading branch information
Sovietaced committed Jul 9, 2024
1 parent 9b78125 commit 3c1920b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 4 deletions.
61 changes: 61 additions & 0 deletions flyteadmin/pkg/rpc/adminservice/middleware/recovery_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package middleware

import (
"context"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"runtime/debug"
)

// RecoveryInterceptor is a struct for creating gRPC interceptors that handle panics in go
type RecoveryInterceptor struct {
panicCounter prometheus.Counter
}

// NewRecoveryInterceptor creates a new RecoveryInterceptor with metrics under the provided scope
func NewRecoveryInterceptor(adminScope promutils.Scope) *RecoveryInterceptor {
panicCounter := adminScope.MustNewCounter("handler_panic", "panics encountered while handling gRPC requests")
return &RecoveryInterceptor{
panicCounter: panicCounter,
}
}

// UnaryServerInterceptor returns a new unary server interceptor for panic recovery.
func (ri *RecoveryInterceptor) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ any, err error) {

defer func() {
if r := recover(); r != nil {
ri.panicCounter.Inc()
logger.Fatalf(ctx, "panic-ed for request: [%+v] to %s with err: %v with Stack: %v", req, info.FullMethod, r, string(debug.Stack()))
// Return INTERNAL to client with no info as to not leak implementation details
err = status.Errorf(codes.Internal, "")
}
}()

resp, err := handler(ctx, req)
return resp, err
}
}

// StreamServerInterceptor returns a new streaming server interceptor for panic recovery.
func (ri *RecoveryInterceptor) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {

defer func() {
if r := recover(); r != nil {
ri.panicCounter.Inc()
logger.Fatalf(stream.Context(), "panic-ed for stream to %s with err: %v with Stack: %v", info.FullMethod, r, string(debug.Stack()))
// Return INTERNAL to client with no info as to not leak implementation details
err = status.Errorf(codes.Internal, "")
}
}()

err = handler(srv, stream)
return err
}
}
29 changes: 25 additions & 4 deletions flyteadmin/pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/flyteorg/flyte/flyteadmin/pkg/rpc/adminservice/middleware"
"net"
"net/http"
"strings"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/gorilla/handlers"
grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpcrecovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/pkg/errors"
Expand Down Expand Up @@ -98,11 +100,18 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
otelgrpc.WithPropagators(propagation.TraceContext{}),
)

adminScope := scope.NewSubScope("admin")
recoveryInterceptor := middleware.NewRecoveryInterceptor(adminScope)

var chainedUnaryInterceptors grpc.UnaryServerInterceptor
if cfg.Security.UseAuth {
logger.Infof(ctx, "Creating gRPC server with authentication")
middlewareInterceptors := plugins.Get[grpc.UnaryServerInterceptor](pluginRegistry, plugins.PluginIDUnaryServiceMiddleware)
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor,
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(
// recovery interceptor should always be first in order to handle any panics in the middleware or server
recoveryInterceptor.UnaryServerInterceptor(),
grpcrecovery.UnaryServerInterceptor(),
grpcprometheus.UnaryServerInterceptor,
otelUnaryServerInterceptor,
auth.GetAuthenticationCustomMetadataInterceptor(authCtx),
grpcauth.UnaryServerInterceptor(auth.GetAuthenticationInterceptor(authCtx)),
Expand All @@ -111,11 +120,23 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
)
} else {
logger.Infof(ctx, "Creating gRPC server without authentication")
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor, otelUnaryServerInterceptor)
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(
// recovery interceptor should always be first in order to handle any panics in the middleware or server
recoveryInterceptor.UnaryServerInterceptor(),
grpcprometheus.UnaryServerInterceptor,
otelUnaryServerInterceptor,
)
}

chainedStreamInterceptors := grpcmiddleware.ChainStreamServer(
// recovery interceptor should always be first in order to handle any panics in the middleware or server
recoveryInterceptor.StreamServerInterceptor(),
grpcprometheus.StreamServerInterceptor,
)

serverOpts := []grpc.ServerOption{
grpc.StreamInterceptor(grpcprometheus.StreamServerInterceptor),
// recovery interceptor should always be first in order to handle any panics in the middleware or server
grpc.StreamInterceptor(chainedStreamInterceptors),
grpc.UnaryInterceptor(chainedUnaryInterceptors),
}
if cfg.GrpcConfig.MaxMessageSizeBytes > 0 {
Expand All @@ -131,7 +152,7 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
}

configuration := runtime2.NewConfigurationProvider()
adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, scope.NewSubScope("admin"))
adminServer := adminservice.NewAdminServer(ctx, pluginRegistry, configuration, cfg.KubeConfig, cfg.Master, dataStorageClient, adminScope)
grpcService.RegisterAdminServiceServer(grpcServer, adminServer)
if cfg.Security.UseAuth {
grpcService.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService())
Expand Down

0 comments on commit 3c1920b

Please sign in to comment.