From fa853129fa9b620ddcaeb1c428d6f9ad380ac450 Mon Sep 17 00:00:00 2001 From: James Chacon Date: Wed, 16 Feb 2022 15:23:07 -0800 Subject: [PATCH] Add a justification optional requirement. (#75) * Add a justification optional requirement. Servers can require a justification string to be passed in from client side metadata. They can also provide a user defined function to validate this as well. Plumb into servers, sanssh and integration tests. * Update usage information and important flags * Refactor into authz hook style. Telemetry just extracts anything which is sansshell-* from metadata and logs it. If it gets an error from the handler (which is where authz hooks) it'll bail at that point. So we get logging in one place and authz handled correctly in it's place. * Remove debugging * Convert server startup to take lists of authz hooks. --- auth/mtls/mtls_test.go | 3 +- auth/opa/rpcauth/hooks.go | 40 +++++++++++++++++- auth/opa/rpcauth/rpcauth_test.go | 53 ++++++++++++++++++++++++ cmd/proxy-server/main.go | 23 ++++++----- cmd/proxy-server/server/server.go | 13 +++++- cmd/sanssh/main.go | 19 ++++++--- cmd/sansshell-server/main.go | 23 ++++++----- cmd/sansshell-server/server/server.go | 13 +++++- proxy/proxy/proxy.go | 12 ++++-- server/server.go | 16 +++++--- server/server_test.go | 5 ++- telemetry/telemetry.go | 55 +++++++++++++++++++++++-- telemetry/telemetry_test.go | 59 +++++++++++++++++++++++++-- testing/integrate.sh | 16 ++++++-- 14 files changed, 300 insertions(+), 50 deletions(-) diff --git a/auth/mtls/mtls_test.go b/auth/mtls/mtls_test.go index 1c3d9e46..12494df2 100644 --- a/auth/mtls/mtls_test.go +++ b/auth/mtls/mtls_test.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/types/known/emptypb" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/server" hcpb "github.com/Snowflake-Labs/sansshell/services/healthcheck" _ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server" @@ -100,7 +101,7 @@ func serverWithPolicy(t *testing.T, policy string, CAPool *x509.CertPool) (*bufc creds, err := LoadServerTLS("testdata/leaf.pem", "testdata/leaf.key", CAPool) testutil.FatalOnErr("Failed to load client cert", err, t) lis := bufconn.Listen(bufSize) - s, err := server.BuildServer(creds, policy, lis.Addr(), logr.Discard()) + s, err := server.BuildServer(creds, policy, logr.Discard(), rpcauth.HostNetHook(lis.Addr())) testutil.FatalOnErr("Could not build server", err, t) listening := make(chan struct{}) go func() { diff --git a/auth/opa/rpcauth/hooks.go b/auth/opa/rpcauth/hooks.go index f134d10d..3baf2a7d 100644 --- a/auth/opa/rpcauth/hooks.go +++ b/auth/opa/rpcauth/hooks.go @@ -19,6 +19,9 @@ package rpcauth import ( "context" "net" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // RPCAuthzHookFunc implements RpcAuthzHook for a simple function @@ -53,7 +56,7 @@ func (c *conditionalHook) Hook(ctx context.Context, input *RPCAuthInput) error { return nil } -// HostNetHook returns an RpcAuthzHook that sets host networking information. +// HostNetHook returns an RPCAuthzHook that sets host networking information. func HostNetHook(addr net.Addr) RPCAuthzHook { return RPCAuthzHookFunc(func(ctx context.Context, input *RPCAuthInput) error { if input.Host == nil { @@ -63,3 +66,38 @@ func HostNetHook(addr net.Addr) RPCAuthzHook { return nil }) } + +const ( + // ReqJustKey is the key name that must exist in the incoming + // context metadata if client side provided justification is required. + ReqJustKey = "sansshell-justification" +) + +var ( + // ErrJustification is the error returned for missing justification. + ErrJustification = status.Error(codes.FailedPrecondition, "missing justification") +) + +// JustificationHook takes the given optional justification function and returns an RPCAuthzHook +// that validates if justification was included. If it is required and passes the optional validation function +// it will return nil, otherwise an error. +func JustificationHook(justificationFunc func(string) error) RPCAuthzHook { + return RPCAuthzHookFunc(func(ctx context.Context, input *RPCAuthInput) error { + // See if we got any metadata and if it contains the justification + var j string + v := input.Metadata[ReqJustKey] + if len(v) > 0 { + j = v[0] + } + + if j == "" { + return ErrJustification + } + if justificationFunc != nil { + if err := justificationFunc(j); err != nil { + return status.Errorf(codes.FailedPrecondition, "justification failed: %v", err) + } + } + return nil + }) +} diff --git a/auth/opa/rpcauth/rpcauth_test.go b/auth/opa/rpcauth/rpcauth_test.go index 1194eff5..a559b6c1 100644 --- a/auth/opa/rpcauth/rpcauth_test.go +++ b/auth/opa/rpcauth/rpcauth_test.go @@ -219,6 +219,59 @@ func TestAuthzHook(t *testing.T) { }, errFunc: wantStatusCode(codes.OK), }, + { + name: "network data allow with justification (no func)", + input: &RPCAuthInput{ + Method: "/Foo.Bar/Foo", + Metadata: metadata.MD{ + ReqJustKey: []string{"justification"}, + }, + }, + hooks: []RPCAuthzHook{ + HostNetHook(tcp), + JustificationHook(nil), + }, + errFunc: wantStatusCode(codes.OK), + }, + { + name: "network data allow with justification req but none given (no func)", + input: &RPCAuthInput{ + Method: "/Foo.Bar/Foo", + }, + hooks: []RPCAuthzHook{ + HostNetHook(tcp), + JustificationHook(nil), + }, + errFunc: wantStatusCode(codes.FailedPrecondition), + }, + { + name: "network data allow with justification (with func)", + input: &RPCAuthInput{ + Method: "/Foo.Bar/Foo", + Metadata: metadata.MD{ + ReqJustKey: []string{"justification"}, + }, + }, + hooks: []RPCAuthzHook{ + HostNetHook(tcp), + JustificationHook(func(string) error { return nil }), + }, + errFunc: wantStatusCode(codes.OK), + }, + { + name: "network data allow with justification req given and func fails", + input: &RPCAuthInput{ + Method: "/Foo.Bar/Foo", + Metadata: metadata.MD{ + ReqJustKey: []string{"justification"}, + }, + }, + hooks: []RPCAuthzHook{ + HostNetHook(tcp), + JustificationHook(func(string) error { return errors.New("error") }), + }, + errFunc: wantStatusCode(codes.FailedPrecondition), + }, { name: "conditional hook, triggered", input: &RPCAuthInput{Method: "/Some.Random/Method"}, diff --git a/cmd/proxy-server/main.go b/cmd/proxy-server/main.go index e6191f38..3496d417 100644 --- a/cmd/proxy-server/main.go +++ b/cmd/proxy-server/main.go @@ -28,6 +28,7 @@ import ( "github.com/Snowflake-Labs/sansshell/auth/mtls" mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags" "github.com/Snowflake-Labs/sansshell/auth/opa" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/cmd/proxy-server/server" "github.com/Snowflake-Labs/sansshell/cmd/util" "github.com/go-logr/stdr" @@ -37,12 +38,13 @@ var ( //go:embed default-policy.rego defaultPolicy string - policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.") - policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.") - hostport = flag.String("hostport", "localhost:50043", "Where to listen for connections.") - credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS creds (one of [%s])", strings.Join(mtls.Loaders(), ","))) - verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging") - validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)") + policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.") + policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.") + hostport = flag.String("hostport", "localhost:50043", "Where to listen for connections.") + credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS creds (one of [%s])", strings.Join(mtls.Loaders(), ","))) + verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging") + validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)") + justification = flag.Bool("justification", false, "If true then justification (which is logged and possibly validated) must be passed along in the client context Metadata with the key '"+rpcauth.ReqJustKey+"'") ) func main() { @@ -65,10 +67,11 @@ func main() { } rs := server.RunState{ - Logger: logger, - Policy: policy, - CredSource: *credSource, - Hostport: *hostport, + Logger: logger, + Policy: policy, + CredSource: *credSource, + Hostport: *hostport, + Justification: *justification, } server.Run(ctx, rs) } diff --git a/cmd/proxy-server/server/server.go b/cmd/proxy-server/server/server.go index e7915e31..70a13ad9 100644 --- a/cmd/proxy-server/server/server.go +++ b/cmd/proxy-server/server/server.go @@ -53,6 +53,13 @@ type RunState struct { CredSource string // Hostport is the host:port to run the server. Hostport string + // Justification if true requires justification to be set in the + // incoming RPC context Metadata (to the key defined in the telemetry package). + Justification bool + // JustificationFunc will be called if Justication is true and a justification + // entry is found. The supplied function can then do any validation it wants + // in order to ensure it's compliant. + JustificationFunc func(string) error } // Run takes the given context and RunState along with any authz hooks and starts up a sansshell proxy server @@ -80,7 +87,11 @@ func Run(ctx context.Context, rs RunState, hooks ...rpcauth.RPCAuthzHook) { addressHook := rpcauth.HookIf(rpcauth.HostNetHook(lis.Addr()), func(input *rpcauth.RPCAuthInput) bool { return input.Host == nil || input.Host.Net == nil }) - h := []rpcauth.RPCAuthzHook{addressHook} + justificationHook := rpcauth.HookIf(rpcauth.JustificationHook(rs.JustificationFunc), func(input *rpcauth.RPCAuthInput) bool { + return rs.Justification + }) + + h := []rpcauth.RPCAuthzHook{addressHook, justificationHook} h = append(h, hooks...) authz, err := rpcauth.NewWithPolicy(ctx, rs.Policy, h...) if err != nil { diff --git a/cmd/sanssh/main.go b/cmd/sanssh/main.go index 3e9d3139..474e5670 100644 --- a/cmd/sanssh/main.go +++ b/cmd/sanssh/main.go @@ -26,19 +26,22 @@ import ( "github.com/Snowflake-Labs/sansshell/auth/mtls" mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/cmd/sanssh/client" "github.com/Snowflake-Labs/sansshell/services/util" "github.com/google/subcommands" + "google.golang.org/grpc/metadata" ) var ( defaultAddress = "localhost:50042" defaultTimeout = 3 * time.Second - proxyAddr = flag.String("proxy", "", "Address to contact for proxy to sansshell-server. If blank a direct connection to the first entry in --targets will be made") - timeout = flag.Duration("timeout", defaultTimeout, "How long to wait for the command to complete") - credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ","))) - outputsDir = flag.String("output-dir", "", "If set defines a directory to emit output/errors from commands. Files will be generated based on target as destination/0 destination/0.error, etc.") + proxyAddr = flag.String("proxy", "", "Address to contact for proxy to sansshell-server. If blank a direct connection to the first entry in --targets will be made") + timeout = flag.Duration("timeout", defaultTimeout, "How long to wait for the command to complete") + credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ","))) + outputsDir = flag.String("output-dir", "", "If set defines a directory to emit output/errors from commands. Files will be generated based on target as destination/0 destination/0.error, etc.") + justification = flag.String("justification", "", "If non-empty will add the key '"+rpcauth.ReqJustKey+"' to the outgoing context Metadata to be passed along to the server for possbile validation and logging.") // targets will be bound to --targets for sending a single request to N nodes. targetsFlag util.StringSliceFlag @@ -62,6 +65,8 @@ func init() { subcommands.ImportantFlag("proxy") subcommands.ImportantFlag("targets") subcommands.ImportantFlag("outputs") + subcommands.ImportantFlag("output-dir") + subcommands.ImportantFlag("justification") } func main() { @@ -75,5 +80,9 @@ func main() { CredSource: *credSource, Timeout: *timeout, } - client.Run(context.Background(), rs) + ctx := context.Background() + if *justification != "" { + ctx = metadata.AppendToOutgoingContext(ctx, rpcauth.ReqJustKey, *justification) + } + client.Run(ctx, rs) } diff --git a/cmd/sansshell-server/main.go b/cmd/sansshell-server/main.go index 08e02347..11510419 100644 --- a/cmd/sansshell-server/main.go +++ b/cmd/sansshell-server/main.go @@ -34,6 +34,7 @@ import ( "github.com/Snowflake-Labs/sansshell/auth/mtls" mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags" "github.com/Snowflake-Labs/sansshell/auth/opa" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/cmd/sansshell-server/server" "github.com/Snowflake-Labs/sansshell/cmd/util" ) @@ -42,12 +43,13 @@ var ( //go:embed default-policy.rego defaultPolicy string - policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.") - policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.") - hostport = flag.String("hostport", "localhost:50042", "Where to listen for connections.") - credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ","))) - verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging") - validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)") + policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.") + policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.") + hostport = flag.String("hostport", "localhost:50042", "Where to listen for connections.") + credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ","))) + verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging") + validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)") + justification = flag.Bool("justification", false, "If true then justification (which is logged and possibly validated) must be passed along in the client context Metadata with the key '"+rpcauth.ReqJustKey+"'") ) func main() { @@ -70,10 +72,11 @@ func main() { } rs := server.RunState{ - Logger: logger, - CredSource: *credSource, - Hostport: *hostport, - Policy: policy, + Logger: logger, + CredSource: *credSource, + Hostport: *hostport, + Policy: policy, + Justification: *justification, } server.Run(ctx, rs) } diff --git a/cmd/sansshell-server/server/server.go b/cmd/sansshell-server/server/server.go index 7c20b599..c58b781a 100644 --- a/cmd/sansshell-server/server/server.go +++ b/cmd/sansshell-server/server/server.go @@ -24,6 +24,7 @@ import ( "os" "github.com/Snowflake-Labs/sansshell/auth/mtls" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" "github.com/Snowflake-Labs/sansshell/server" "github.com/go-logr/logr" @@ -46,6 +47,13 @@ type RunState struct { Hostport string // Policy is an OPA policy for determining authz decisions. Policy string + // Justification if true requires justification to be set in the + // incoming RPC context Metadata (to the key defined in the telemetry package). + Justification bool + // JustificationFunc will be called if Justication is true and a justification + // entry is found. The supplied function can then do any validation it wants + // in order to ensure it's compliant. + JustificationFunc func(string) error } // Run takes the given context and RunState and starts up a sansshell server. @@ -57,7 +65,10 @@ func Run(ctx context.Context, rs RunState) { os.Exit(1) } - if err := server.Serve(rs.Hostport, creds, rs.Policy, rs.Logger); err != nil { + justificationHook := rpcauth.HookIf(rpcauth.JustificationHook(rs.JustificationFunc), func(input *rpcauth.RPCAuthInput) bool { + return rs.Justification + }) + if err := server.Serve(rs.Hostport, creds, rs.Policy, rs.Logger, justificationHook); err != nil { rs.Logger.Error(err, "server.Serve", "hostport", rs.Hostport) os.Exit(1) } diff --git a/proxy/proxy/proxy.go b/proxy/proxy/proxy.go index fe04f37c..78a70996 100644 --- a/proxy/proxy/proxy.go +++ b/proxy/proxy/proxy.go @@ -336,9 +336,16 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_ } err = stream.Send(req) - if err != nil { + // If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation + // for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream + // a server has ever handled so account for that here. + if err != nil && err != io.EOF { return nil, nil, status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err) } + if err != nil { + _, err := stream.Recv() + return nil, nil, status.Errorf(codes.Internal, "remote error from Send for %s - %v", method, err) + } resp, err := stream.Recv() if err != nil { return nil, nil, status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err) @@ -391,8 +398,7 @@ func (p *Conn) InvokeOneMany(ctx context.Context, method string, args interface{ if err := s.send(requestMsg); err != nil { return nil, err } - // TODO(): Put this back when the race in server.go is figured out. Causes a send and close - // on the channel processing server side which isn't allowed. + if err := s.closeClients(); err != nil { return nil, err } diff --git a/server/server.go b/server/server.go index ae2b906e..02701007 100644 --- a/server/server.go +++ b/server/server.go @@ -38,15 +38,19 @@ var ( mu sync.Mutex ) -// Serve wraps up BuildServer in a succinct API for callers -func Serve(hostport string, c credentials.TransportCredentials, policy string, logger logr.Logger) error { +// Serve wraps up BuildServer in a succinct API for callers passing along various parameters. It will automatically add +// an authz hook for HostNet based on the listener address. Additional hooks are passed along after this one. +func Serve(hostport string, c credentials.TransportCredentials, policy string, logger logr.Logger, authzHooks ...rpcauth.RPCAuthzHook) error { lis, err := net.Listen("tcp", hostport) if err != nil { return fmt.Errorf("failed to listen: %v", err) } mu.Lock() - srv, err = BuildServer(c, policy, lis.Addr(), logger) + h := []rpcauth.RPCAuthzHook{rpcauth.HostNetHook(lis.Addr())} + h = append(h, authzHooks...) + + srv, err = BuildServer(c, policy, logger, h...) mu.Unlock() if err != nil { return err @@ -62,11 +66,11 @@ func getSrv() *grpc.Server { return srv } -// BuildServer creates a gRPC server, attaches the OPA policy interceptor, +// BuildServer creates a gRPC server, attaches the OPA policy interceptor with supplied args and then // registers all of the imported SansShell modules. Separating this from Serve // primarily facilitates testing. -func BuildServer(c credentials.TransportCredentials, policy string, address net.Addr, logger logr.Logger) (*grpc.Server, error) { - authz, err := rpcauth.NewWithPolicy(context.Background(), policy, rpcauth.HostNetHook(address)) +func BuildServer(c credentials.TransportCredentials, policy string, logger logr.Logger, authzHooks ...rpcauth.RPCAuthzHook) (*grpc.Server, error) { + authz, err := rpcauth.NewWithPolicy(context.Background(), policy, authzHooks...) if err != nil { return nil, err } diff --git a/server/server_test.go b/server/server_test.go index a15527e7..172dc51f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" _ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server" lfpb "github.com/Snowflake-Labs/sansshell/services/localfile" _ "github.com/Snowflake-Labs/sansshell/services/localfile/server" @@ -67,7 +68,7 @@ func bufDialer(context.Context, string) (net.Conn, error) { func TestMain(m *testing.M) { lis = bufconn.Listen(bufSize) - s, err := BuildServer(nil, policy, lis.Addr(), logr.Discard()) + s, err := BuildServer(nil, policy, logr.Discard(), rpcauth.HostNetHook(lis.Addr())) if err != nil { log.Fatalf("Could not build server: %s", err) } @@ -83,7 +84,7 @@ func TestMain(m *testing.M) { func TestBuildServer(t *testing.T) { // Make sure a bad policy fails - _, err := BuildServer(nil, "", lis.Addr(), logr.Discard()) + _, err := BuildServer(nil, "", logr.Discard(), rpcauth.HostNetHook(lis.Addr())) t.Log(err) testutil.FatalOnNoErr("empty policy", err, t) } diff --git a/telemetry/telemetry.go b/telemetry/telemetry.go index c8504b42..5ec99f17 100644 --- a/telemetry/telemetry.go +++ b/telemetry/telemetry.go @@ -21,12 +21,18 @@ package telemetry import ( "context" "io" + "strings" "github.com/go-logr/logr" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" ) +const ( + sansshellMetadata = "sansshell-" +) + // UnaryClientLogInterceptor returns a new grpc.UnaryClientInterceptor that logs // outgoing requests using the supplied logger, as well as injecting it into the // context of the invoker. @@ -34,6 +40,8 @@ func UnaryClientLogInterceptor(logger logr.Logger) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { l := logger.WithValues("method", method, "target", cc.Target()) logCtx := logr.NewContext(ctx, l) + logCtx = passAlongMetadata(logCtx) + l = logMetadata(logCtx, l) l.Info("new client request") err := invoker(logCtx, method, req, reply, cc, opts...) if err != nil { @@ -44,13 +52,15 @@ func UnaryClientLogInterceptor(logger logr.Logger) grpc.UnaryClientInterceptor { } // StreamClientLogInterceptor returns a new grpc.StreamClientInterceptor that logs -// client requests using the supplied logger, as as as injecting into into the Context +// client requests using the supplied logger, as well as injecting it into the context // of the created stream. func StreamClientLogInterceptor(logger logr.Logger) grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { l := logger.WithValues("method", method, "target", cc.Target()) - l.Info("new client stream") logCtx := logr.NewContext(ctx, l) + logCtx = passAlongMetadata(logCtx) + l = logMetadata(logCtx, l) + l.Info("new client stream") stream, err := streamer(logCtx, desc, cc, method, opts...) if err != nil { l.Error(err, "create stream") @@ -64,6 +74,37 @@ func StreamClientLogInterceptor(logger logr.Logger) grpc.StreamClientInterceptor } } +func logMetadata(ctx context.Context, l logr.Logger) logr.Logger { + // Add any sansshell specific metadata to the logging we do. + md, ok := metadata.FromIncomingContext(ctx) + if ok { + for k, v := range md { + if strings.HasPrefix(k, sansshellMetadata) { + for _, val := range v { + l = l.WithValues(k, val) + } + } + } + } + return l +} + +func passAlongMetadata(ctx context.Context) context.Context { + // See if we got any metadata that has our prefix and pass it along + // downstream (i.e. proxy case). + md, ok := metadata.FromIncomingContext(ctx) + if ok { + for k, v := range md { + if strings.HasPrefix(k, sansshellMetadata) { + for _, val := range v { + ctx = metadata.AppendToOutgoingContext(ctx, k, val) + } + } + } + } + return ctx +} + type loggedClientStream struct { grpc.ClientStream ctx context.Context @@ -107,13 +148,16 @@ func (l *loggedClientStream) CloseSend() error { // UnaryServerLogInterceptor returns a new gprc.UnaryServerInterceptor that logs // incoming requests using the supplied logger, as well as injecting it into the -// context of downstream handlers. +// context of downstream handlers. If incoming calls require client side provided justification +// (which is logged) then the justification parameter should be true and a required +// key of ReqJustKey must be in the context when the interceptor runs. func UnaryServerLogInterceptor(logger logr.Logger) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { l := logger.WithValues("method", info.FullMethod) if p, ok := peer.FromContext(ctx); ok { l = l.WithValues("peer-address", p.Addr) } + l = logMetadata(ctx, l) l.Info("new request") logCtx := logr.NewContext(ctx, l) resp, err := handler(logCtx, req) @@ -126,13 +170,16 @@ func UnaryServerLogInterceptor(logger logr.Logger) grpc.UnaryServerInterceptor { // StreamServerLogInterceptor returns a new grpc.StreamServerInterceptor that logs // incoming streams using the supplied logger, and makes it available via the stream -// context to stream handlers. +// context to stream handlers. If incoming calls require client side provided justification +// (which is logged) then the justification parameter should be true and a required +// key of ReqJustKey must be in the context when the interceptor runs. func StreamServerLogInterceptor(logger logr.Logger) grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { l := logger.WithValues("method", info.FullMethod) if p, ok := peer.FromContext(ss.Context()); ok { l = l.WithValues("peer-address", p.Addr) } + l = logMetadata(ss.Context(), l) l.Info("new stream") stream := &loggedStream{ ServerStream: ss, diff --git a/telemetry/telemetry_test.go b/telemetry/telemetry_test.go index 8759ba90..6941d1ad 100644 --- a/telemetry/telemetry_test.go +++ b/telemetry/telemetry_test.go @@ -28,8 +28,10 @@ import ( "github.com/Snowflake-Labs/sansshell/testing/testutil" "github.com/go-logr/logr" "github.com/go-logr/logr/funcr" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/test/bufconn" ) @@ -79,6 +81,8 @@ func TestUnaryClient(t *testing.T) { wantMethod := "foo" wantError := "error" + mdKey := "sansshell-key" + mdVal := "some value" // Testing is a little weird. This will be called below when we call intercept. Then additional state // gets set on the error return we test below that. invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { @@ -86,21 +90,34 @@ func TestUnaryClient(t *testing.T) { if _, err := logr.FromContext(ctx); err != nil { t.Fatal("didn't get passed a logging context") } + // Test the outgoing context has the MD key. + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatal("can't find outgoing context") + } + if got, want := md[mdKey], []string{mdVal}; !cmp.Equal(got, want) { + t.Fatalf("Invalid MD key/value. Want %s/%v got %s/%v", mdKey, want, mdKey, got) + } if got, want := method, wantMethod; got != want { t.Fatalf("didn't get expected method. got %s want %s", got, want) } // The logging should have happened by now testLogging(t, args, "new client request") + testLogging(t, args, mdVal) // Return an error return errors.New(wantError) } - err = intercept(context.Background(), wantMethod, nil, nil, conn, invoker) + md := metadata.Pairs(mdKey, mdVal) + // This has to be an incoming context because there's no RPC layer to transform it. + ctx = metadata.NewIncomingContext(ctx, md) + err = intercept(ctx, wantMethod, nil, nil, conn, invoker) t.Log(err) testutil.FatalOnNoErr("intercept", err, t) if got, want := err.Error(), wantError; got != want { t.Fatalf("didn't get expected error. got %v want %v", got, want) } + } func TestStreamClient(t *testing.T) { @@ -118,7 +135,10 @@ func TestStreamClient(t *testing.T) { intercept := StreamClientLogInterceptor(logger) wantMethod := "sendError" + errorCase := wantMethod wantError := "error" + mdKey := "sansshell-key" + mdVal := "some value" // Testing is a little weird. This will be called below when we call intercept. Then additional state // gets set on the error return we test below that. streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { @@ -126,18 +146,36 @@ func TestStreamClient(t *testing.T) { if _, err := logr.FromContext(ctx); err != nil { t.Fatal("didn't get passed a logging context") } + t.Log(method) + if wantMethod != errorCase { + // Test the outgoing context has the MD key. + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + t.Fatal("can't find outgoing context") + } + if got, want := md[mdKey], []string{mdVal}; !cmp.Equal(got, want) { + t.Fatalf("Invalid MD key/value. Want %s/%v got %s/%v", mdKey, want, mdKey, got) + } + } if got, want := method, wantMethod; got != want { t.Fatalf("didn't get expected method. got %s want %s", got, want) } // The logging should have happened by now testLogging(t, args, "new client stream") + testLogging(t, args, mdVal) + // Return an error if method == "sendError" { return nil, errors.New(wantError) } return &testutil.FakeClientStream{}, nil } - stream, err := intercept(context.Background(), nil, conn, wantMethod, streamer) + + md := metadata.Pairs(mdKey, mdVal) + // This has to be an incoming context because there's no RPC layer to transform it. + ctx = metadata.NewIncomingContext(ctx, md) + + stream, err := intercept(ctx, nil, conn, wantMethod, streamer) t.Log(err) testutil.FatalOnNoErr("streamer", err, t) if got, want := err.Error(), wantError; got != want { @@ -149,7 +187,8 @@ func TestStreamClient(t *testing.T) { // Shouldn't get an error now and we get a real stream. wantMethod = "bar" - stream, err = intercept(context.Background(), nil, conn, wantMethod, streamer) + ctx = metadata.NewIncomingContext(context.Background(), md) + stream, err = intercept(ctx, nil, conn, wantMethod, streamer) testutil.FatalOnErr("2nd streamer call", err, t) if _, err := logr.FromContext(stream.Context()); err != nil { @@ -187,6 +226,8 @@ func TestUnaryServer(t *testing.T) { wantMethod := "foo" wantError := "error" + mdKey := "sansshell-key" + mdVal := "some value" // Testing is a little weird. This will be called below when we call intercept. Then additional state // gets set on the error return we test below that. handler := func(ctx context.Context, req interface{}) (interface{}, error) { @@ -198,6 +239,8 @@ func TestUnaryServer(t *testing.T) { // The logging should have happened by now testLogging(t, args, wantMethod) testLogging(t, args, "new request") + testLogging(t, args, mdVal) + // Return an error return nil, errors.New(wantError) } @@ -206,6 +249,9 @@ func TestUnaryServer(t *testing.T) { FullMethod: wantMethod, } ctx := peer.NewContext(context.Background(), &peer.Peer{}) + md := metadata.Pairs(mdKey, mdVal) + // This has to be an incoming context because there's no RPC layer to transform it. + ctx = metadata.NewIncomingContext(ctx, md) _, err := intercept(ctx, nil, info, handler) t.Log(err) testutil.FatalOnNoErr("intercept", err, t) @@ -225,6 +271,8 @@ func TestStreamServer(t *testing.T) { wantMethod := "foo" wantError := "error" + mdKey := "sansshell-key" + mdVal := "some value" // Testing is a little weird. This will be called below when we call intercept. Then additional state // gets set on the error return we test below that. handler := func(srv interface{}, stream grpc.ServerStream) error { @@ -236,6 +284,7 @@ func TestStreamServer(t *testing.T) { // The logging should have happened by now testLogging(t, args, wantMethod) testLogging(t, args, "new stream") + testLogging(t, args, mdVal) if err := stream.SendMsg(nil); err == nil { t.Fatal("didn't get error from SendMsg on fake client stream") @@ -258,6 +307,10 @@ func TestStreamServer(t *testing.T) { FullMethod: wantMethod, } ctx := peer.NewContext(context.Background(), &peer.Peer{}) + md := metadata.Pairs(mdKey, mdVal) + // This has to be an incoming context because there's no RPC layer to transform it. + ctx = metadata.NewIncomingContext(ctx, md) + ss := &testutil.FakeServerStream{ Ctx: ctx, } diff --git a/testing/integrate.sh b/testing/integrate.sh index ee8a49b1..ec754796 100755 --- a/testing/integrate.sh +++ b/testing/integrate.sh @@ -472,13 +472,13 @@ check_status $? /dev/null policy check failed for server echo echo "Starting servers. Logs in ${LOGS}" -./bin/proxy-server -v=1 --root-ca=./auth/mtls/testdata/root.pem --server-cert=./auth/mtls/testdata/leaf.pem --server-key=./auth/mtls/testdata/leaf.key --client-cert=./auth/mtls/testdata/client.pem --client-key=./auth/mtls/testdata/client.key --policy-file=${LOGS}/policy --hostport=localhost:50043 >& ${LOGS}/proxy.log & +./bin/proxy-server -v=1 --justification --root-ca=./auth/mtls/testdata/root.pem --server-cert=./auth/mtls/testdata/leaf.pem --server-key=./auth/mtls/testdata/leaf.key --client-cert=./auth/mtls/testdata/client.pem --client-key=./auth/mtls/testdata/client.key --policy-file=${LOGS}/policy --hostport=localhost:50043 >& ${LOGS}/proxy.log & PROXY_PID=$! # Since we're controlling lifetime the shell can ignore this (avoids useless termination messages). disown %% # The server needs to be root in order for package installation tests (and the nodes run this as root). -sudo --preserve-env=AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY -b ./bin/sansshell-server -v=1 --root-ca=./auth/mtls/testdata/root.pem --server-cert=./auth/mtls/testdata/leaf.pem --server-key=./auth/mtls/testdata/leaf.key --policy-file=${LOGS}/policy --hostport=localhost:50042 >& ${LOGS}/server.log +sudo --preserve-env=AWS_ACCESS_KEY_ID,AWS_SECRET_ACCESS_KEY -b ./bin/sansshell-server -v=1 --justification --root-ca=./auth/mtls/testdata/root.pem --server-cert=./auth/mtls/testdata/leaf.pem --server-key=./auth/mtls/testdata/leaf.key --policy-file=${LOGS}/policy --hostport=localhost:50042 >& ${LOGS}/server.log # Skip if on github if [ -z "${ON_GITHUB}" ]; then @@ -501,7 +501,8 @@ else echo "Skipping remote cloud setup on Github" fi -SANSSH_NOPROXY="./bin/sanssh --root-ca=./auth/mtls/testdata/root.pem --client-cert=./auth/mtls/testdata/client.pem --client-key=./auth/mtls/testdata/client.key --timeout=120s" +SANSSH_NOPROXY_NO_JUSTIFY="./bin/sanssh --root-ca=./auth/mtls/testdata/root.pem --client-cert=./auth/mtls/testdata/client.pem --client-key=./auth/mtls/testdata/client.key --timeout=120s" +SANSSH_NOPROXY="${SANSSH_NOPROXY_NO_JUSTIFY} --justification=yes" SANSSH_PROXY="${SANSSH_NOPROXY} --proxy=localhost:50043" SINGLE_TARGET="--targets=localhost:50042" MULTI_TARGETS="--targets=localhost:50042,localhost:50042" @@ -528,6 +529,15 @@ if [ "${HEALTHY}" != "true" ]; then fi echo "Servers healthy" +# Now a simple test to validate justification requirements are working +# and are passing along from proxy->clients +echo "Expect a failure here about a lack of justification" +${SANSSH_NOPROXY_NO_JUSTIFY} --proxy=localhost:50043 ${MULTI_TARGETS} healthcheck validate +if [ $? != 1 ]; then + check_status 1 /dev/null missing justification failed +fi + + run_a_test false 50 ansible playbook --playbook=$PWD/services/ansible/server/testdata/test.yml --vars=path=/tmp,path2=/ run_a_test false 1 exec run /usr/bin/echo Hello World