Skip to content

Commit

Permalink
Merge pull request #10996 from cweld510:cweld/optionally-close-unix-s…
Browse files Browse the repository at this point in the history
…ockets-on-save

PiperOrigin-RevId: 684217787
  • Loading branch information
gvisor-bot committed Oct 9, 2024
2 parents 62eaadc + befd16e commit 41c56d4
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 37 deletions.
12 changes: 6 additions & 6 deletions pkg/sentry/fsimpl/gofer/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type endpoint struct {
}

// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint), opts transport.UnixSocketOpts) *syserr.Error {
// No lock ordering required as only the ConnectingEndpoint has a mutex.
ce.Lock()

Expand All @@ -68,7 +68,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec
return syserr.ErrInvalidEndpointState
}

c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue())
c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue(), opts)
if err != nil {
ce.Unlock()
return err
Expand All @@ -85,8 +85,8 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec

// UnidirectionalConnect implements
// transport.BoundEndpoint.UnidirectionalConnect.
func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{})
func (e *endpoint) UnidirectionalConnect(ctx context.Context, opts transport.UnixSocketOpts) (transport.ConnectedEndpoint, *syserr.Error) {
c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{}, opts)
if err != nil {
return nil, err
}
Expand All @@ -102,15 +102,15 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
return c, nil
}

func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue) (*transport.SCMConnectedEndpoint, *syserr.Error) {
func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue, opts transport.UnixSocketOpts) (*transport.SCMConnectedEndpoint, *syserr.Error) {
e.dentry.fs.renameMu.RLock()
hostSockFD, err := e.dentry.connect(ctx, sockType)
e.dentry.fs.renameMu.RUnlock()
if err != nil {
return nil, syserr.ErrConnectionRefused
}

c, serr := transport.NewSCMEndpoint(hostSockFD, queue, e.path)
c, serr := transport.NewSCMEndpoint(hostSockFD, queue, e.path, opts)
if serr != nil {
unix.Close(hostSockFD)
log.Warningf("NewSCMEndpoint failed: path=%q, err=%v", e.path, serr)
Expand Down
1 change: 1 addition & 0 deletions pkg/sentry/fsimpl/testutil/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ go_library(
"//pkg/sentry/platform/kvm",
"//pkg/sentry/platform/ptrace",
"//pkg/sentry/seccheck",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/time",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
Expand Down
2 changes: 2 additions & 0 deletions pkg/sentry/fsimpl/testutil/kernel.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/seccheck"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
Expand Down Expand Up @@ -106,6 +107,7 @@ func Boot() (*kernel.Kernel, error) {
RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace),
RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace),
PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace),
UnixSocketOpts: transport.UnixSocketOpts{},
}); err != nil {
return nil, fmt.Errorf("initializing kernel: %v", err)
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/sentry/kernel/kernel.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
Expand Down Expand Up @@ -387,6 +388,9 @@ type Kernel struct {
// attempt succeeded, after which at least one more checkpoint attempt was
// made and failed with this error. It's protected by checkpointMu.
lastCheckpointStatus error `state:"nosave"`

// UnixSocketOpts stores configuration options for management of unix sockets.
UnixSocketOpts transport.UnixSocketOpts
}

// Saver is an interface for saving the kernel.
Expand Down Expand Up @@ -445,6 +449,9 @@ type InitKernelArgs struct {
// used by processes. If it is zero, the limit will be set to
// unlimited.
MaxFDLimit int32

// UnixSocketOpts contains configuration options for unix sockets.
UnixSocketOpts transport.UnixSocketOpts
}

// Init initialize the Kernel with no tasks.
Expand Down Expand Up @@ -567,6 +574,7 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.sockets = make(map[*vfs.FileDescription]*SocketRecord)

k.cgroupRegistry = newCgroupRegistry()
k.UnixSocketOpts = args.UnixSocketOpts
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/sentry/socket/netlink/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func New(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *sy
}

// Create a connection from which the kernel can write messages.
connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t, t.Kernel().UnixSocketOpts)
if err != nil {
ep.Close(t)
return nil, err
Expand Down
18 changes: 9 additions & 9 deletions pkg/sentry/socket/unix/transport/connectioned.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func (e *connectionedEndpoint) Close(ctx context.Context) {
}

// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint), opts UnixSocketOpts) *syserr.Error {
if ce.Type() != e.stype {
return syserr.ErrWrongProtocolForSocket
}
Expand Down Expand Up @@ -378,13 +378,13 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
}

// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context, opts UnixSocketOpts) (ConnectedEndpoint, *syserr.Error) {
return nil, syserr.ErrConnectionRefused
}

// Connect attempts to directly connect to another Endpoint.
// Implements Endpoint.Connect.
func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint, opts UnixSocketOpts) *syserr.Error {
returnConnect := func(r Receiver, ce ConnectedEndpoint) {
e.receiver = r
e.connected = ce
Expand All @@ -396,7 +396,7 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint
}
}

return server.BidirectionalConnect(ctx, e, returnConnect)
return server.BidirectionalConnect(ctx, e, returnConnect, opts)
}

// Listen starts listening on the connection.
Expand All @@ -405,7 +405,7 @@ func (e *connectionedEndpoint) Listen(ctx context.Context, backlog int) *syserr.
defer e.Unlock()
if e.ListeningLocked() {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
// pending connections into the new one
if len(e.acceptedChan) > backlog {
return syserr.ErrInvalidEndpointState
}
Expand Down Expand Up @@ -438,15 +438,15 @@ func (e *connectionedEndpoint) Listen(ctx context.Context, backlog int) *syserr.
}

// Accept accepts a new connection.
func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address) (Endpoint, *syserr.Error) {
func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address, opts UnixSocketOpts) (Endpoint, *syserr.Error) {
e.Lock()

if !e.ListeningLocked() {
e.Unlock()
return nil, syserr.ErrInvalidEndpointState
}

ne, err := e.getAcceptedEndpointLocked(ctx)
ne, err := e.getAcceptedEndpointLocked(ctx, opts)
e.Unlock()
if err != nil {
return nil, err
Expand All @@ -470,7 +470,7 @@ func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address) (E
// Preconditions:
// - e.Listening()
// - e is locked.
func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context) (*connectionedEndpoint, *syserr.Error) {
func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context, opts UnixSocketOpts) (*connectionedEndpoint, *syserr.Error) {
// Accept connections from within the sentry first, since this avoids
// an RPC to the gofer on the common path.
select {
Expand All @@ -493,7 +493,7 @@ func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context) (*
return nil, syserr.FromError(err)
}
q := &waiter.Queue{}
scme, serr := NewSCMEndpoint(nfd, q, e.path)
scme, serr := NewSCMEndpoint(nfd, q, e.path, opts)
if serr != nil {
unix.Close(nfd)
return nil, serr
Expand Down
13 changes: 7 additions & 6 deletions pkg/sentry/socket/unix/transport/connectionless.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ func (e *connectionlessEndpoint) Close(ctx context.Context) {
}

// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint), opts UnixSocketOpts) *syserr.Error {
return syserr.ErrConnectionRefused
}

// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context, opts UnixSocketOpts) (ConnectedEndpoint, *syserr.Error) {
e.Lock()
r := e.receiver
e.Unlock()
Expand All @@ -107,7 +107,8 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C
return e.baseEndpoint.SendMsg(ctx, data, c, nil)
}

connected, err := to.UnidirectionalConnect(ctx)
opts := UnixSocketOpts{}
connected, err := to.UnidirectionalConnect(ctx, opts)
if err != nil {
return 0, nil, syserr.ErrInvalidEndpointState
}
Expand All @@ -131,8 +132,8 @@ func (e *connectionlessEndpoint) Type() linux.SockType {
}

// Connect attempts to connect directly to server.
func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
connected, err := server.UnidirectionalConnect(ctx)
func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint, opts UnixSocketOpts) *syserr.Error {
connected, err := server.UnidirectionalConnect(ctx, opts)
if err != nil {
return err
}
Expand All @@ -153,7 +154,7 @@ func (*connectionlessEndpoint) Listen(context.Context, int) *syserr.Error {
}

// Accept accepts a new connection.
func (*connectionlessEndpoint) Accept(context.Context, *Address) (Endpoint, *syserr.Error) {
func (*connectionlessEndpoint) Accept(context.Context, *Address, UnixSocketOpts) (Endpoint, *syserr.Error) {
return nil, syserr.ErrNotSupported
}

Expand Down
60 changes: 56 additions & 4 deletions pkg/sentry/socket/unix/transport/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ func (c *HostConnectedEndpoint) init() *syserr.Error {
}

func (c *HostConnectedEndpoint) initFromOptions() *syserr.Error {
if c.fd < 0 {
// There is no underlying FD to restore; nothing to do
return nil
}

family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
if err != nil {
return syserr.FromError(err)
Expand Down Expand Up @@ -163,6 +168,10 @@ func (c *HostConnectedEndpoint) Send(ctx context.Context, data [][]byte, control
return 0, false, syserr.ErrInvalidEndpointState
}

if c.IsSendClosed() {
return 0, false, syserr.ErrClosedForSend
}

// Since stream sockets don't preserve message boundaries, we can write
// only as much of the message as fits in the send buffer.
truncate := c.stype == linux.SOCK_STREAM
Expand Down Expand Up @@ -192,6 +201,14 @@ func (c *HostConnectedEndpoint) SendNotify() {}
func (c *HostConnectedEndpoint) CloseSend() {
c.mu.Lock()
defer c.mu.Unlock()
c.closeSendLocked()
}

// Preconditions: c.mu must be held.
func (c *HostConnectedEndpoint) closeSendLocked() {
if c.IsSendClosed() {
return
}

if err := unix.Shutdown(c.fd, unix.SHUT_WR); err != nil {
// A well-formed UDS shutdown can't fail. See
Expand Down Expand Up @@ -300,6 +317,14 @@ func (c *HostConnectedEndpoint) RecvNotify() {}
func (c *HostConnectedEndpoint) CloseRecv() {
c.mu.Lock()
defer c.mu.Unlock()
c.closeRecvLocked()
}

// Preconditions: c.mu must be held.
func (c *HostConnectedEndpoint) closeRecvLocked() {
if c.IsRecvClosed() {
return
}

if err := unix.Shutdown(c.fd, unix.SHUT_RD); err != nil {
// A well-formed UDS shutdown can't fail. See
Expand Down Expand Up @@ -382,13 +407,34 @@ func (c *HostConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
// SCMConnectedEndpoint represents an endpoint backed by a host fd that was
// passed through a gofer Unix socket. It resembles HostConnectedEndpoint, with the
// following differences:
// - SCMConnectedEndpoint is not saveable, because the host cannot guarantee
// the same descriptor number across S/R.
// - SCMConnectedEndpoint is not saveable by default, because the host
// cannot guarantee the same descriptor number across S/R.
// However, it can optionally be placed in a closed state before save.
// - SCMConnectedEndpoint holds ownership of its fd and notification queue.
//
// +stateify savable
type SCMConnectedEndpoint struct {
HostConnectedEndpoint

queue *waiter.Queue
opts UnixSocketOpts
}

// beforeSave is invoked by stateify.
func (e *SCMConnectedEndpoint) beforeSave() {
if !e.opts.DisconnectOnSave {
panic("socket cannot be saved in a connected state")
}

e.mu.Lock()
defer e.mu.Unlock()
fdnotifier.RemoveFD(int32(e.fd))
e.closeRecvLocked()
e.closeSendLocked()
if err := unix.Close(e.fd); err != nil {
log.Warningf("Failed to close host fd %d: %v", err)
}
e.destroyLocked()
}

// Init will do the initialization required without holding other locks.
Expand All @@ -400,12 +446,17 @@ func (e *SCMConnectedEndpoint) Init() error {
func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
e.DecRef(func() {
e.mu.Lock()
defer e.mu.Unlock()

if e.fd < 0 {
return
}

fdnotifier.RemoveFD(int32(e.fd))
if err := unix.Close(e.fd); err != nil {
log.Warningf("Failed to close host fd %d: %v", err)
}
e.destroyLocked()
e.mu.Unlock()
})
}

Expand All @@ -415,13 +466,14 @@ func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
// The caller is responsible for calling Init(). Additionally, Release needs to
// be called twice because ConnectedEndpoint is both a Receiver and
// ConnectedEndpoint.
func NewSCMEndpoint(hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) {
func NewSCMEndpoint(hostFD int, queue *waiter.Queue, addr string, opts UnixSocketOpts) (*SCMConnectedEndpoint, *syserr.Error) {
e := SCMConnectedEndpoint{
HostConnectedEndpoint: HostConnectedEndpoint{
fd: hostFD,
addr: addr,
},
queue: queue,
opts: opts,
}

if err := e.init(); err != nil {
Expand Down
Loading

0 comments on commit 41c56d4

Please sign in to comment.