diff --git a/auth/mtls/mtls_test.go b/auth/mtls/mtls_test.go index 1c3d9e46..12494df2 100644 --- a/auth/mtls/mtls_test.go +++ b/auth/mtls/mtls_test.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/emptypb" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/server" hcpb "github.com/Snowflake-Labs/sansshell/services/healthcheck" _ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server" @@ -100,7 +101,7 @@ func serverWithPolicy(t *testing.T, policy string, CAPool *x509.CertPool) (*bufc creds, err := LoadServerTLS("testdata/leaf.pem", "testdata/leaf.key", CAPool) testutil.FatalOnErr("Failed to load client cert", err, t) lis := bufconn.Listen(bufSize) - s, err := server.BuildServer(creds, policy, lis.Addr(), logr.Discard()) + s, err := server.BuildServer(creds, policy, logr.Discard(), rpcauth.HostNetHook(lis.Addr())) testutil.FatalOnErr("Could not build server", err, t) listening := make(chan struct{}) go func() { diff --git a/auth/opa/rpcauth/hooks.go b/auth/opa/rpcauth/hooks.go index f134d10d..3baf2a7d 100644 --- a/auth/opa/rpcauth/hooks.go +++ b/auth/opa/rpcauth/hooks.go @@ -19,6 +19,9 @@ package rpcauth import ( "context" "net" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // RPCAuthzHookFunc implements RpcAuthzHook for a simple function @@ -53,7 +56,7 @@ func (c *conditionalHook) Hook(ctx context.Context, input *RPCAuthInput) error { return nil } -// HostNetHook returns an RpcAuthzHook that sets host networking information. +// HostNetHook returns an RPCAuthzHook that sets host networking information. func HostNetHook(addr net.Addr) RPCAuthzHook { return RPCAuthzHookFunc(func(ctx context.Context, input *RPCAuthInput) error { if input.Host == nil { @@ -63,3 +66,38 @@ func HostNetHook(addr net.Addr) RPCAuthzHook { return nil }) } + +const ( + // ReqJustKey is the key name that must exist in the incoming + // context metadata if client side provided justification is required. + ReqJustKey = "sansshell-justification" +) + +var ( + // ErrJustification is the error returned for missing justification. + ErrJustification = status.Error(codes.FailedPrecondition, "missing justification") +) + +// JustificationHook takes the given optional justification function and returns an RPCAuthzHook +// that validates if justification was included. If it is required and passes the optional validation function +// it will return nil, otherwise an error. +func JustificationHook(justificationFunc func(string) error) RPCAuthzHook { + return RPCAuthzHookFunc(func(ctx context.Context, input *RPCAuthInput) error { + // See if we got any metadata and if it contains the justification + var j string + v := input.Metadata[ReqJustKey] + if len(v) > 0 { + j = v[0] + } + + if j == "" { + return ErrJustification + } + if justificationFunc != nil { + if err := justificationFunc(j); err != nil { + return status.Errorf(codes.FailedPrecondition, "justification failed: %v", err) + } + } + return nil + }) +} diff --git a/auth/opa/rpcauth/rpcauth_test.go b/auth/opa/rpcauth/rpcauth_test.go index 1194eff5..a559b6c1 100644 --- a/auth/opa/rpcauth/rpcauth_test.go +++ b/auth/opa/rpcauth/rpcauth_test.go @@ -219,6 +219,59 @@ func TestAuthzHook(t *testing.T) { }, errFunc: wantStatusCode(codes.OK), }, + { + name: "network data allow with justification (no func)", + input: &RPCAuthInput{ + Method: "/Foo.Bar/Foo", + Metadata: metadata.MD{ + ReqJustKey: []string{"justification"}, + }, + }, + hooks: []RPCAuthzHook{ + HostNetHook(tcp), + JustificationHook(nil), + }, + errFunc: wantStatusCode(codes.OK), + }, + { + name: "network data allow with justification req but none given (no func)", + input: &RPCAuthInput{ + Method: "/Foo.Bar/Foo", + }, + hooks: []RPCAuthzHook{ + HostNetHook(tcp), + JustificationHook(nil), + }, + errFunc: wantStatusCode(codes.FailedPrecondition), + }, + { + name: "network data allow with justification (with func)", + input: &RPCAuthInput{ + Method: "/Foo.Bar/Foo", + Metadata: metadata.MD{ + ReqJustKey: []string{"justification"}, + }, + }, + hooks: []RPCAuthzHook{ + HostNetHook(tcp), + JustificationHook(func(string) error { return nil }), + }, + errFunc: wantStatusCode(codes.OK), + }, + { + name: "network data allow with justification req given and func fails", + input: &RPCAuthInput{ + Method: "/Foo.Bar/Foo", + Metadata: metadata.MD{ + ReqJustKey: []string{"justification"}, + }, + }, + hooks: []RPCAuthzHook{ + HostNetHook(tcp), + JustificationHook(func(string) error { return errors.New("error") }), + }, + errFunc: wantStatusCode(codes.FailedPrecondition), + }, { name: "conditional hook, triggered", input: &RPCAuthInput{Method: "/Some.Random/Method"}, diff --git a/cmd/proxy-server/main.go b/cmd/proxy-server/main.go index e6191f38..3496d417 100644 --- a/cmd/proxy-server/main.go +++ b/cmd/proxy-server/main.go @@ -28,6 +28,7 @@ import ( "github.com/Snowflake-Labs/sansshell/auth/mtls" mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags" "github.com/Snowflake-Labs/sansshell/auth/opa" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/cmd/proxy-server/server" "github.com/Snowflake-Labs/sansshell/cmd/util" "github.com/go-logr/stdr" @@ -37,12 +38,13 @@ var ( //go:embed default-policy.rego defaultPolicy string - policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.") - policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.") - hostport = flag.String("hostport", "localhost:50043", "Where to listen for connections.") - credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS creds (one of [%s])", strings.Join(mtls.Loaders(), ","))) - verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging") - validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)") + policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.") + policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.") + hostport = flag.String("hostport", "localhost:50043", "Where to listen for connections.") + credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS creds (one of [%s])", strings.Join(mtls.Loaders(), ","))) + verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging") + validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)") + justification = flag.Bool("justification", false, "If true then justification (which is logged and possibly validated) must be passed along in the client context Metadata with the key '"+rpcauth.ReqJustKey+"'") ) func main() { @@ -65,10 +67,11 @@ func main() { } rs := server.RunState{ - Logger: logger, - Policy: policy, - CredSource: *credSource, - Hostport: *hostport, + Logger: logger, + Policy: policy, + CredSource: *credSource, + Hostport: *hostport, + Justification: *justification, } server.Run(ctx, rs) } diff --git a/cmd/proxy-server/server/server.go b/cmd/proxy-server/server/server.go index e7915e31..70a13ad9 100644 --- a/cmd/proxy-server/server/server.go +++ b/cmd/proxy-server/server/server.go @@ -53,6 +53,13 @@ type RunState struct { CredSource string // Hostport is the host:port to run the server. Hostport string + // Justification if true requires justification to be set in the + // incoming RPC context Metadata (to the key defined in the telemetry package). + Justification bool + // JustificationFunc will be called if Justication is true and a justification + // entry is found. The supplied function can then do any validation it wants + // in order to ensure it's compliant. + JustificationFunc func(string) error } // Run takes the given context and RunState along with any authz hooks and starts up a sansshell proxy server @@ -80,7 +87,11 @@ func Run(ctx context.Context, rs RunState, hooks ...rpcauth.RPCAuthzHook) { addressHook := rpcauth.HookIf(rpcauth.HostNetHook(lis.Addr()), func(input *rpcauth.RPCAuthInput) bool { return input.Host == nil || input.Host.Net == nil }) - h := []rpcauth.RPCAuthzHook{addressHook} + justificationHook := rpcauth.HookIf(rpcauth.JustificationHook(rs.JustificationFunc), func(input *rpcauth.RPCAuthInput) bool { + return rs.Justification + }) + + h := []rpcauth.RPCAuthzHook{addressHook, justificationHook} h = append(h, hooks...) authz, err := rpcauth.NewWithPolicy(ctx, rs.Policy, h...) if err != nil { diff --git a/cmd/sanssh/main.go b/cmd/sanssh/main.go index 3e9d3139..474e5670 100644 --- a/cmd/sanssh/main.go +++ b/cmd/sanssh/main.go @@ -26,19 +26,22 @@ import ( "github.com/Snowflake-Labs/sansshell/auth/mtls" mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/cmd/sanssh/client" "github.com/Snowflake-Labs/sansshell/services/util" "github.com/google/subcommands" + "google.golang.org/grpc/metadata" ) var ( defaultAddress = "localhost:50042" defaultTimeout = 3 * time.Second - proxyAddr = flag.String("proxy", "", "Address to contact for proxy to sansshell-server. If blank a direct connection to the first entry in --targets will be made") - timeout = flag.Duration("timeout", defaultTimeout, "How long to wait for the command to complete") - credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ","))) - outputsDir = flag.String("output-dir", "", "If set defines a directory to emit output/errors from commands. Files will be generated based on target as destination/0 destination/0.error, etc.") + proxyAddr = flag.String("proxy", "", "Address to contact for proxy to sansshell-server. If blank a direct connection to the first entry in --targets will be made") + timeout = flag.Duration("timeout", defaultTimeout, "How long to wait for the command to complete") + credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ","))) + outputsDir = flag.String("output-dir", "", "If set defines a directory to emit output/errors from commands. Files will be generated based on target as destination/0 destination/0.error, etc.") + justification = flag.String("justification", "", "If non-empty will add the key '"+rpcauth.ReqJustKey+"' to the outgoing context Metadata to be passed along to the server for possbile validation and logging.") // targets will be bound to --targets for sending a single request to N nodes. targetsFlag util.StringSliceFlag @@ -62,6 +65,8 @@ func init() { subcommands.ImportantFlag("proxy") subcommands.ImportantFlag("targets") subcommands.ImportantFlag("outputs") + subcommands.ImportantFlag("output-dir") + subcommands.ImportantFlag("justification") } func main() { @@ -75,5 +80,9 @@ func main() { CredSource: *credSource, Timeout: *timeout, } - client.Run(context.Background(), rs) + ctx := context.Background() + if *justification != "" { + ctx = metadata.AppendToOutgoingContext(ctx, rpcauth.ReqJustKey, *justification) + } + client.Run(ctx, rs) } diff --git a/cmd/sansshell-server/main.go b/cmd/sansshell-server/main.go index 08e02347..11510419 100644 --- a/cmd/sansshell-server/main.go +++ b/cmd/sansshell-server/main.go @@ -34,6 +34,7 @@ import ( "github.com/Snowflake-Labs/sansshell/auth/mtls" mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags" "github.com/Snowflake-Labs/sansshell/auth/opa" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/cmd/sansshell-server/server" "github.com/Snowflake-Labs/sansshell/cmd/util" ) @@ -42,12 +43,13 @@ var ( //go:embed default-policy.rego defaultPolicy string - policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.") - policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.") - hostport = flag.String("hostport", "localhost:50042", "Where to listen for connections.") - credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ","))) - verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging") - validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)") + policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.") + policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.") + hostport = flag.String("hostport", "localhost:50042", "Where to listen for connections.") + credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ","))) + verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging") + validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)") + justification = flag.Bool("justification", false, "If true then justification (which is logged and possibly validated) must be passed along in the client context Metadata with the key '"+rpcauth.ReqJustKey+"'") ) func main() { @@ -70,10 +72,11 @@ func main() { } rs := server.RunState{ - Logger: logger, - CredSource: *credSource, - Hostport: *hostport, - Policy: policy, + Logger: logger, + CredSource: *credSource, + Hostport: *hostport, + Policy: policy, + Justification: *justification, } server.Run(ctx, rs) } diff --git a/cmd/sansshell-server/server/server.go b/cmd/sansshell-server/server/server.go index 7c20b599..c58b781a 100644 --- a/cmd/sansshell-server/server/server.go +++ b/cmd/sansshell-server/server/server.go @@ -24,6 +24,7 @@ import ( "os" "github.com/Snowflake-Labs/sansshell/auth/mtls" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/server" "github.com/go-logr/logr" @@ -46,6 +47,13 @@ type RunState struct { Hostport string // Policy is an OPA policy for determining authz decisions. Policy string + // Justification if true requires justification to be set in the + // incoming RPC context Metadata (to the key defined in the telemetry package). + Justification bool + // JustificationFunc will be called if Justication is true and a justification + // entry is found. The supplied function can then do any validation it wants + // in order to ensure it's compliant. + JustificationFunc func(string) error } // Run takes the given context and RunState and starts up a sansshell server. @@ -57,7 +65,10 @@ func Run(ctx context.Context, rs RunState) { os.Exit(1) } - if err := server.Serve(rs.Hostport, creds, rs.Policy, rs.Logger); err != nil { + justificationHook := rpcauth.HookIf(rpcauth.JustificationHook(rs.JustificationFunc), func(input *rpcauth.RPCAuthInput) bool { + return rs.Justification + }) + if err := server.Serve(rs.Hostport, creds, rs.Policy, rs.Logger, justificationHook); err != nil { rs.Logger.Error(err, "server.Serve", "hostport", rs.Hostport) os.Exit(1) } diff --git a/proxy/proxy/proxy.go b/proxy/proxy/proxy.go index fe04f37c..78a70996 100644 --- a/proxy/proxy/proxy.go +++ b/proxy/proxy/proxy.go @@ -336,9 +336,16 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_ } err = stream.Send(req) - if err != nil { + // If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation + // for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream + // a server has ever handled so account for that here. + if err != nil && err != io.EOF { return nil, nil, status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err) } + if err != nil { + _, err := stream.Recv() + return nil, nil, status.Errorf(codes.Internal, "remote error from Send for %s - %v", method, err) + } resp, err := stream.Recv() if err != nil { return nil, nil, status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err) @@ -391,8 +398,7 @@ func (p *Conn) InvokeOneMany(ctx context.Context, method string, args interface{ if err := s.send(requestMsg); err != nil { return nil, err } - // TODO(): Put this back when the race in server.go is figured out. Causes a send and close - // on the channel processing server side which isn't allowed. + if err := s.closeClients(); err != nil { return nil, err } diff --git a/server/server.go b/server/server.go index ae2b906e..02701007 100644 --- a/server/server.go +++ b/server/server.go @@ -38,15 +38,19 @@ var ( mu sync.Mutex ) -// Serve wraps up BuildServer in a succinct API for callers -func Serve(hostport string, c credentials.TransportCredentials, policy string, logger logr.Logger) error { +// Serve wraps up BuildServer in a succinct API for callers passing along various parameters. It will automatically add +// an authz hook for HostNet based on the listener address. Additional hooks are passed along after this one. +func Serve(hostport string, c credentials.TransportCredentials, policy string, logger logr.Logger, authzHooks ...rpcauth.RPCAuthzHook) error { lis, err := net.Listen("tcp", hostport) if err != nil { return fmt.Errorf("failed to listen: %v", err) } mu.Lock() - srv, err = BuildServer(c, policy, lis.Addr(), logger) + h := []rpcauth.RPCAuthzHook{rpcauth.HostNetHook(lis.Addr())} + h = append(h, authzHooks...) + + srv, err = BuildServer(c, policy, logger, h...) mu.Unlock() if err != nil { return err @@ -62,11 +66,11 @@ func getSrv() *grpc.Server { return srv } -// BuildServer creates a gRPC server, attaches the OPA policy interceptor, +// BuildServer creates a gRPC server, attaches the OPA policy interceptor with supplied args and then // registers all of the imported SansShell modules. Separating this from Serve // primarily facilitates testing. -func BuildServer(c credentials.TransportCredentials, policy string, address net.Addr, logger logr.Logger) (*grpc.Server, error) { - authz, err := rpcauth.NewWithPolicy(context.Background(), policy, rpcauth.HostNetHook(address)) +func BuildServer(c credentials.TransportCredentials, policy string, logger logr.Logger, authzHooks ...rpcauth.RPCAuthzHook) (*grpc.Server, error) { + authz, err := rpcauth.NewWithPolicy(context.Background(), policy, authzHooks...) if err != nil { return nil, err } diff --git a/server/server_test.go b/server/server_test.go index a15527e7..172dc51f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" _ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server" lfpb "github.com/Snowflake-Labs/sansshell/services/localfile" _ "github.com/Snowflake-Labs/sansshell/services/localfile/server" @@ -67,7 +68,7 @@ func bufDialer(context.Context, string) (net.Conn, error) { func TestMain(m *testing.M) { lis = bufconn.Listen(bufSize) - s, err := BuildServer(nil, policy, lis.Addr(), logr.Discard()) + s, err := BuildServer(nil, policy, logr.Discard(), rpcauth.HostNetHook(lis.Addr())) if err != nil { log.Fatalf("Could not build server: %s", err) } @@ -83,7 +84,7 @@ func TestMain(m *testing.M) { func TestBuildServer(t *testing.T) { // Make sure a bad policy fails - _, err := BuildServer(nil, "", lis.Addr(), logr.Discard()) + _, err := BuildServer(nil, "", logr.Discard(), rpcauth.HostNetHook(lis.Addr())) t.Log(err) testutil.FatalOnNoErr("empty policy", err, t) } diff --git a/telemetry/telemetry.go b/telemetry/telemetry.go index c8504b42..5ec99f17 100644 --- a/telemetry/telemetry.go +++ b/telemetry/telemetry.go @@ -21,12 +21,18 @@ package telemetry import ( "context" "io" + "strings" "github.com/go-logr/logr" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" ) +const ( + sansshellMetadata = "sansshell-" +) + // UnaryClientLogInterceptor returns a new grpc.UnaryClientInterceptor that logs // outgoing requests using the supplied logger, as well as injecting it into the // context of the invoker. @@ -34,6 +40,8 @@ func UnaryClientLogInterceptor(logger logr.Logger) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { l := logger.WithValues("method", method, "target", cc.Target()) logCtx := logr.NewContext(ctx, l) + logCtx = passAlongMetadata(logCtx) + l = logMetadata(logCtx, l) l.Info("new client request") err := invoker(logCtx, method, req, reply, cc, opts...) if err != nil { @@ -44,13 +52,15 @@ func UnaryClientLogInterceptor(logger logr.Logger) grpc.UnaryClientInterceptor { } // StreamClientLogInterceptor returns a new grpc.StreamClientInterceptor that logs -// client requests using the supplied logger, as as as injecting into into the Context +// client requests using the supplied logger, as well as injecting it into the context // of the created stream. func StreamClientLogInterceptor(logger logr.Logger) grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { l := logger.WithValues("method", method, "target", cc.Target()) - l.Info("new client stream") logCtx := logr.NewContext(ctx, l) + logCtx = passAlongMetadata(logCtx) + l = logMetadata(logCtx, l) + l.Info("new client stream") stream, err := streamer(logCtx, desc, cc, method, opts...) if err != nil { l.Error(err, "create stream") @@ -64,6 +74,37 @@ func StreamClientLogInterceptor(logger logr.Logger) grpc.StreamClientInterceptor } } +func logMetadata(ctx context.Context, l logr.Logger) logr.Logger { + // Add any sansshell specific metadata to the logging we do. + md, ok := metadata.FromIncomingContext(ctx) + if ok { + for k, v := range md { + if strings.HasPrefix(k, sansshellMetadata) { + for _, val := range v { + l = l.WithValues(k, val) + } + } + } + } + return l +} + +func passAlongMetadata(ctx context.Context) context.Context { + // See if we got any metadata that has our prefix and pass it along + // downstream (i.e. proxy case). + md, ok := metadata.FromIncomingContext(ctx) + if ok { + for k, v := range md { + if strings.HasPrefix(k, sansshellMetadata) { + for _, val := range v { + ctx = metadata.AppendToOutgoingContext(ctx, k, val) + } + } + } + } + return ctx +} + type loggedClientStream struct { grpc.ClientStream ctx context.Context @@ -107,13 +148,16 @@ func (l *loggedClientStream) CloseSend() error { // UnaryServerLogInterceptor returns a new gprc.UnaryServerInterceptor that logs // incoming requests using the supplied logger, as well as injecting it into the -// context of downstream handlers. +// context of downstream handlers. If incoming calls require client side provided justification +// (which is logged) then the justification parameter should be true and a required +// key of ReqJustKey must be in the context when the interceptor runs. func UnaryServerLogInterceptor(logger logr.Logger) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { l := logger.WithValues("method", info.FullMethod) if p, ok := peer.FromContext(ctx); ok { l = l.WithValues("peer-address", p.Addr) } + l = logMetadata(ctx, l) l.Info("new request") logCtx := logr.NewContext(ctx, l) resp, err := handler(logCtx, req) @@ -126,13 +170,16 @@ func UnaryServerLogInterceptor(logger logr.Logger) grpc.UnaryServerInterceptor { // StreamServerLogInterceptor returns a new grpc.StreamServerInterceptor that logs // incoming streams using the supplied logger, and makes it available via the stream -// context to stream handlers. +// context to stream handlers. If incoming calls require client side provided justification +// (which is logged) then the justification parameter should be true and a required +// key of ReqJustKey must be in the context when the interceptor runs. func StreamServerLogInterceptor(logger logr.Logger) grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { l := logger.WithValues("method", info.FullMethod) if p, ok := peer.FromContext(ss.Context()); ok { l = l.WithValues("peer-address", p.Addr) } + l = logMetadata(ss.Context(), l) l.Info("new stream") stream := &loggedStream{ ServerStream: ss, diff --git a/telemetry/telemetry_test.go b/telemetry/telemetry_test.go index 8759ba90..6941d1ad 100644 --- a/telemetry/telemetry_test.go +++ b/telemetry/telemetry_test.go @@ -28,8 +28,10 @@ import ( "github.com/Snowflake-Labs/sansshell/testing/testutil" "github.com/go-logr/logr" "github.com/go-logr/logr/funcr" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/test/bufconn" ) @@ -79,6 +81,8 @@ func TestUnaryClient(t *testing.T) { wantMethod := "foo" wantError := "error" + mdKey := "sansshell-key" + mdVal := "some value" // Testing is a little weird. This will be called below when we call intercept. Then additional state // gets set on the error return we test below that. invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { @@ -86,21 +90,34 @@ func TestUnaryClient(t *testing.T) { if _, err := logr.FromContext(ctx); err != nil { t.Fatal("didn't get passed a logging context") } + // Test the outgoing context has the MD key. + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatal("can't find outgoing context") + } + if got, want := md[mdKey], []string{mdVal}; !cmp.Equal(got, want) { + t.Fatalf("Invalid MD key/value. Want %s/%v got %s/%v", mdKey, want, mdKey, got) + } if got, want := method, wantMethod; got != want { t.Fatalf("didn't get expected method. got %s want %s", got, want) } // The logging should have happened by now testLogging(t, args, "new client request") + testLogging(t, args, mdVal) // Return an error return errors.New(wantError) } - err = intercept(context.Background(), wantMethod, nil, nil, conn, invoker) + md := metadata.Pairs(mdKey, mdVal) + // This has to be an incoming context because there's no RPC layer to transform it. + ctx = metadata.NewIncomingContext(ctx, md) + err = intercept(ctx, wantMethod, nil, nil, conn, invoker) t.Log(err) testutil.FatalOnNoErr("intercept", err, t) if got, want := err.Error(), wantError; got != want { t.Fatalf("didn't get expected error. got %v want %v", got, want) } + } func TestStreamClient(t *testing.T) { @@ -118,7 +135,10 @@ func TestStreamClient(t *testing.T) { intercept := StreamClientLogInterceptor(logger) wantMethod := "sendError" + errorCase := wantMethod wantError := "error" + mdKey := "sansshell-key" + mdVal := "some value" // Testing is a little weird. This will be called below when we call intercept. Then additional state // gets set on the error return we test below that. streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { @@ -126,18 +146,36 @@ func TestStreamClient(t *testing.T) { if _, err := logr.FromContext(ctx); err != nil { t.Fatal("didn't get passed a logging context") } + t.Log(method) + if wantMethod != errorCase { + // Test the outgoing context has the MD key. + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatal("can't find outgoing context") + } + if got, want := md[mdKey], []string{mdVal}; !cmp.Equal(got, want) { + t.Fatalf("Invalid MD key/value. Want %s/%v got %s/%v", mdKey, want, mdKey, got) + } + } if got, want := method, wantMethod; got != want { t.Fatalf("didn't get expected method. got %s want %s", got, want) } // The logging should have happened by now testLogging(t, args, "new client stream") + testLogging(t, args, mdVal) + // Return an error if method == "sendError" { return nil, errors.New(wantError) } return &testutil.FakeClientStream{}, nil } - stream, err := intercept(context.Background(), nil, conn, wantMethod, streamer) + + md := metadata.Pairs(mdKey, mdVal) + // This has to be an incoming context because there's no RPC layer to transform it. + ctx = metadata.NewIncomingContext(ctx, md) + + stream, err := intercept(ctx, nil, conn, wantMethod, streamer) t.Log(err) testutil.FatalOnNoErr("streamer", err, t) if got, want := err.Error(), wantError; got != want { @@ -149,7 +187,8 @@ func TestStreamClient(t *testing.T) { // Shouldn't get an error now and we get a real stream. wantMethod = "bar" - stream, err = intercept(context.Background(), nil, conn, wantMethod, streamer) + ctx = metadata.NewIncomingContext(context.Background(), md) + stream, err = intercept(ctx, nil, conn, wantMethod, streamer) testutil.FatalOnErr("2nd streamer call", err, t) if _, err := logr.FromContext(stream.Context()); err != nil { @@ -187,6 +226,8 @@ func TestUnaryServer(t *testing.T) { wantMethod := "foo" wantError := "error" + mdKey := "sansshell-key" + mdVal := "some value" // Testing is a little weird. This will be called below when we call intercept. Then additional state // gets set on the error return we test below that. handler := func(ctx context.Context, req interface{}) (interface{}, error) { @@ -198,6 +239,8 @@ func TestUnaryServer(t *testing.T) { // The logging should have happened by now testLogging(t, args, wantMethod) testLogging(t, args, "new request") + testLogging(t, args, mdVal) + // Return an error return nil, errors.New(wantError) } @@ -206,6 +249,9 @@ func TestUnaryServer(t *testing.T) { FullMethod: wantMethod, } ctx := peer.NewContext(context.Background(), &peer.Peer{}) + md := metadata.Pairs(mdKey, mdVal) + // This has to be an incoming context because there's no RPC layer to transform it. + ctx = metadata.NewIncomingContext(ctx, md) _, err := intercept(ctx, nil, info, handler) t.Log(err) testutil.FatalOnNoErr("intercept", err, t) @@ -225,6 +271,8 @@ func TestStreamServer(t *testing.T) { wantMethod := "foo" wantError := "error" + mdKey := "sansshell-key" + mdVal := "some value" // Testing is a little weird. This will be called below when we call intercept. Then additional state // gets set on the error return we test below that. handler := func(srv interface{}, stream grpc.ServerStream) error { @@ -236,6 +284,7 @@ func TestStreamServer(t *testing.T) { // The logging should have happened by now testLogging(t, args, wantMethod) testLogging(t, args, "new stream") + testLogging(t, args, mdVal) if err := stream.SendMsg(nil); err == nil { t.Fatal("didn't get error from SendMsg on fake client stream") @@ -258,6 +307,10 @@ func TestStreamServer(t *testing.T) { FullMethod: wantMethod, } ctx := peer.NewContext(context.Background(), &peer.Peer{}) + md := metadata.Pairs(mdKey, mdVal) + // This has to be an incoming context because there's no RPC layer to transform it. + ctx = metadata.NewIncomingContext(ctx, md) + ss := &testutil.FakeServerStream{ Ctx: ctx, } diff --git a/testing/integrate.sh b/testing/integrate.sh index ee8a49b1..ec754796 100755 --- a/testing/integrate.sh +++ b/testing/integrate.sh @@ -472,13 +472,13 @@ check_status $? /dev/null policy check failed for server echo echo "Starting servers. Logs in ${LOGS}" -./bin/proxy-server -v=1 --root-ca=./auth/mtls/testdata/root.pem --server-cert=./auth/mtls/testdata/leaf.pem --server-key=./auth/mtls/testdata/leaf.key --client-cert=./auth/mtls/testdata/client.pem --client-key=./auth/mtls/testdata/client.key --policy-file=${LOGS}/policy --hostport=localhost:50043 >& ${LOGS}/proxy.log & +./bin/proxy-server -v=1 --justification --root-ca=./auth/mtls/testdata/root.pem --server-cert=./auth/mtls/testdata/leaf.pem --server-key=./auth/mtls/testdata/leaf.key --client-cert=./auth/mtls/testdata/client.pem --client-key=./auth/mtls/testdata/client.key --policy-file=${LOGS}/policy --hostport=localhost:50043 >& ${LOGS}/proxy.log & PROXY_PID=$! # Since we're controlling lifetime the shell can ignore this (avoids useless termination messages). disown %% # The server needs to be root in order for package installation tests (and the nodes run this as root). -sudo --preserve-env=AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY -b ./bin/sansshell-server -v=1 --root-ca=./auth/mtls/testdata/root.pem --server-cert=./auth/mtls/testdata/leaf.pem --server-key=./auth/mtls/testdata/leaf.key --policy-file=${LOGS}/policy --hostport=localhost:50042 >& ${LOGS}/server.log +sudo --preserve-env=AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY -b ./bin/sansshell-server -v=1 --justification --root-ca=./auth/mtls/testdata/root.pem --server-cert=./auth/mtls/testdata/leaf.pem --server-key=./auth/mtls/testdata/leaf.key --policy-file=${LOGS}/policy --hostport=localhost:50042 >& ${LOGS}/server.log # Skip if on github if [ -z "${ON_GITHUB}" ]; then @@ -501,7 +501,8 @@ else echo "Skipping remote cloud setup on Github" fi -SANSSH_NOPROXY="./bin/sanssh --root-ca=./auth/mtls/testdata/root.pem --client-cert=./auth/mtls/testdata/client.pem --client-key=./auth/mtls/testdata/client.key --timeout=120s" +SANSSH_NOPROXY_NO_JUSTIFY="./bin/sanssh --root-ca=./auth/mtls/testdata/root.pem --client-cert=./auth/mtls/testdata/client.pem --client-key=./auth/mtls/testdata/client.key --timeout=120s" +SANSSH_NOPROXY="${SANSSH_NOPROXY_NO_JUSTIFY} --justification=yes" SANSSH_PROXY="${SANSSH_NOPROXY} --proxy=localhost:50043" SINGLE_TARGET="--targets=localhost:50042" MULTI_TARGETS="--targets=localhost:50042,localhost:50042" @@ -528,6 +529,15 @@ if [ "${HEALTHY}" != "true" ]; then fi echo "Servers healthy" +# Now a simple test to validate justification requirements are working +# and are passing along from proxy->clients +echo "Expect a failure here about a lack of justification" +${SANSSH_NOPROXY_NO_JUSTIFY} --proxy=localhost:50043 ${MULTI_TARGETS} healthcheck validate +if [ $? != 1 ]; then + check_status 1 /dev/null missing justification failed +fi + + run_a_test false 50 ansible playbook --playbook=$PWD/services/ansible/server/testdata/test.yml --vars=path=/tmp,path2=/ run_a_test false 1 exec run /usr/bin/echo Hello World