Skip to content

Commit

Permalink
Add a justification optional requirement. (#75)
Browse files Browse the repository at this point in the history
* Add a justification optional requirement.

Servers can require a justification string to be passed in from client side metadata.

They can also provide a user defined function to validate this as well.

Plumb into servers, sanssh and integration tests.

* Update usage information and important flags

* Refactor into authz hook style.

Telemetry just extracts anything which is sansshell-* from metadata and logs it.

If it gets an error from the handler (which is where authz hooks) it'll bail at that point.

So we get logging in one place and authz handled correctly in it's place.

* Remove debugging

* Convert server startup to take lists of authz hooks.
  • Loading branch information
sfc-gh-jchacon authored Feb 16, 2022
1 parent 4b4f696 commit fa85312
Show file tree
Hide file tree
Showing 14 changed files with 300 additions and 50 deletions.
3 changes: 2 additions & 1 deletion auth/mtls/mtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
"github.com/Snowflake-Labs/sansshell/server"
hcpb "github.com/Snowflake-Labs/sansshell/services/healthcheck"
_ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server"
Expand Down Expand Up @@ -100,7 +101,7 @@ func serverWithPolicy(t *testing.T, policy string, CAPool *x509.CertPool) (*bufc
creds, err := LoadServerTLS("testdata/leaf.pem", "testdata/leaf.key", CAPool)
testutil.FatalOnErr("Failed to load client cert", err, t)
lis := bufconn.Listen(bufSize)
s, err := server.BuildServer(creds, policy, lis.Addr(), logr.Discard())
s, err := server.BuildServer(creds, policy, logr.Discard(), rpcauth.HostNetHook(lis.Addr()))
testutil.FatalOnErr("Could not build server", err, t)
listening := make(chan struct{})
go func() {
Expand Down
40 changes: 39 additions & 1 deletion auth/opa/rpcauth/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ package rpcauth
import (
"context"
"net"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// RPCAuthzHookFunc implements RpcAuthzHook for a simple function
Expand Down Expand Up @@ -53,7 +56,7 @@ func (c *conditionalHook) Hook(ctx context.Context, input *RPCAuthInput) error {
return nil
}

// HostNetHook returns an RpcAuthzHook that sets host networking information.
// HostNetHook returns an RPCAuthzHook that sets host networking information.
func HostNetHook(addr net.Addr) RPCAuthzHook {
return RPCAuthzHookFunc(func(ctx context.Context, input *RPCAuthInput) error {
if input.Host == nil {
Expand All @@ -63,3 +66,38 @@ func HostNetHook(addr net.Addr) RPCAuthzHook {
return nil
})
}

const (
// ReqJustKey is the key name that must exist in the incoming
// context metadata if client side provided justification is required.
ReqJustKey = "sansshell-justification"
)

var (
// ErrJustification is the error returned for missing justification.
ErrJustification = status.Error(codes.FailedPrecondition, "missing justification")
)

// JustificationHook takes the given optional justification function and returns an RPCAuthzHook
// that validates if justification was included. If it is required and passes the optional validation function
// it will return nil, otherwise an error.
func JustificationHook(justificationFunc func(string) error) RPCAuthzHook {
return RPCAuthzHookFunc(func(ctx context.Context, input *RPCAuthInput) error {
// See if we got any metadata and if it contains the justification
var j string
v := input.Metadata[ReqJustKey]
if len(v) > 0 {
j = v[0]
}

if j == "" {
return ErrJustification
}
if justificationFunc != nil {
if err := justificationFunc(j); err != nil {
return status.Errorf(codes.FailedPrecondition, "justification failed: %v", err)
}
}
return nil
})
}
53 changes: 53 additions & 0 deletions auth/opa/rpcauth/rpcauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,59 @@ func TestAuthzHook(t *testing.T) {
},
errFunc: wantStatusCode(codes.OK),
},
{
name: "network data allow with justification (no func)",
input: &RPCAuthInput{
Method: "/Foo.Bar/Foo",
Metadata: metadata.MD{
ReqJustKey: []string{"justification"},
},
},
hooks: []RPCAuthzHook{
HostNetHook(tcp),
JustificationHook(nil),
},
errFunc: wantStatusCode(codes.OK),
},
{
name: "network data allow with justification req but none given (no func)",
input: &RPCAuthInput{
Method: "/Foo.Bar/Foo",
},
hooks: []RPCAuthzHook{
HostNetHook(tcp),
JustificationHook(nil),
},
errFunc: wantStatusCode(codes.FailedPrecondition),
},
{
name: "network data allow with justification (with func)",
input: &RPCAuthInput{
Method: "/Foo.Bar/Foo",
Metadata: metadata.MD{
ReqJustKey: []string{"justification"},
},
},
hooks: []RPCAuthzHook{
HostNetHook(tcp),
JustificationHook(func(string) error { return nil }),
},
errFunc: wantStatusCode(codes.OK),
},
{
name: "network data allow with justification req given and func fails",
input: &RPCAuthInput{
Method: "/Foo.Bar/Foo",
Metadata: metadata.MD{
ReqJustKey: []string{"justification"},
},
},
hooks: []RPCAuthzHook{
HostNetHook(tcp),
JustificationHook(func(string) error { return errors.New("error") }),
},
errFunc: wantStatusCode(codes.FailedPrecondition),
},
{
name: "conditional hook, triggered",
input: &RPCAuthInput{Method: "/Some.Random/Method"},
Expand Down
23 changes: 13 additions & 10 deletions cmd/proxy-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/Snowflake-Labs/sansshell/auth/mtls"
mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags"
"github.com/Snowflake-Labs/sansshell/auth/opa"
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
"github.com/Snowflake-Labs/sansshell/cmd/proxy-server/server"
"github.com/Snowflake-Labs/sansshell/cmd/util"
"github.com/go-logr/stdr"
Expand All @@ -37,12 +38,13 @@ var (
//go:embed default-policy.rego
defaultPolicy string

policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.")
policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.")
hostport = flag.String("hostport", "localhost:50043", "Where to listen for connections.")
credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS creds (one of [%s])", strings.Join(mtls.Loaders(), ",")))
verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging")
validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)")
policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.")
policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.")
hostport = flag.String("hostport", "localhost:50043", "Where to listen for connections.")
credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS creds (one of [%s])", strings.Join(mtls.Loaders(), ",")))
verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging")
validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)")
justification = flag.Bool("justification", false, "If true then justification (which is logged and possibly validated) must be passed along in the client context Metadata with the key '"+rpcauth.ReqJustKey+"'")
)

func main() {
Expand All @@ -65,10 +67,11 @@ func main() {
}

rs := server.RunState{
Logger: logger,
Policy: policy,
CredSource: *credSource,
Hostport: *hostport,
Logger: logger,
Policy: policy,
CredSource: *credSource,
Hostport: *hostport,
Justification: *justification,
}
server.Run(ctx, rs)
}
13 changes: 12 additions & 1 deletion cmd/proxy-server/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ type RunState struct {
CredSource string
// Hostport is the host:port to run the server.
Hostport string
// Justification if true requires justification to be set in the
// incoming RPC context Metadata (to the key defined in the telemetry package).
Justification bool
// JustificationFunc will be called if Justication is true and a justification
// entry is found. The supplied function can then do any validation it wants
// in order to ensure it's compliant.
JustificationFunc func(string) error
}

// Run takes the given context and RunState along with any authz hooks and starts up a sansshell proxy server
Expand Down Expand Up @@ -80,7 +87,11 @@ func Run(ctx context.Context, rs RunState, hooks ...rpcauth.RPCAuthzHook) {
addressHook := rpcauth.HookIf(rpcauth.HostNetHook(lis.Addr()), func(input *rpcauth.RPCAuthInput) bool {
return input.Host == nil || input.Host.Net == nil
})
h := []rpcauth.RPCAuthzHook{addressHook}
justificationHook := rpcauth.HookIf(rpcauth.JustificationHook(rs.JustificationFunc), func(input *rpcauth.RPCAuthInput) bool {
return rs.Justification
})

h := []rpcauth.RPCAuthzHook{addressHook, justificationHook}
h = append(h, hooks...)
authz, err := rpcauth.NewWithPolicy(ctx, rs.Policy, h...)
if err != nil {
Expand Down
19 changes: 14 additions & 5 deletions cmd/sanssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@ import (

"github.com/Snowflake-Labs/sansshell/auth/mtls"
mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags"
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
"github.com/Snowflake-Labs/sansshell/cmd/sanssh/client"
"github.com/Snowflake-Labs/sansshell/services/util"
"github.com/google/subcommands"
"google.golang.org/grpc/metadata"
)

var (
defaultAddress = "localhost:50042"
defaultTimeout = 3 * time.Second

proxyAddr = flag.String("proxy", "", "Address to contact for proxy to sansshell-server. If blank a direct connection to the first entry in --targets will be made")
timeout = flag.Duration("timeout", defaultTimeout, "How long to wait for the command to complete")
credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ",")))
outputsDir = flag.String("output-dir", "", "If set defines a directory to emit output/errors from commands. Files will be generated based on target as destination/0 destination/0.error, etc.")
proxyAddr = flag.String("proxy", "", "Address to contact for proxy to sansshell-server. If blank a direct connection to the first entry in --targets will be made")
timeout = flag.Duration("timeout", defaultTimeout, "How long to wait for the command to complete")
credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ",")))
outputsDir = flag.String("output-dir", "", "If set defines a directory to emit output/errors from commands. Files will be generated based on target as destination/0 destination/0.error, etc.")
justification = flag.String("justification", "", "If non-empty will add the key '"+rpcauth.ReqJustKey+"' to the outgoing context Metadata to be passed along to the server for possbile validation and logging.")

// targets will be bound to --targets for sending a single request to N nodes.
targetsFlag util.StringSliceFlag
Expand All @@ -62,6 +65,8 @@ func init() {
subcommands.ImportantFlag("proxy")
subcommands.ImportantFlag("targets")
subcommands.ImportantFlag("outputs")
subcommands.ImportantFlag("output-dir")
subcommands.ImportantFlag("justification")
}

func main() {
Expand All @@ -75,5 +80,9 @@ func main() {
CredSource: *credSource,
Timeout: *timeout,
}
client.Run(context.Background(), rs)
ctx := context.Background()
if *justification != "" {
ctx = metadata.AppendToOutgoingContext(ctx, rpcauth.ReqJustKey, *justification)
}
client.Run(ctx, rs)
}
23 changes: 13 additions & 10 deletions cmd/sansshell-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/Snowflake-Labs/sansshell/auth/mtls"
mtlsFlags "github.com/Snowflake-Labs/sansshell/auth/mtls/flags"
"github.com/Snowflake-Labs/sansshell/auth/opa"
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
"github.com/Snowflake-Labs/sansshell/cmd/sansshell-server/server"
"github.com/Snowflake-Labs/sansshell/cmd/util"
)
Expand All @@ -42,12 +43,13 @@ var (
//go:embed default-policy.rego
defaultPolicy string

policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.")
policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.")
hostport = flag.String("hostport", "localhost:50042", "Where to listen for connections.")
credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ",")))
verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging")
validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)")
policyFlag = flag.String("policy", defaultPolicy, "Local OPA policy governing access. If empty, use builtin policy.")
policyFile = flag.String("policy-file", "", "Path to a file with an OPA policy. If empty, uses --policy.")
hostport = flag.String("hostport", "localhost:50042", "Where to listen for connections.")
credSource = flag.String("credential-source", mtlsFlags.Name(), fmt.Sprintf("Method used to obtain mTLS credentials (one of [%s])", strings.Join(mtls.Loaders(), ",")))
verbosity = flag.Int("v", 0, "Verbosity level. > 0 indicates more extensive logging")
validate = flag.Bool("validate", false, "If true will evaluate the policy and then exit (non-zero on error)")
justification = flag.Bool("justification", false, "If true then justification (which is logged and possibly validated) must be passed along in the client context Metadata with the key '"+rpcauth.ReqJustKey+"'")
)

func main() {
Expand All @@ -70,10 +72,11 @@ func main() {
}

rs := server.RunState{
Logger: logger,
CredSource: *credSource,
Hostport: *hostport,
Policy: policy,
Logger: logger,
CredSource: *credSource,
Hostport: *hostport,
Policy: policy,
Justification: *justification,
}
server.Run(ctx, rs)
}
13 changes: 12 additions & 1 deletion cmd/sansshell-server/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"os"

"github.com/Snowflake-Labs/sansshell/auth/mtls"
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
"github.com/Snowflake-Labs/sansshell/server"
"github.com/go-logr/logr"

Expand All @@ -46,6 +47,13 @@ type RunState struct {
Hostport string
// Policy is an OPA policy for determining authz decisions.
Policy string
// Justification if true requires justification to be set in the
// incoming RPC context Metadata (to the key defined in the telemetry package).
Justification bool
// JustificationFunc will be called if Justication is true and a justification
// entry is found. The supplied function can then do any validation it wants
// in order to ensure it's compliant.
JustificationFunc func(string) error
}

// Run takes the given context and RunState and starts up a sansshell server.
Expand All @@ -57,7 +65,10 @@ func Run(ctx context.Context, rs RunState) {
os.Exit(1)
}

if err := server.Serve(rs.Hostport, creds, rs.Policy, rs.Logger); err != nil {
justificationHook := rpcauth.HookIf(rpcauth.JustificationHook(rs.JustificationFunc), func(input *rpcauth.RPCAuthInput) bool {
return rs.Justification
})
if err := server.Serve(rs.Hostport, creds, rs.Policy, rs.Logger, justificationHook); err != nil {
rs.Logger.Error(err, "server.Serve", "hostport", rs.Hostport)
os.Exit(1)
}
Expand Down
12 changes: 9 additions & 3 deletions proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,16 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_
}

err = stream.Send(req)
if err != nil {
// If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation
// for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream
// a server has ever handled so account for that here.
if err != nil && err != io.EOF {
return nil, nil, status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err)
}
if err != nil {
_, err := stream.Recv()
return nil, nil, status.Errorf(codes.Internal, "remote error from Send for %s - %v", method, err)
}
resp, err := stream.Recv()
if err != nil {
return nil, nil, status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err)
Expand Down Expand Up @@ -391,8 +398,7 @@ func (p *Conn) InvokeOneMany(ctx context.Context, method string, args interface{
if err := s.send(requestMsg); err != nil {
return nil, err
}
// TODO(): Put this back when the race in server.go is figured out. Causes a send and close
// on the channel processing server side which isn't allowed.

if err := s.closeClients(); err != nil {
return nil, err
}
Expand Down
16 changes: 10 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,19 @@ var (
mu sync.Mutex
)

// Serve wraps up BuildServer in a succinct API for callers
func Serve(hostport string, c credentials.TransportCredentials, policy string, logger logr.Logger) error {
// Serve wraps up BuildServer in a succinct API for callers passing along various parameters. It will automatically add
// an authz hook for HostNet based on the listener address. Additional hooks are passed along after this one.
func Serve(hostport string, c credentials.TransportCredentials, policy string, logger logr.Logger, authzHooks ...rpcauth.RPCAuthzHook) error {
lis, err := net.Listen("tcp", hostport)
if err != nil {
return fmt.Errorf("failed to listen: %v", err)
}

mu.Lock()
srv, err = BuildServer(c, policy, lis.Addr(), logger)
h := []rpcauth.RPCAuthzHook{rpcauth.HostNetHook(lis.Addr())}
h = append(h, authzHooks...)

srv, err = BuildServer(c, policy, logger, h...)
mu.Unlock()
if err != nil {
return err
Expand All @@ -62,11 +66,11 @@ func getSrv() *grpc.Server {
return srv
}

// BuildServer creates a gRPC server, attaches the OPA policy interceptor,
// BuildServer creates a gRPC server, attaches the OPA policy interceptor with supplied args and then
// registers all of the imported SansShell modules. Separating this from Serve
// primarily facilitates testing.
func BuildServer(c credentials.TransportCredentials, policy string, address net.Addr, logger logr.Logger) (*grpc.Server, error) {
authz, err := rpcauth.NewWithPolicy(context.Background(), policy, rpcauth.HostNetHook(address))
func BuildServer(c credentials.TransportCredentials, policy string, logger logr.Logger, authzHooks ...rpcauth.RPCAuthzHook) (*grpc.Server, error) {
authz, err := rpcauth.NewWithPolicy(context.Background(), policy, authzHooks...)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit fa85312

Please sign in to comment.