From 5067420d0120cc68ab0512b68ebed0a1300a3dfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awek=20Rudnicki?= Date: Thu, 24 Oct 2024 14:13:50 +0200 Subject: [PATCH] Pass peer credentials from Unix domain socket to authorization engine (#501) In commit 39af558b10, we have added the possibility of exposing a Unix domain socket to the Sansshell server. In this commit, we enhance the authentication and authorization possibilities around this new method of communication. On Linux systems, we can get information about the process which has initiated a connection over a Unix socket by means of the `getsockopt` call with the `SO_PEERCRED` option. This way, we get the UID, GID and PID of the calling process. We pass this information into the input structure of the OPA rules engine, so that rules can be written to only allow certain local users to access a particular Sansshell gRPC method, for example. An equivalent mechanism via the `getsockopt` option `LOCAL_PEERCRED`` is available on Darwin (macOS) systems, so we include Unix credentials in the auth input there as well for the sake of completeness. This also aids testability during development. In addition to the numeric UID and GID values in the credentials of a peer talking to the Sansshell server over a Unix socket, we provide the human-readable user and group names. This will enable writing more reader-friendly OPA rules based on the Unix credentials. --- auth/opa/rpcauth/input.go | 32 +++++++ auth/opa/rpcauth/rpcauth_test.go | 35 +++++++ auth/opa/rpcauth/unix_peer_auth_info.go | 40 ++++++++ cmd/sansshell-server/server/server.go | 2 +- server/server_test.go | 120 ++++++++++++++++++++++++ server/unix_peer_credentials.go | 85 +++++++++++++++++ server/unix_peer_credentials_darwin.go | 89 ++++++++++++++++++ server/unix_peer_credentials_linux.go | 96 +++++++++++++++++++ 8 files changed, 498 insertions(+), 1 deletion(-) create mode 100644 auth/opa/rpcauth/unix_peer_auth_info.go create mode 100644 server/unix_peer_credentials.go create mode 100644 server/unix_peer_credentials_darwin.go create mode 100644 server/unix_peer_credentials_linux.go diff --git a/auth/opa/rpcauth/input.go b/auth/opa/rpcauth/input.go index 4b20dd2b..bfff6908 100644 --- a/auth/opa/rpcauth/input.go +++ b/auth/opa/rpcauth/input.go @@ -84,6 +84,9 @@ type PeerAuthInput struct { // Network information about the peer Net *NetAuthInput `json:"net"` + // Unix peer credentials if peer connects via Unix socket, nil otherwise + Unix *UnixAuthInput `json:"unix"` + // Information about the certificate presented by the peer, if any Cert *CertAuthInput `json:"cert"` @@ -103,6 +106,21 @@ type NetAuthInput struct { Port string `json:"port"` } +// UnixAuthInput contains information about a Unix socket peer. +type UnixAuthInput struct { + // The user ID of the peer. + Uid int `json:"uid"` + + // The username of the peer, or the stringified UID if user is not known. + UserName string `json:"username"` + + // The group IDs (primary and supplementary) of the peer. + Gids []int `json:"gids"` + + // The group names of the peer. If not available, the stringified IDs is used. + GroupNames []string `json:"groupnames"` +} + // HostAuthInput contains policy-relevant information about the system receiving // an RPC type HostAuthInput struct { @@ -189,6 +207,7 @@ func PeerInputFromContext(ctx context.Context) *PeerAuthInput { } out.Net = NetInputFromAddr(p.Addr) + out.Unix = UnixInputFrom(p.AuthInfo) out.Cert = CertInputFrom(p.AuthInfo) // If this runs after rpcauth hooks, we can return richer data that includes @@ -222,6 +241,19 @@ func NetInputFromAddr(addr net.Addr) *NetAuthInput { return out } +// UnixInputFrom returns UnixAuthInput from the supplied credentials, if available. +func UnixInputFrom(authInfo credentials.AuthInfo) *UnixAuthInput { + if unixInfo, ok := authInfo.(UnixPeerAuthInfo); ok { + return &UnixAuthInput{ + Uid: unixInfo.Credentials.Uid, + UserName: unixInfo.Credentials.UserName, + Gids: unixInfo.Credentials.Gids, + GroupNames: unixInfo.Credentials.GroupNames, + } + } + return nil +} + // CertInputFrom populates certificate information from the supplied // credentials, if available. func CertInputFrom(authInfo credentials.AuthInfo) *CertAuthInput { diff --git a/auth/opa/rpcauth/rpcauth_test.go b/auth/opa/rpcauth/rpcauth_test.go index 8c14a0b9..0e921d90 100644 --- a/auth/opa/rpcauth/rpcauth_test.go +++ b/auth/opa/rpcauth/rpcauth_test.go @@ -552,6 +552,41 @@ func TestRpcAuthInput(t *testing.T) { }, }, }, + { + name: "method and a peer context with unix creds", + ctx: peer.NewContext(ctx, &peer.Peer{ + Addr: &net.UnixAddr{Net: "unix", Name: "@"}, + AuthInfo: UnixPeerAuthInfo{ + CommonAuthInfo: credentials.CommonAuthInfo{ + SecurityLevel: credentials.NoSecurity, + }, + Credentials: UnixPeerCredentials{ + Uid: 1, + UserName: "george", + Gids: []int{1001, 2}, + GroupNames: []string{"george", "the_gang"}, + }, + }, + }), + method: "/AMethod", + compare: &RPCAuthInput{ + Method: "/AMethod", + Peer: &PeerAuthInput{ + Net: &NetAuthInput{ + Network: "unix", + Address: "@", + Port: "", + }, + Unix: &UnixAuthInput{ + Uid: 1, + UserName: "george", + Gids: []int{1001, 2}, + GroupNames: []string{"george", "the_gang"}, + }, + Cert: &CertAuthInput{}, + }, + }, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { diff --git a/auth/opa/rpcauth/unix_peer_auth_info.go b/auth/opa/rpcauth/unix_peer_auth_info.go new file mode 100644 index 00000000..a9d7d286 --- /dev/null +++ b/auth/opa/rpcauth/unix_peer_auth_info.go @@ -0,0 +1,40 @@ +/* Copyright (c) 2024 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package rpcauth + +import ( + "google.golang.org/grpc/credentials" +) + +// UnixPeerCreds represents the credentials of a Unix peer. +type UnixPeerCredentials struct { + Uid int + Gids []int // Primary and supplementary group IDs. + UserName string + GroupNames []string +} + +// UnixPeerAuthInfo contains the authentication information for a Unix peer, +// in a form suitable for authentication info returned by gRPC transport credentials. +type UnixPeerAuthInfo struct { + credentials.CommonAuthInfo + Credentials UnixPeerCredentials +} + +func (UnixPeerAuthInfo) AuthType() string { + return "insecure_with_unix_creds" +} diff --git a/cmd/sansshell-server/server/server.go b/cmd/sansshell-server/server/server.go index 2d6a9a97..1e47df2d 100644 --- a/cmd/sansshell-server/server/server.go +++ b/cmd/sansshell-server/server/server.go @@ -401,7 +401,7 @@ func runTCPServer(ctx context.Context, rs *runState) error { func runInsecureUnixSocketServer(_ context.Context, rs *runState) error { serverOpts := extractCommonOptionsFromRunState(rs) - serverOpts = append(serverOpts, server.WithInsecure()) + serverOpts = append(serverOpts, server.WithCredentials(server.NewUnixPeerTransportCredentials())) return server.ServeUnix(rs.unixSocket, rs.unixSocketConfigHook, serverOpts...) } diff --git a/server/server_test.go b/server/server_test.go index 66ce2fea..6118a5e4 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -19,10 +19,13 @@ package server import ( "bytes" "context" + "fmt" "io" "log" "net" "os" + "os/user" + "strconv" "strings" "testing" "time" @@ -31,8 +34,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" + "google.golang.org/protobuf/types/known/emptypb" "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" + hcpb "github.com/Snowflake-Labs/sansshell/services/healthcheck" _ "github.com/Snowflake-Labs/sansshell/services/healthcheck/server" lfpb "github.com/Snowflake-Labs/sansshell/services/localfile" _ "github.com/Snowflake-Labs/sansshell/services/localfile/server" @@ -174,6 +179,121 @@ func TestServeUnix(t *testing.T) { testutil.FatalOnErr("ServeUnix with socket config hook", err, t) } +func TestServerWithUnixCredentials(t *testing.T) { + socketPath := t.TempDir() + "/test.sock" + policyTemplateWithUnixCreds := ` +package sansshell.authz + +default allow = false + +allow { + %s + input.method = "/HealthCheck.HealthCheck/Ok" +} +` + + // This function produces an on-server-start listener which connects to the + // server over the Unix socket and calls a gRPC method. + runHealthCheck := func(t *testing.T, expectedSuccess bool) func(*grpc.Server) { + return func(s *grpc.Server) { + defer s.Stop() + + conn, err := grpc.NewClient("passthrough:///unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials())) + testutil.FatalOnErr("Failed to dial bufnet", err, t) + defer conn.Close() + + client := hcpb.NewHealthCheckClient(conn) + _, err = client.Ok(context.Background(), &emptypb.Empty{}) + if expectedSuccess { + testutil.FatalOnErr("Failed to call Ok", err, t) + } else { + testutil.FatalOnNoErr("Ok should have failed", err, t) + } + } + } + + // In the test environment, the process connecting over the Unix socket + // will be the same process as the one running the server. We can rely on + // that to get the expected values for the "peer's" Unix credentials. + currentUser, err := user.Current() + testutil.FatalOnErr("Failed to get current user", err, t) + currentUid, err := strconv.Atoi(currentUser.Uid) + testutil.FatalOnErr("Failed to convert current user UID to int", err, t) + + // We will check that both the primary and supplementary groups can be + // used in OPA policies. + groupIdStrings, err := currentUser.GroupIds() + testutil.FatalOnErr("Failed to get group IDs of current user", err, t) + groupNameStrings := []string{} + for _, groupIdString := range groupIdStrings { + groupInfo, err := user.LookupGroupId(groupIdString) + testutil.FatalOnErr("Failed to get group info", err, t) + groupNameStrings = append(groupNameStrings, groupInfo.Name) + } + + for _, tc := range []struct { + name string + policyFragment string + expectedSuccess bool + }{ + { + name: "UID match", + policyFragment: fmt.Sprintf("input.peer.unix.uid == %d", currentUid), + expectedSuccess: true, + }, + { + name: "UID mismatch", + policyFragment: fmt.Sprintf("input.peer.unix.uid == %d", currentUid+1), + expectedSuccess: false, + }, + { + name: "username match", + policyFragment: fmt.Sprintf("input.peer.unix.username == \"%s\"", currentUser.Username), + expectedSuccess: true, + }, + { + name: "username mismatch", + policyFragment: fmt.Sprintf("input.peer.unix.username == \"%s\"", currentUser.Username+"x"), + expectedSuccess: false, + }, + { + name: "primary GID match", + policyFragment: fmt.Sprintf("input.peer.unix.gids[_] == %s", groupIdStrings[0]), + expectedSuccess: true, + }, + { + name: "supplementary GID match", + policyFragment: fmt.Sprintf("input.peer.unix.gids[_] == %s", groupIdStrings[len(groupIdStrings)-1]), + expectedSuccess: true, + }, + { + name: "primary group name match", + policyFragment: fmt.Sprintf("input.peer.unix.groupnames[_] == \"%s\"", groupNameStrings[0]), + expectedSuccess: true, + }, + { + name: "supplementary group name match", + policyFragment: fmt.Sprintf("input.peer.unix.groupnames[_] == \"%s\"", groupNameStrings[len(groupNameStrings)-1]), + expectedSuccess: true, + }, + { + name: "group name mismatch", + policyFragment: fmt.Sprintf("input.peer.unix.groupnames[_] == \"%s\"", groupNameStrings[0]+"x"), + expectedSuccess: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + policy := fmt.Sprintf(policyTemplateWithUnixCreds, tc.policyFragment) + err := ServeUnix(socketPath, + nil, + WithPolicy(policy), + WithCredentials(NewUnixPeerTransportCredentials()), + WithOnStartListener(runHealthCheck(t, tc.expectedSuccess))) + testutil.FatalOnErr("ServeUnix with Unix creds policy", err, t) + }) + } +} + func TestRead(t *testing.T) { var err error ctx := context.Background() diff --git a/server/unix_peer_credentials.go b/server/unix_peer_credentials.go new file mode 100644 index 00000000..1b7b8953 --- /dev/null +++ b/server/unix_peer_credentials.go @@ -0,0 +1,85 @@ +/* Copyright (c) 2024 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package server + +import ( + "context" + "fmt" + "net" + + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +// unixPeerTransportCredentials is a TransportCredentials implementation that fetches the +// peer's credentials from the Unix domain socket. Otherwise, the channel is insecure (no TLS). +type unixPeerTransportCredentials struct { + insecureCredentials credentials.TransportCredentials +} + +func (uc *unixPeerTransportCredentials) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return uc.insecureCredentials.ClientHandshake(ctx, authority, conn) +} + +func (uc *unixPeerTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + conn, insecureAuthInfo, err := uc.insecureCredentials.ServerHandshake(conn) + if err != nil { + return nil, nil, err + } + + unixCreds, err := getUnixPeerCredentials(conn) + if err != nil { + return nil, nil, fmt.Errorf("failed to get unix peer credentials: %w", err) + } + if unixCreds == nil { + // This means Unix credentials are not available (not a Unix system). + // We treat this connection as a basic insecure connection, with the + // authentication info coming from the 'insecure' module. + return conn, insecureAuthInfo, nil + } + + unixPeerAuthInfo := rpcauth.UnixPeerAuthInfo{ + CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}, + Credentials: *unixCreds, + } + return conn, unixPeerAuthInfo, nil +} + +func (uc *unixPeerTransportCredentials) Info() credentials.ProtocolInfo { + return uc.insecureCredentials.Info() +} + +func (uc *unixPeerTransportCredentials) Clone() credentials.TransportCredentials { + return &unixPeerTransportCredentials{ + insecureCredentials: uc.insecureCredentials.Clone(), + } +} + +func (uc *unixPeerTransportCredentials) OverrideServerName(serverName string) error { + // This is the same as the insecure implementation, but does not use + // its deprecated method. + return nil +} + +// NewUnixPeerCredentials returns a new TransportCredentials that disables transport security, +// but fetches the peer's credentials from the Unix domain socket. +func NewUnixPeerTransportCredentials() credentials.TransportCredentials { + return &unixPeerTransportCredentials{ + insecureCredentials: insecure.NewCredentials(), + } +} diff --git a/server/unix_peer_credentials_darwin.go b/server/unix_peer_credentials_darwin.go new file mode 100644 index 00000000..e1074b0f --- /dev/null +++ b/server/unix_peer_credentials_darwin.go @@ -0,0 +1,89 @@ +//go:build darwin + +/* Copyright (c) 2024 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package server + +import ( + "fmt" + "net" + "os/user" + "strconv" + + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" + "golang.org/x/sys/unix" +) + +// getUnixPeerCredentials indicates missing Unix credentials on non-Linux systems. +// +// This is needed so that the rpcauth package compiles on non-Linux systems, +// where Unix credentials cannot be fetched. +func getUnixPeerCredentials(conn net.Conn) (*rpcauth.UnixPeerCredentials, error) { + uc, ok := conn.(*net.UnixConn) + if !ok { + return nil, fmt.Errorf("called getUnixPeerCredentials on non-Unix connection") + } + + rawConn, err := uc.SyscallConn() + if err != nil { + return nil, fmt.Errorf("failed to get raw connection: %w", err) + } + + var ucred *unix.Xucred + err2 := rawConn.Control(func(fd uintptr) { + ucred, err = unix.GetsockoptXucred(int(fd), + unix.SOL_LOCAL, + unix.LOCAL_PEERCRED) + }) + if err != nil { + return nil, fmt.Errorf("failed to get peer credentials - getsockopt error: %w", err) + } + if err2 != nil { + return nil, fmt.Errorf("failed to get peer credentials - socket Control error: %w", err2) + } + + // Convert UID and GIDs to user & group names. If any lookup fails, use the numeric value. + uid := int(ucred.Uid) + userName := strconv.Itoa(uid) + userInfo, err := user.LookupId(userName) + if err == nil { + userName = userInfo.Username + } + + groupIds := []int{} + for _, groupId := range ucred.Groups[0:ucred.Ngroups] { + groupIds = append(groupIds, int(groupId)) + } + groupNames := []string{} + + for _, groupId := range groupIds { + groupIdString := strconv.Itoa(groupId) + groupInfo, err := user.LookupGroupId(groupIdString) + if err == nil { + groupNames = append(groupNames, groupInfo.Name) + } else { + groupNames = append(groupNames, groupIdString) + } + } + + return &rpcauth.UnixPeerCredentials{ + Uid: uid, + Gids: groupIds, + UserName: userName, + GroupNames: groupNames, + }, nil +} diff --git a/server/unix_peer_credentials_linux.go b/server/unix_peer_credentials_linux.go new file mode 100644 index 00000000..8a83e0b3 --- /dev/null +++ b/server/unix_peer_credentials_linux.go @@ -0,0 +1,96 @@ +//go:build linux + +/* Copyright (c) 2024 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package server + +import ( + "fmt" + "net" + "os/user" + "strconv" + + "github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" + "golang.org/x/sys/unix" +) + +// getUnixPeerCredentials returns the peer's Unix credentials from the given network connection. +// +// The provided connection should be established over a Unix domain socket. +func getUnixPeerCredentials(conn net.Conn) (*rpcauth.UnixPeerCredentials, error) { + uc, ok := conn.(*net.UnixConn) + if !ok { + return nil, fmt.Errorf("called getUnixPeerCredentials on non-Unix connection") + } + + rawConn, err := uc.SyscallConn() + if err != nil { + return nil, fmt.Errorf("failed to get raw connection: %w", err) + } + + var ucred *unix.Ucred + err2 := rawConn.Control(func(fd uintptr) { + ucred, err = unix.GetsockoptUcred(int(fd), + unix.SOL_SOCKET, + unix.SO_PEERCRED) + }) + + if err != nil { + return nil, fmt.Errorf("failed to get peer credentials - getsockopt error: %w", err) + } + if err2 != nil { + return nil, fmt.Errorf("failed to get peer credentials - socket Control error: %w", err2) + } + + // Convert UID to user name, fetch primary & supplementary group IDs and convert them to group + // names. If any user/group lookup fails, use the numeric value. + uid := int(ucred.Uid) + userName := strconv.Itoa(uid) + userInfo, err := user.LookupId(userName) + if err == nil { + userName = userInfo.Username + } + + groupIdStrings, err := userInfo.GroupIds() + if err != nil { + return nil, fmt.Errorf("failed to get group IDs for user %s: %w", userName, err) + } + groupIds := []int{} + groupNames := []string{} + + for _, groupIdString := range groupIdStrings { + groupId, err := strconv.Atoi(groupIdString) + if err != nil { + return nil, fmt.Errorf("failed to convert group ID %s to int: %w", groupIdString, err) + } + groupIds = append(groupIds, groupId) + + groupInfo, err := user.LookupGroupId(groupIdString) + if err == nil { + groupNames = append(groupNames, groupInfo.Name) + } else { + groupNames = append(groupNames, groupIdString) + } + } + + return &rpcauth.UnixPeerCredentials{ + Uid: uid, + Gids: groupIds, + UserName: userName, + GroupNames: groupNames, + }, nil +}