From 5f1ff80b473140bc3edb92b16b317f4008d062d3 Mon Sep 17 00:00:00 2001 From: Steven Rhodes Date: Fri, 10 Nov 2023 09:20:49 -0800 Subject: [PATCH] Add server and client implementations for MPA. (#364) These changes are sufficient for MPA when using a direct connection to the server. Here's a few sample commands you can run in parallel to try it out. ``` go run ./cmd/sansshell-server go run ./cmd/sanssh -client-cert ./auth/mtls/testdata/client.pem -client-key ./auth/mtls/testdata/client.key -mpa -targets localhost healthcheck validate go run ./cmd/sanssh -client-cert ./services/mpa/testdata/approver.pem -client-key ./services/mpa/testdata/approver.key -targets localhost mpa approve a59c2fef-748944da-336c9d35 ``` I've added some new testdata certs because I'm forbidding cases where approver == requester. I've updated the sansshell server code to allow any request if it's requested by our "normal" client cert and approved by our "approver" client cert. The output of `-mpa` prints a nonconfigurable help message to stderr while waiting on approval. If the command is already approved, the message won't show up. ``` $ sanssh -mpa -targets localhost healthcheck validate Multi party auth requested, ask an approver to run: sanssh --targets localhost:50042 mpa approve a59c2fef-748944da-336c9d35 Target localhost:50042 (0) healthy` ``` This implements the client and server portion, but not the proxy portion. The proxy part mostly builds on top of what I have here and will take advantage of some other features I'm implementing. - https://github.com/Snowflake-Labs/sansshell/pull/361 for implementing the proxy equivalent of `ServerMPAAuthzHook()` - https://github.com/Snowflake-Labs/sansshell/pull/358 for implementing the proxy equivalents of `mpahooks.UnaryClientIntercepter()` and `mpahooks.StreamClientIntercepter()` - https://github.com/Snowflake-Labs/sansshell/pull/359 so that MPA can use the identity of the caller to the proxy instead of the identity of the proxy. Part of https://github.com/Snowflake-Labs/sansshell/issues/346 --- auth/opa/rpcauth/input.go | 22 +- auth/opa/rpcauth/rpcauth.go | 29 +++ cmd/sanssh/client/client.go | 7 + cmd/sanssh/main.go | 4 + cmd/sansshell-server/default-policy.rego | 10 + cmd/sansshell-server/main.go | 2 + go.mod | 1 + go.sum | 2 + services/mpa/client/client.go | 307 +++++++++++++++++++++++ services/mpa/mpahooks/mpahooks.go | 221 ++++++++++++++++ services/mpa/mpahooks/mpahooks_test.go | 289 +++++++++++++++++++++ services/mpa/server/server.go | 295 ++++++++++++++++++++++ services/mpa/server/server_test.go | 223 ++++++++++++++++ services/mpa/testdata/README.md | 13 + services/mpa/testdata/approver.key | 28 +++ services/mpa/testdata/approver.pem | 17 ++ 16 files changed, 1469 insertions(+), 1 deletion(-) create mode 100644 services/mpa/client/client.go create mode 100644 services/mpa/mpahooks/mpahooks.go create mode 100644 services/mpa/mpahooks/mpahooks_test.go create mode 100644 services/mpa/server/server.go create mode 100644 services/mpa/server/server_test.go create mode 100644 services/mpa/testdata/README.md create mode 100644 services/mpa/testdata/approver.key create mode 100644 services/mpa/testdata/approver.pem diff --git a/auth/opa/rpcauth/input.go b/auth/opa/rpcauth/input.go index edd8c9d7..bbb7bdc2 100644 --- a/auth/opa/rpcauth/input.go +++ b/auth/opa/rpcauth/input.go @@ -32,7 +32,6 @@ import ( ) // RPCAuthInput is used as policy input to validate Sansshell RPCs -// NOTE: RPCAuthInputForLogging must be updated when this changes. type RPCAuthInput struct { // The GRPC method name, as '/Package.Service/Method' Method string `json:"method"` @@ -52,6 +51,9 @@ type RPCAuthInput struct { // Information about the host serving the RPC. Host *HostAuthInput `json:"host"` + // Information about approvers when using multi-party authentication. + Approvers []*PrincipalAuthInput `json:"approvers"` + // Information about the environment in which the policy evaluation is // happening. Environment *EnvironmentInput `json:"environment"` @@ -153,9 +155,27 @@ func NewRPCAuthInput(ctx context.Context, method string, req proto.Message) (*RP return out, nil } +type peerInfoKey struct{} + +// AddPeerToContext adds a PeerAuthInput to the context. This is typically +// added by the rpcauth grpc interceptors. +func AddPeerToContext(ctx context.Context, p *PeerAuthInput) context.Context { + if p == nil { + return ctx + } + return context.WithValue(ctx, peerInfoKey{}, p) +} + // PeerInputFromContext populates peer information from the supplied // context, if available. func PeerInputFromContext(ctx context.Context) *PeerAuthInput { + // If this runs after rpcauth hooks, we can return richer data that includes + // information added by the hooks. + cached, ok := ctx.Value(peerInfoKey{}).(*PeerAuthInput) + if ok { + return cached + } + out := &PeerAuthInput{} p, ok := peer.FromContext(ctx) if !ok { diff --git a/auth/opa/rpcauth/rpcauth.go b/auth/opa/rpcauth/rpcauth.go index e118532e..f41075ef 100644 --- a/auth/opa/rpcauth/rpcauth.go +++ b/auth/opa/rpcauth/rpcauth.go @@ -22,6 +22,7 @@ import ( "context" "fmt" "strings" + "sync" "github.com/go-logr/logr" "go.opentelemetry.io/otel/attribute" @@ -171,6 +172,7 @@ func (g *Authorizer) Authorize(ctx context.Context, req interface{}, info *grpc. if err := g.Eval(ctx, authInput); err != nil { return nil, err } + ctx = AddPeerToContext(ctx, authInput.Peer) return handler(ctx, req) } @@ -187,6 +189,7 @@ func (g *Authorizer) AuthorizeClient(ctx context.Context, method string, req, re if err := g.Eval(ctx, authInput); err != nil { return err } + ctx = AddPeerToContext(ctx, authInput.Peer) return invoker(ctx, method, req, reply, cc, opts...) } @@ -209,6 +212,16 @@ type wrappedClientStream struct { grpc.ClientStream method string authz *Authorizer + + peerMu sync.Mutex + lastPeerAuthInput *PeerAuthInput +} + +func (e *wrappedClientStream) Context() context.Context { + e.peerMu.Lock() + ctx := AddPeerToContext(e.ClientStream.Context(), e.lastPeerAuthInput) + e.peerMu.Unlock() + return ctx } // see: grpc.ClientStream.SendMsg @@ -225,6 +238,9 @@ func (e *wrappedClientStream) SendMsg(req interface{}) error { if err := e.authz.Eval(ctx, authInput); err != nil { return err } + e.peerMu.Lock() + e.lastPeerAuthInput = authInput.Peer + e.peerMu.Unlock() return e.ClientStream.SendMsg(req) } @@ -243,6 +259,16 @@ type wrappedStream struct { grpc.ServerStream info *grpc.StreamServerInfo authz *Authorizer + + peerMu sync.Mutex + lastPeerAuthInput *PeerAuthInput +} + +func (e *wrappedStream) Context() context.Context { + e.peerMu.Lock() + ctx := AddPeerToContext(e.ServerStream.Context(), e.lastPeerAuthInput) + e.peerMu.Unlock() + return ctx } // see: grpc.ServerStream.RecvMsg @@ -266,5 +292,8 @@ func (e *wrappedStream) RecvMsg(req interface{}) error { if err := e.authz.Eval(ctx, authInput); err != nil { return err } + e.peerMu.Lock() + e.lastPeerAuthInput = authInput.Peer + e.peerMu.Unlock() return nil } diff --git a/cmd/sanssh/client/client.go b/cmd/sanssh/client/client.go index 7a6ffe65..4b2d60b4 100644 --- a/cmd/sanssh/client/client.go +++ b/cmd/sanssh/client/client.go @@ -37,6 +37,7 @@ import ( "github.com/Snowflake-Labs/sansshell/proxy/proxy" cmdUtil "github.com/Snowflake-Labs/sansshell/cmd/util" + "github.com/Snowflake-Labs/sansshell/services/mpa/mpahooks" "github.com/Snowflake-Labs/sansshell/services/util" ) @@ -74,6 +75,8 @@ type RunState struct { // BatchSize if non-zero will do the requested operation to the targets but in // N calls to the proxy where N is the target list size divided by BatchSize. BatchSize int + // If true, add an interceptor that performs the multi-party auth flow + EnableMPA bool } const ( @@ -317,6 +320,10 @@ func Run(ctx context.Context, rs RunState) { streamInterceptors = append(streamInterceptors, clientAuthz.AuthorizeClientStream) unaryInterceptors = append(unaryInterceptors, clientAuthz.AuthorizeClient) } + if rs.EnableMPA { + unaryInterceptors = append(unaryInterceptors, mpahooks.UnaryClientIntercepter()) + streamInterceptors = append(streamInterceptors, mpahooks.StreamClientIntercepter()) + } // timeout interceptor should be the last item in ops so that it's executed first. streamInterceptors = append(streamInterceptors, StreamClientTimeoutInterceptor(rs.IdleTimeout)) unaryInterceptors = append(unaryInterceptors, UnaryClientTimeoutInterceptor(rs.IdleTimeout)) diff --git a/cmd/sanssh/main.go b/cmd/sanssh/main.go index 591502cf..41b43177 100644 --- a/cmd/sanssh/main.go +++ b/cmd/sanssh/main.go @@ -47,6 +47,7 @@ import ( _ "github.com/Snowflake-Labs/sansshell/services/healthcheck/client" _ "github.com/Snowflake-Labs/sansshell/services/httpoverrpc/client" _ "github.com/Snowflake-Labs/sansshell/services/localfile/client" + _ "github.com/Snowflake-Labs/sansshell/services/mpa/client" _ "github.com/Snowflake-Labs/sansshell/services/packages/client" _ "github.com/Snowflake-Labs/sansshell/services/power/client" _ "github.com/Snowflake-Labs/sansshell/services/process/client" @@ -84,6 +85,7 @@ If port is blank the default of %d will be used`, proxyEnv, defaultProxyPort)) verbosity = flag.Int("v", -1, "Verbosity level. > 0 indicates more extensive logging") prefixHeader = flag.Bool("h", false, "If true prefix each line of output with '-: '") batchSize = flag.Int("batch-size", 0, "If non-zero will perform the proxy->target work in batches of this size (with any remainder done at the end).") + mpa = flag.Bool("mpa", false, "Request multi-party approval for commands. This will create an MPA request, wait for approval, and then execute the command.") // targets will be bound to --targets for sending a single request to N nodes. targetsFlag util.StringSliceCommaOrWhitespaceFlag @@ -118,6 +120,7 @@ func init() { subcommands.ImportantFlag("justification") subcommands.ImportantFlag("client-policy") subcommands.ImportantFlag("client-policy-file") + subcommands.ImportantFlag("mpa") subcommands.ImportantFlag("v") } @@ -192,6 +195,7 @@ func main() { ClientPolicy: clientPolicy, PrefixOutput: *prefixHeader, BatchSize: *batchSize, + EnableMPA: *mpa, } ctx := logr.NewContext(context.Background(), logger) diff --git a/cmd/sansshell-server/default-policy.rego b/cmd/sansshell-server/default-policy.rego index cb43b3ce..9166703d 100644 --- a/cmd/sansshell-server/default-policy.rego +++ b/cmd/sansshell-server/default-policy.rego @@ -84,3 +84,13 @@ allow { allow { input.method = "/SysInfo.SysInfo/Dmesg" } + +# Allow anything with MPA +allow { + input.peer.principal.id = "sanssh" + input.approvers[_].id = "approver" +} + +allow { + startswith(input.method, "/Mpa.Mpa/") +} diff --git a/cmd/sansshell-server/main.go b/cmd/sansshell-server/main.go index dbfdff6c..b3e7edce 100644 --- a/cmd/sansshell-server/main.go +++ b/cmd/sansshell-server/main.go @@ -64,6 +64,7 @@ import ( fdbserver "github.com/Snowflake-Labs/sansshell/services/fdb/server" _ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server" _ "github.com/Snowflake-Labs/sansshell/services/localfile/server" + mpa "github.com/Snowflake-Labs/sansshell/services/mpa/server" _ "github.com/Snowflake-Labs/sansshell/services/power/server" // Packages needs a real import to bind flags. @@ -171,6 +172,7 @@ func main() { server.WithParsedPolicy(parsed), server.WithJustification(*justification), server.WithAuthzHook(rpcauth.PeerPrincipalFromCertHook()), + server.WithAuthzHook(mpa.ServerMPAAuthzHook()), server.WithRawServerOption(func(s *grpc.Server) { reflection.Register(s) }), server.WithRawServerOption(func(s *grpc.Server) { channelz.RegisterChannelzServiceToServer(s) }), server.WithDebugPort(*debugport), diff --git a/go.mod b/go.mod index 6993e956..0d204d7d 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-logr/stdr v1.2.2 github.com/google/go-cmp v0.6.0 github.com/google/subcommands v1.2.0 + github.com/gowebpki/jcs v1.0.1 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.0 github.com/open-policy-agent/opa v0.58.0 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 0fda66b5..2807b8c9 100644 --- a/go.sum +++ b/go.sum @@ -249,6 +249,8 @@ github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56 github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gowebpki/jcs v1.0.1 h1:Qjzg8EOkrOTuWP7DqQ1FbYtcpEbeTzUoTN9bptp8FOU= +github.com/gowebpki/jcs v1.0.1/go.mod h1:CID1cNZ+sHp1CCpAR8mPf6QRtagFBgPJE0FCUQ6+BrI= github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.0 h1:f4tggROQKKcnh4eItay6z/HbHLqghBxS8g7pyMhmDio= github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.0/go.mod h1:hKAkSgNkL0FII46ZkJcpVEAai4KV+swlIWCKfekd1pA= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0-rc.3 h1:o95KDiV/b1xdkumY5YbLR0/n2+wBxUpgf3HgfKgTyLI= diff --git a/services/mpa/client/client.go b/services/mpa/client/client.go new file mode 100644 index 00000000..e29c0774 --- /dev/null +++ b/services/mpa/client/client.go @@ -0,0 +1,307 @@ +/* Copyright (c) 2023 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +// Package client provides the client interface for 'mpa' +package client + +import ( + "context" + "flag" + "fmt" + "os" + "strings" + + "github.com/Snowflake-Labs/sansshell/client" + pb "github.com/Snowflake-Labs/sansshell/services/mpa" + "github.com/Snowflake-Labs/sansshell/services/util" + "github.com/google/subcommands" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +const subPackage = "mpa" + +func init() { + subcommands.Register(&mpaCmd{}, subPackage) +} + +func (*mpaCmd) GetSubpackage(f *flag.FlagSet) *subcommands.Commander { + c := client.SetupSubpackage(subPackage, f) + c.Register(&approveCmd{}, "") + c.Register(&listCmd{}, "") + c.Register(&getCmd{}, "") + c.Register(&clearCmd{}, "") + return c +} + +type mpaCmd struct{} + +func (*mpaCmd) Name() string { return subPackage } +func (p *mpaCmd) Synopsis() string { + return client.GenerateSynopsis(p.GetSubpackage(flag.NewFlagSet("", flag.ContinueOnError)), 2) +} +func (p *mpaCmd) Usage() string { + return client.GenerateUsage(subPackage, p.Synopsis()) +} +func (*mpaCmd) SetFlags(f *flag.FlagSet) {} + +func (p *mpaCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + c := p.GetSubpackage(f) + return c.Execute(ctx, args...) +} + +func getAction(ctx context.Context, state *util.ExecuteState, c pb.MpaClientProxy, id string) *pb.Action { + resp, err := c.GetOneMany(ctx, &pb.GetRequest{Id: id}) + if err != nil { + // Emit this to every error file as it's not specific to a given target. + for _, e := range state.Err { + fmt.Fprintf(e, "All targets - could not execute: %v\n", err) + } + return nil + } + var anyAction *pb.Action + actions := make(map[int]*pb.Action) + for r := range resp { + if r.Error != nil { + fmt.Fprintf(state.Err[r.Index], "Unable to look up request: %v\n", r.Error) + continue + } + if r.Resp.Action == nil { + fmt.Fprintf(state.Err[r.Index], "Error: action was nil when looking up MPA request") + continue + } + actions[r.Index] = r.Resp.Action + anyAction = r.Resp.Action + } + if anyAction == nil { + // All commands above must have failed for this to happen. + return nil + } + for _, a := range actions { + if !proto.Equal(a, anyAction) { + // Bail if returned actions were inconsistent because we don't know which action is + // correct. + for idx := range actions { + fmt.Fprintf(state.Err[idx], "All targets - inconsistent action: <%v> vs <%v>\n", a, anyAction) + } + return nil + } + } + return anyAction +} + +type approveCmd struct{} + +func (*approveCmd) Name() string { return "approve" } +func (*approveCmd) Synopsis() string { return "Approves an MPA request" } +func (*approveCmd) Usage() string { + return `approve : + Approves an MPA request with the specified ID. +` +} + +func (p *approveCmd) SetFlags(f *flag.FlagSet) {} + +func (p *approveCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + state := args[0].(*util.ExecuteState) + if f.NArg() != 1 { + fmt.Fprintln(os.Stderr, "Please specify a single ID to approve.") + return subcommands.ExitUsageError + } + c := pb.NewMpaClientProxy(state.Conn) + action := getAction(ctx, state, c, f.Args()[0]) + if action == nil { + return subcommands.ExitFailure + } + + approved, err := c.ApproveOneMany(ctx, &pb.ApproveRequest{ + Action: action, + }) + if err != nil { + // Emit this to every error file as it's not specific to a given target. + for _, e := range state.Err { + fmt.Fprintf(e, "All targets - could not execute: %v\n", err) + } + return subcommands.ExitFailure + } + for r := range approved { + if r.Error != nil { + fmt.Fprintf(state.Err[r.Index], "Unable to approve: %v\n", r.Error) + continue + } + msg := []string{"Approved", action.Method} + if action.GetUser() != "" { + msg = append(msg, "from", action.GetUser()) + } + if action.GetJustification() != "" { + msg = append(msg, "for", action.GetJustification()) + } + fmt.Fprintln(state.Out[r.Index], strings.Join(msg, " ")) + } + return subcommands.ExitSuccess +} + +type listCmd struct { + verbose bool +} + +func (*listCmd) Name() string { return "list" } +func (*listCmd) Synopsis() string { return "Lists out pending MPA requests on machines" } +func (*listCmd) Usage() string { + return `list: + Lists out any MPA requests on machines. +` +} + +func (p *listCmd) SetFlags(f *flag.FlagSet) { + f.BoolVar(&p.verbose, "v", false, "Verbose: list full details of MPA request") +} + +func (p *listCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + state := args[0].(*util.ExecuteState) + if f.NArg() != 0 { + fmt.Fprintln(os.Stderr, "List takes no args.") + return subcommands.ExitUsageError + } + + c := pb.NewMpaClientProxy(state.Conn) + + resp, err := c.ListOneMany(ctx, &pb.ListRequest{}) + if err != nil { + // Emit this to every error file as it's not specific to a given target. + for _, e := range state.Err { + fmt.Fprintf(e, "All targets - could not execute: %v\n", err) + } + return subcommands.ExitFailure + } + for r := range resp { + if r.Error != nil { + fmt.Fprintln(state.Err[r.Index], r.Error) + continue + } + for _, item := range r.Resp.Item { + msg := []string{item.Id} + if p.verbose { + if len(item.Approver) > 0 { + var approvers []string + for _, a := range item.Approver { + approvers = append(approvers, a.Id) + } + msg = append(msg, fmt.Sprintf("(approved by %v)", strings.Join(approvers, ","))) + } + msg = append(msg, protojson.MarshalOptions{UseProtoNames: true}.Format(item.Action)) + } else { + msg = append(msg, item.Action.GetMethod()) + if item.Action.GetUser() != "" { + msg = append(msg, "from", item.Action.GetUser()) + } + if item.Action.GetJustification() != "" { + msg = append(msg, "for", item.Action.GetJustification()) + } + if len(item.Approver) > 0 { + msg = append(msg, "(approved)") + } + } + fmt.Fprintln(state.Out[r.Index], strings.Join(msg, " ")) + } + } + return subcommands.ExitSuccess +} + +type clearCmd struct{} + +func (*clearCmd) Name() string { return "clear" } +func (*clearCmd) Synopsis() string { return "Clears an MPA request" } +func (*clearCmd) Usage() string { + return `clear : + Clears an MPA request with the specified ID. +` +} + +func (p *clearCmd) SetFlags(f *flag.FlagSet) {} + +func (p *clearCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + state := args[0].(*util.ExecuteState) + if f.NArg() != 1 { + fmt.Fprintln(os.Stderr, "Please specify a single ID to clear.") + return subcommands.ExitUsageError + } + c := pb.NewMpaClientProxy(state.Conn) + action := getAction(ctx, state, c, f.Args()[0]) + if action == nil { + return subcommands.ExitFailure + } + + cleared, err := c.ClearOneMany(ctx, &pb.ClearRequest{ + Action: action, + }) + if err != nil { + // Emit this to every error file as it's not specific to a given target. + for _, e := range state.Err { + fmt.Fprintf(e, "All targets - could not execute: %v\n", err) + } + return subcommands.ExitFailure + } + for r := range cleared { + if r.Error != nil { + fmt.Fprintf(state.Err[r.Index], "Unable to clear: %v\n", r.Error) + continue + } + fmt.Fprintln(state.Out[r.Index], "Cleared") + } + return subcommands.ExitSuccess +} + +type getCmd struct{} + +func (*getCmd) Name() string { return "get" } +func (*getCmd) Synopsis() string { return "Print an MPA request" } +func (*getCmd) Usage() string { + return `get : + Prints out the MPA request with the specified ID. +` +} + +func (p *getCmd) SetFlags(f *flag.FlagSet) {} + +func (p *getCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + state := args[0].(*util.ExecuteState) + if f.NArg() != 1 { + fmt.Fprintln(os.Stderr, "Please specify a single ID to approve.") + return subcommands.ExitUsageError + } + c := pb.NewMpaClientProxy(state.Conn) + resp, err := c.GetOneMany(ctx, &pb.GetRequest{Id: f.Args()[0]}) + if err != nil { + // Emit this to every error file as it's not specific to a given target. + for _, e := range state.Err { + fmt.Fprintf(e, "All targets - could not execute: %v\n", err) + } + return subcommands.ExitFailure + } + for r := range resp { + if r.Error != nil { + fmt.Fprintf(state.Err[r.Index], "Unable to look up request: %v\n", r.Error) + continue + } + if r.Resp.Action == nil { + fmt.Fprintf(state.Err[r.Index], "Error: action was nil when looking up MPA request") + continue + } + fmt.Fprintln(state.Out[r.Index], protojson.MarshalOptions{UseProtoNames: true, Multiline: true}.Format(r.Resp)) + } + return subcommands.ExitSuccess +} diff --git a/services/mpa/mpahooks/mpahooks.go b/services/mpa/mpahooks/mpahooks.go new file mode 100644 index 00000000..1fcba70f --- /dev/null +++ b/services/mpa/mpahooks/mpahooks.go @@ -0,0 +1,221 @@ +/* Copyright (c) 2023 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +// Package mpahooks provides grpc interceptors and other helpers for implementing MPA. +package mpahooks + +import ( + "context" + "fmt" + "os" + + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" + "github.com/Snowflake-Labs/sansshell/services/mpa" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/anypb" +) + +const ( + // reqMPAKey is the key name that must exist in the incoming + // context metadata if the client wants to do an MPA request. + reqMPAKey = "sansshell-mpa-request-id" +) + +// WithMPAInMetadata adds a MPA ID to the grpc metadata of an outgoing RPC call +func WithMPAInMetadata(ctx context.Context, mpaID string) context.Context { + return metadata.AppendToOutgoingContext(ctx, reqMPAKey, mpaID) +} + +// MPAFromIncomingContext reads a MPA ID from the grpc metadata of an incoming RPC call +func MPAFromIncomingContext(ctx context.Context) (mpaID string, ok bool) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", false + } + v := md.Get(reqMPAKey) + if len(v) == 0 { + return "", false + } + return v[0], true +} + +// ActionMatchesInput returns an error if an MPA action doesn't match the +// message being checked in the RPCAuthInput. +func ActionMatchesInput(action *mpa.Action, input *rpcauth.RPCAuthInput) error { + var justification string + if j := input.Metadata[rpcauth.ReqJustKey]; len(j) > 0 { + justification = j[0] + } + + // Transform the rpcauth input into the original proto + mt, err := protoregistry.GlobalTypes.FindMessageByURL(input.MessageType) + if err != nil { + return fmt.Errorf("unable to find proto type: %v", err) + } + m2 := mt.New().Interface() + if err := protojson.Unmarshal([]byte(input.Message), m2); err != nil { + return fmt.Errorf("could not marshal input into %v: %v", input.Message, err) + } + var msg anypb.Any + if err := msg.MarshalFrom(m2); err != nil { + return fmt.Errorf("unable to marshal into anyproto: %v", err) + } + if input.Peer == nil || input.Peer.Principal == nil { + return fmt.Errorf("missing peer information") + } + + sentAct := &mpa.Action{ + User: input.Peer.Principal.ID, + Method: input.Method, + Justification: justification, + Message: &msg, + } + // Make sure to use an any-proto-aware comparison + if !cmp.Equal(action, sentAct, protocmp.Transform()) { + return fmt.Errorf("request doesn't match mpa approval: want %v, got %v", action, sentAct) + } + return nil +} + +func createAndBlockOnSingleTargetMPA(ctx context.Context, method string, req any, cc *grpc.ClientConn) (mpaID string, err error) { + p, ok := req.(proto.Message) + if !ok { + return "", fmt.Errorf("unable to cast req to proto: %v", req) + } + + var msg anypb.Any + if err := msg.MarshalFrom(p); err != nil { + return "", fmt.Errorf("unable to marshal into anyproto: %v", err) + } + + mpaClient := mpa.NewMpaClient(cc) + result, err := mpaClient.Store(ctx, &mpa.StoreRequest{ + Method: method, + Message: &msg, + }) + if err != nil { + return "", err + } + if len(result.Approver) == 0 { + fmt.Fprintln(os.Stderr, "Multi party auth requested, ask an approver to run:") + fmt.Fprintf(os.Stderr, " sanssh --targets %v mpa approve %v\n", cc.Target(), result.Id) + _, err := mpaClient.WaitForApproval(ctx, &mpa.WaitForApprovalRequest{Id: result.Id}) + if err != nil { + return "", err + } + } + return result.Id, nil +} + +// UnaryClientIntercepter is a grpc.UnaryClientIntercepter that will perform the MPA flow. +func UnaryClientIntercepter() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + // Our interceptor will run for all gRPC calls, including ones used inside the interceptor. + // We need to bail early on MPA-related ones to prevent infinite recursion. + if method == "/Mpa.Mpa/Store" || method == "/Mpa.Mpa/WaitForApproval" { + return invoker(ctx, method, req, reply, cc, opts...) + } + + mpaID, err := createAndBlockOnSingleTargetMPA(ctx, method, req, cc) + if err != nil { + return err + } + + ctx = WithMPAInMetadata(ctx, mpaID) + // Complete the call + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +// newStreamAfterFirstSend creates a grpc.ClientStream that doesn't attempt to begin +// the stream until SendMsg is first called. This is useful if we want to let the initial +// message affect how we set up the stream and supply metadata. +func newStreamAfterFirstSend(sendMsg func(m any) (grpc.ClientStream, error)) grpc.ClientStream { + return &delayedStartStream{ + sendMsg: sendMsg, + innerReady: make(chan struct{}), + } +} + +type delayedStartStream struct { + sendMsg func(m any) (grpc.ClientStream, error) + inner grpc.ClientStream + innerReady chan struct{} +} + +func (w *delayedStartStream) SendMsg(m any) error { + if w.inner == nil { + s, err := w.sendMsg(m) + if err != nil { + return err + } + w.inner = s + close(w.innerReady) + } + + return w.inner.SendMsg(m) +} + +func (w *delayedStartStream) Header() (metadata.MD, error) { + <-w.innerReady + return w.inner.Header() +} +func (w *delayedStartStream) Trailer() metadata.MD { + <-w.innerReady + return w.inner.Trailer() +} +func (w *delayedStartStream) CloseSend() error { + <-w.innerReady + return w.inner.CloseSend() +} +func (w *delayedStartStream) Context() context.Context { + <-w.innerReady + return w.inner.Context() +} +func (w *delayedStartStream) RecvMsg(m any) error { + <-w.innerReady + return w.inner.RecvMsg(m) +} + +// StreamClientIntercepter is a grpc.StreamClientInterceptor that will perform +// the MPA flow. +func StreamClientIntercepter() grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if method == "/Proxy.Proxy/Proxy" { + // No need to intercept proxying, that's handled specially. + return streamer(ctx, desc, cc, method, opts...) + } + + return newStreamAfterFirstSend(func(m any) (grpc.ClientStream, error) { + // Figure out the MPA request + mpaID, err := createAndBlockOnSingleTargetMPA(ctx, method, m, cc) + if err != nil { + return nil, err + } + + // Now establish the stream we actually want because we can only do so after + // we put the MPA ID in the metadata. + ctx := WithMPAInMetadata(ctx, mpaID) + return streamer(ctx, desc, cc, method, opts...) + }), nil + } +} diff --git a/services/mpa/mpahooks/mpahooks_test.go b/services/mpa/mpahooks/mpahooks_test.go new file mode 100644 index 00000000..346f03e1 --- /dev/null +++ b/services/mpa/mpahooks/mpahooks_test.go @@ -0,0 +1,289 @@ +/* Copyright (c) 2023 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package mpahooks_test + +import ( + "context" + "fmt" + "io" + "log" + "net" + "testing" + "time" + + "github.com/Snowflake-Labs/sansshell/auth/mtls" + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" + "github.com/Snowflake-Labs/sansshell/services" + "github.com/Snowflake-Labs/sansshell/services/healthcheck" + "github.com/Snowflake-Labs/sansshell/services/localfile" + "github.com/Snowflake-Labs/sansshell/services/mpa" + "github.com/Snowflake-Labs/sansshell/services/mpa/mpahooks" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" + + _ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server" + _ "github.com/Snowflake-Labs/sansshell/services/localfile/server" + mpaserver "github.com/Snowflake-Labs/sansshell/services/mpa/server" +) + +func mustAny(a *anypb.Any, err error) *anypb.Any { + if err != nil { + panic(err) + } + return a +} + +func TestActionMatchesInput(t *testing.T) { + for _, tc := range []struct { + desc string + action *mpa.Action + input *rpcauth.RPCAuthInput + matches bool + }{ + { + desc: "basic action", + action: &mpa.Action{ + User: "requester", + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }, + input: &rpcauth.RPCAuthInput{ + Method: "foobar", + MessageType: "google.protobuf.Empty", + Message: []byte("{}"), + Peer: &rpcauth.PeerAuthInput{ + Principal: &rpcauth.PrincipalAuthInput{ + ID: "requester", + }, + }, + }, + matches: true, + }, + { + desc: "missing auth info", + action: &mpa.Action{ + User: "requester", + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }, + input: &rpcauth.RPCAuthInput{ + Method: "foobar", + MessageType: "google.protobuf.Empty", + Message: []byte("{}"), + }, + matches: false, + }, + { + desc: "wrong message", + action: &mpa.Action{ + User: "requester", + Method: "foobar", + Message: mustAny(anypb.New(&mpa.Action{})), + }, + input: &rpcauth.RPCAuthInput{ + Method: "foobar", + MessageType: "google.protobuf.Empty", + Message: []byte("{}"), + Peer: &rpcauth.PeerAuthInput{ + Principal: &rpcauth.PrincipalAuthInput{ + ID: "requester", + }, + }, + }, + matches: false, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + err := mpahooks.ActionMatchesInput(tc.action, tc.input) + if err != nil && tc.matches { + t.Errorf("expected match: %v", err) + } + if err == nil && !tc.matches { + t.Error("unexpected match") + } + }) + } +} + +func pollForAction(ctx context.Context, m mpa.MpaClient, method string) (*mpa.Action, error) { + for { + l, err := m.List(ctx, &mpa.ListRequest{}) + if err != nil { + return nil, err + } + for _, i := range l.Item { + if i.Action.Method == method { + return i.Action, nil + } + } + time.Sleep(10 * time.Millisecond) + } +} + +var serverPolicy = ` +package sansshell.authz + +default allow = false + +allow { + input.method = "/HealthCheck.HealthCheck/Ok" + input.peer.principal.id = "sanssh" + input.approvers[_].id = "approver" +} + + +allow { + input.method = "/LocalFile.LocalFile/Read" + input.peer.principal.id = "sanssh" + input.approvers[_].id = "approver" +} + +allow { + startswith(input.method, "/Mpa.Mpa/") +} +` + +func TestClientInterceptors(t *testing.T) { + ctx := context.Background() + rot, err := mtls.LoadRootOfTrust("../../../auth/mtls/testdata/root.pem") + if err != nil { + t.Fatal(err) + } + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + srvAddr := lis.Addr().String() + authz, err := rpcauth.NewWithPolicy(ctx, serverPolicy, rpcauth.PeerPrincipalFromCertHook(), mpaserver.ServerMPAAuthzHook()) + if err != nil { + t.Fatal(err) + } + srvCreds, err := mtls.LoadServerTLS("../../../auth/mtls/testdata/leaf.pem", "../../../auth/mtls/testdata/leaf.key", rot) + if err != nil { + t.Fatal(err) + } + s := grpc.NewServer( + grpc.ChainStreamInterceptor(authz.AuthorizeStream), + grpc.ChainUnaryInterceptor(authz.Authorize), + grpc.Creds(srvCreds), + ) + for _, svc := range services.ListServices() { + svc.Register(s) + } + go func() { + if err := s.Serve(lis); err != nil { + log.Fatalf("Server exited with error: %v", err) + } + }() + defer s.GracefulStop() + + clientCreds, err := mtls.LoadClientTLS("../../../auth/mtls/testdata/client.pem", "../../../auth/mtls/testdata/client.key", rot) + if err != nil { + t.Fatal(err) + } + approverCreds, err := mtls.LoadClientTLS("../testdata/approver.pem", "../testdata/approver.key", rot) + if err != nil { + t.Fatal(err) + } + + // Confirm that we get Permission Denied without MPA + noInterceptorConn, err := grpc.DialContext(ctx, srvAddr, + grpc.WithTransportCredentials(clientCreds), + ) + if err != nil { + t.Fatal(err) + } + if _, err := healthcheck.NewHealthCheckClient(noInterceptorConn).Ok(ctx, &emptypb.Empty{}); status.Code(err) != codes.PermissionDenied { + t.Fatalf("got something other than permission denied: %v", err) + } + read, err := localfile.NewLocalFileClient(noInterceptorConn).Read(ctx, &localfile.ReadActionRequest{ + Request: &localfile.ReadActionRequest_File{File: &localfile.ReadRequest{Filename: "/etc/hosts"}}, + }) + if err != nil { + t.Fatal(err) + } + if _, err := read.Recv(); status.Code(err) != codes.PermissionDenied { + t.Fatalf("got something other than permission denied: %v", err) + } + + var g errgroup.Group + g.Go(func() error { + // Set up an approver loop + conn, err := grpc.DialContext(ctx, srvAddr, grpc.WithTransportCredentials(approverCreds)) + if err != nil { + return err + } + m := mpa.NewMpaClient(conn) + + healthcheckAction, err := pollForAction(ctx, m, "/HealthCheck.HealthCheck/Ok") + if err != nil { + return err + } + if _, err := m.Approve(ctx, &mpa.ApproveRequest{Action: healthcheckAction}); err != nil { + return fmt.Errorf("unable to approve %v: %v", healthcheckAction, err) + } + + fileReadAction, err := pollForAction(ctx, m, "/LocalFile.LocalFile/Read") + if err != nil { + return err + } + if _, err := m.Approve(ctx, &mpa.ApproveRequest{Action: fileReadAction}); err != nil { + return fmt.Errorf("unable to approve %v: %v", healthcheckAction, err) + } + return nil + }) + + // Make our calls + conn, err := grpc.DialContext(ctx, srvAddr, + grpc.WithTransportCredentials(clientCreds), + grpc.WithChainStreamInterceptor(mpahooks.StreamClientIntercepter()), + grpc.WithChainUnaryInterceptor(mpahooks.UnaryClientIntercepter()), + ) + if err != nil { + t.Error(err) + } + hc := healthcheck.NewHealthCheckClient(conn) + if _, err := hc.Ok(ctx, &emptypb.Empty{}); err != nil { + t.Error(err) + } + + file := localfile.NewLocalFileClient(conn) + bytes, err := file.Read(ctx, &localfile.ReadActionRequest{ + Request: &localfile.ReadActionRequest_File{File: &localfile.ReadRequest{Filename: "/etc/hosts"}}, + }) + if err != nil { + t.Error(err) + } else { + for { + _, err := bytes.Recv() + if err != nil { + if err != io.EOF { + t.Error(err) + } + break + } + } + } + + if err := g.Wait(); err != nil { + t.Error(err) + } +} diff --git a/services/mpa/server/server.go b/services/mpa/server/server.go new file mode 100644 index 00000000..08e20b1d --- /dev/null +++ b/services/mpa/server/server.go @@ -0,0 +1,295 @@ +/* Copyright (c) 2023 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +// Package server implements the sansshell 'Mpa' service. +package server + +import ( + "context" + "crypto/sha256" + "fmt" + "reflect" + "sort" + "sync" + "time" + + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" + "github.com/Snowflake-Labs/sansshell/services" + "github.com/Snowflake-Labs/sansshell/services/mpa" + "github.com/Snowflake-Labs/sansshell/services/mpa/mpahooks" + "github.com/gowebpki/jcs" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" +) + +var ( + // We hardcode a maximum number of approvals to remember to avoid + // providing an easy way to overload the server. + maxMPAApprovals = 1000 + // We also hardcode a max age to prevent unreasonably-old approvals from being used. + maxMPAApprovedAge = 24 * time.Hour +) + +// ServerMPAAuthzHook populates approver information based on an internal MPA store. +func ServerMPAAuthzHook() rpcauth.RPCAuthzHook { + return rpcauth.RPCAuthzHookFunc(func(ctx context.Context, input *rpcauth.RPCAuthInput) error { + mpaID, ok := mpahooks.MPAFromIncomingContext(ctx) + if !ok { + // Nothing to look up if MPA wasn't requested + return nil + } + resp, err := serverSingleton.Get(ctx, &mpa.GetRequest{Id: mpaID}) + if err != nil { + return err + } + + if err := mpahooks.ActionMatchesInput(resp.Action, input); err != nil { + return err + } + for _, a := range resp.Approver { + input.Approvers = append(input.Approvers, &rpcauth.PrincipalAuthInput{ + ID: a.Id, + Groups: a.Groups, + }) + } + return nil + }) +} + +// actionId generates the id for an action by hashing it +func actionId(action *mpa.Action) (string, error) { + // Binary proto encoding doesn't provide any guarantees about deterministic + // output for the same input. Go provides a deterministic marshalling option, + // but this marshalling isn't guaranteed to be stable over time. + // JSON encoding can be made deterministic by canonicalizing. + b, err := protojson.Marshal(action) + if err != nil { + return "", err + } + canonical, err := jcs.Transform(b) + if err != nil { + return "", err + } + h := sha256.New() + h.Write(canonical) + sum := h.Sum(nil) + // Humans are going to need to deal with these, so let's shorten them + // and make them a bit prettier. + return fmt.Sprintf("%x-%x-%x", sum[0:4], sum[4:8], sum[8:12]), nil + +} + +type storedAction struct { + action *mpa.Action + lastModified time.Time + approvers []*mpa.Principal + approved chan struct{} +} + +// server is used to implement the gRPC server +type server struct { + actions map[string]*storedAction + mu sync.Mutex +} + +func callerIdentity(ctx context.Context) (*rpcauth.PrincipalAuthInput, bool) { + // TODO(#346): Prefer using a proxied identity if provided + peer := rpcauth.PeerInputFromContext(ctx) + if peer != nil { + return peer.Principal, true + } + return nil, false +} + +func (s *server) clearOutdatedApprovals() { + s.mu.Lock() + defer s.mu.Unlock() + + staleTime := time.Now().Add(-maxMPAApprovedAge) + for id, act := range s.actions { + if act.lastModified.Before(staleTime) { + delete(s.actions, id) + } + } + +} + +func (s *server) Store(ctx context.Context, in *mpa.StoreRequest) (*mpa.StoreResponse, error) { + var justification string + if md, found := metadata.FromIncomingContext(ctx); found && len(md[rpcauth.ReqJustKey]) > 0 { + justification = md[rpcauth.ReqJustKey][0] + } + + p, ok := callerIdentity(ctx) + if !ok || p == nil { + return nil, status.Error(codes.FailedPrecondition, "unable to determine caller's identity") + } + + action := &mpa.Action{ + User: p.ID, + Justification: justification, + Method: in.Method, + Message: in.Message, + } + id, err := actionId(action) + if err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + + // Time to clear out excessive approvals! + for len(s.actions) >= maxMPAApprovals { + var oldestID string + oldestTime := time.Now() + for id, act := range s.actions { + if act.lastModified.Before(oldestTime) { + oldestID = id + oldestTime = act.lastModified + } + } + delete(s.actions, oldestID) + } + + act, ok := s.actions[id] + if !ok { + act = &storedAction{ + action: action, + approved: make(chan struct{}), + lastModified: time.Now(), + } + s.actions[id] = act + } + return &mpa.StoreResponse{ + Id: id, + Action: action, + Approver: act.approvers, + }, nil +} + +func containsPrincipal(principals []*mpa.Principal, p *rpcauth.PrincipalAuthInput) bool { + for _, s := range principals { + if s.Id == p.ID && reflect.DeepEqual(s.Groups, p.Groups) { + return true + } + } + return false +} + +func (s *server) Approve(ctx context.Context, in *mpa.ApproveRequest) (*mpa.ApproveResponse, error) { + p, ok := callerIdentity(ctx) + if !ok { + return nil, status.Error(codes.FailedPrecondition, "unable to determine caller's identity") + } + id, err := actionId(in.Action) + if err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + act, ok := s.actions[id] + if !ok { + return nil, status.Error(codes.NotFound, "MPA request with provided input not found") + } + if act.action.User == p.ID { + return nil, status.Error(codes.InvalidArgument, "MPA requests cannot be approved by their requestor") + } + act.lastModified = time.Now() + // Only add the approver if it's new compared to existing approvals + if !containsPrincipal(act.approvers, p) { + act.approvers = append(act.approvers, &mpa.Principal{ + Id: p.ID, + Groups: p.Groups, + }) + // The first approval lets any WaitForApproval calls finish immediately. + if len(act.approvers) == 1 { + close(act.approved) + } + } + return &mpa.ApproveResponse{}, nil +} +func (s *server) WaitForApproval(ctx context.Context, in *mpa.WaitForApprovalRequest) (*mpa.WaitForApprovalResponse, error) { + for { + s.mu.Lock() + act, ok := s.actions[in.Id] + if !ok { + return nil, status.Error(codes.NotFound, "MPA request not found") + } + s.mu.Unlock() + select { + case <-act.approved: + return &mpa.WaitForApprovalResponse{}, nil + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Minute): + // Loop around again so that we can make sure that the request still exists + } + } +} +func (s *server) List(ctx context.Context, in *mpa.ListRequest) (*mpa.ListResponse, error) { + s.clearOutdatedApprovals() + s.mu.Lock() + defer s.mu.Unlock() + var items []*mpa.ListResponse_Item + for id, action := range s.actions { + items = append(items, &mpa.ListResponse_Item{ + Id: id, + Action: action.action, + Approver: action.approvers, + }) + } + sort.Slice(items, func(i, j int) bool { + return items[i].Id < items[j].Id + }) + return &mpa.ListResponse{Item: items}, nil +} +func (s *server) Get(ctx context.Context, in *mpa.GetRequest) (*mpa.GetResponse, error) { + s.clearOutdatedApprovals() + s.mu.Lock() + defer s.mu.Unlock() + act, ok := s.actions[in.Id] + if !ok { + return nil, status.Error(codes.NotFound, "MPA request not found") + } + return &mpa.GetResponse{ + Action: act.action, + Approver: act.approvers, + }, nil +} +func (s *server) Clear(ctx context.Context, in *mpa.ClearRequest) (*mpa.ClearResponse, error) { + id, err := actionId(in.Action) + if err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + delete(s.actions, id) + return &mpa.ClearResponse{}, nil +} + +// Register is called to expose this handler to the gRPC server +func (s *server) Register(gs *grpc.Server) { + mpa.RegisterMpaServer(gs, s) +} + +var serverSingleton = &server{actions: make(map[string]*storedAction)} + +func init() { + services.RegisterSansShellService(serverSingleton) +} diff --git a/services/mpa/server/server_test.go b/services/mpa/server/server_test.go new file mode 100644 index 00000000..0d58fcdb --- /dev/null +++ b/services/mpa/server/server_test.go @@ -0,0 +1,223 @@ +/* Copyright (c) 2023 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package server + +import ( + "context" + "reflect" + "strconv" + "testing" + + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" + "github.com/Snowflake-Labs/sansshell/services/mpa" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" +) + +func mustAny(a *anypb.Any, err error) *anypb.Any { + if err != nil { + panic(err) + } + return a +} + +func TestAuthzHook(t *testing.T) { + ctx := context.Background() + + // Create a hook and make sure that hooking with no mpa request works + hook := ServerMPAAuthzHook() + newInput, err := rpcauth.NewRPCAuthInput(ctx, "foobar", &emptypb.Empty{}) + if err != nil { + t.Fatal(err) + } + if err := hook.Hook(ctx, newInput); err != nil { + t.Fatal(err) + } + + // Add a request + rCtx := rpcauth.AddPeerToContext(ctx, &rpcauth.PeerAuthInput{ + Principal: &rpcauth.PrincipalAuthInput{ID: "requester"}, + }) + if _, err := serverSingleton.Store(rCtx, &mpa.StoreRequest{ + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }); err != nil { + t.Fatal(err) + } + + // Make sure we can't approve our own request + approvReq := &mpa.ApproveRequest{ + Action: &mpa.Action{ + User: "requester", + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }, + } + if _, err := serverSingleton.Approve(rCtx, approvReq); status.Code(err) != codes.InvalidArgument { + t.Fatalf("expected failure when self-approving: %v", err) + } + + aCtx := rpcauth.AddPeerToContext(ctx, &rpcauth.PeerAuthInput{ + Principal: &rpcauth.PrincipalAuthInput{ID: "approver", Groups: []string{"g1"}}, + }) + // Approve the request twice to make sure approval is idempotent + for i := 0; i < 2; i++ { + if _, err := serverSingleton.Approve(aCtx, approvReq); err != nil { + t.Fatal(err) + } + } + + mpaCtx := metadata.NewIncomingContext(rCtx, map[string][]string{"sansshell-mpa-request-id": {"3e31b2b4-f8724bae-c1504987"}}) + passingInput, err := rpcauth.NewRPCAuthInput(mpaCtx, "foobar", &emptypb.Empty{}) + if err != nil { + t.Fatal(err) + } + if err := hook.Hook(mpaCtx, passingInput); err != nil { + t.Fatal(err) + } + wantApprovers := []*rpcauth.PrincipalAuthInput{ + {ID: "approver", Groups: []string{"g1"}}, + } + if !reflect.DeepEqual(passingInput.Approvers, wantApprovers) { + t.Errorf("got %+v, want %+v", passingInput.Approvers, wantApprovers) + } + + // An action not matching the input should fail + wrongInput, err := rpcauth.NewRPCAuthInput(mpaCtx, "foobaz", &emptypb.Empty{}) + if err != nil { + t.Fatal(err) + } + if err := hook.Hook(mpaCtx, wrongInput); err == nil { + t.Fatal("unexpectedly nil err") + } +} + +func TestMaxNumApprovals(t *testing.T) { + ctx := context.Background() + for i := 0; i < maxMPAApprovals+20; i++ { + rCtx := rpcauth.AddPeerToContext(ctx, &rpcauth.PeerAuthInput{ + Principal: &rpcauth.PrincipalAuthInput{ID: "requester"}, + }) + if _, err := serverSingleton.Store(rCtx, &mpa.StoreRequest{ + Method: "foobar" + strconv.Itoa(i), + Message: mustAny(anypb.New(&emptypb.Empty{})), + }); err != nil { + t.Fatal(err) + } + } + reqs, err := serverSingleton.List(ctx, &mpa.ListRequest{}) + if err != nil { + t.Fatal(err) + } + if len(reqs.Item) != maxMPAApprovals { + t.Fatalf("got %v requests, expected %v", len(reqs.Item), maxMPAApprovals) + } +} + +func TestWaitForApproval(t *testing.T) { + ctx := context.Background() + + rCtx := rpcauth.AddPeerToContext(ctx, &rpcauth.PeerAuthInput{ + Principal: &rpcauth.PrincipalAuthInput{ID: "requester"}, + }) + if _, err := serverSingleton.Store(rCtx, &mpa.StoreRequest{ + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }); err != nil { + t.Fatal(err) + } + + var g errgroup.Group + g.Go(func() error { + aCtx := rpcauth.AddPeerToContext(ctx, &rpcauth.PeerAuthInput{ + Principal: &rpcauth.PrincipalAuthInput{ID: "approver", Groups: []string{"g1"}}, + }) + _, err := serverSingleton.Approve(aCtx, &mpa.ApproveRequest{ + Action: &mpa.Action{ + User: "requester", + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }, + }) + return err + }) + + _, err := serverSingleton.WaitForApproval(ctx, &mpa.WaitForApprovalRequest{ + Id: "3e31b2b4-f8724bae-c1504987", + }) + if err != nil { + t.Fatal(err) + } + if err := g.Wait(); err != nil { + t.Fatal(err) + } +} + +func TestActionIdIsDeterministic(t *testing.T) { + for _, tc := range []struct { + desc string + action *mpa.Action + wantID string + }{ + { + desc: "empty action", + action: &mpa.Action{}, + wantID: "44136fa3-55b3678a-1146ad16", + }, + { + desc: "simple action", + action: &mpa.Action{ + User: "requester", + Method: "foobar", + Message: mustAny(anypb.New(&emptypb.Empty{})), + }, + wantID: "3e31b2b4-f8724bae-c1504987", + }, + { + desc: "complex action", + action: &mpa.Action{ + User: "user", + Method: "method", + Justification: "justification", + Message: mustAny(anypb.New(&mpa.Action{ + User: "so", + Method: "meta", + Justification: "nested", + Message: mustAny(anypb.New(&mpa.Principal{ + Id: "approver", + Groups: []string{"g1", "g2"}, + })), + })), + }, + wantID: "66bc8827-d4fab1bf-b51181f1", + }, + } { + t.Run(tc.desc, func(t *testing.T) { + id, err := actionId(tc.action) + if err != nil { + t.Error(err) + } + if id != tc.wantID { + t.Errorf("got %v, want %v", id, tc.wantID) + } + }) + } +} diff --git a/services/mpa/testdata/README.md b/services/mpa/testdata/README.md new file mode 100644 index 00000000..10dbe460 --- /dev/null +++ b/services/mpa/testdata/README.md @@ -0,0 +1,13 @@ +This directory contains additional certificates for testing MPA. See [/auth/mtls/testdata/](/auth/mtls/testdata/) for more info on these. + +``` +openssl genrsa -out approver.key +openssl req -new -key approver.key -out approver.csr -subj "/O=Acme Co/OU=group1/OU=group2/CN=approver" +openssl x509 -req -days 3000 -in approver.csr -CA ../../../auth/mtls/testdata/root.pem -CAkey ../../../auth/mtls/testdata/root.key -out approver.pem -extensions req_ext -extfile /dev/stdin <