From 727bc9c72acc31c61d3eb015a02c2e6f028dd288 Mon Sep 17 00:00:00 2001 From: cweld510 Date: Wed, 2 Oct 2024 15:24:31 +0000 Subject: [PATCH 1/3] Add and implement option to close unsaveable gofer-backed unix sockets on save --- pkg/sentry/fsimpl/gofer/socket.go | 12 ++-- pkg/sentry/fsimpl/testutil/kernel.go | 2 + pkg/sentry/kernel/kernel.go | 10 ++- pkg/sentry/socket/netlink/socket.go | 2 +- .../socket/unix/transport/connectioned.go | 18 ++--- .../socket/unix/transport/connectionless.go | 13 ++-- pkg/sentry/socket/unix/transport/host.go | 68 +++++++++++++++++-- pkg/sentry/socket/unix/transport/unix.go | 17 +++-- pkg/sentry/socket/unix/unix.go | 6 +- runsc/boot/BUILD | 1 + runsc/boot/loader.go | 5 ++ 11 files changed, 120 insertions(+), 34 deletions(-) diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index 34c34f6a6b..64e9a995a6 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -54,7 +54,7 @@ type endpoint struct { } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. -func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error { +func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint), opts transport.UnixSocketOpts) *syserr.Error { // No lock ordering required as only the ConnectingEndpoint has a mutex. ce.Lock() @@ -68,7 +68,7 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec return syserr.ErrInvalidEndpointState } - c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue()) + c, err := e.newConnectedEndpoint(ctx, ce.Type(), ce.WaiterQueue(), opts) if err != nil { ce.Unlock() return err @@ -85,8 +85,8 @@ func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.Connec // UnidirectionalConnect implements // transport.BoundEndpoint.UnidirectionalConnect. -func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) { - c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{}) +func (e *endpoint) UnidirectionalConnect(ctx context.Context, opts transport.UnixSocketOpts) (transport.ConnectedEndpoint, *syserr.Error) { + c, err := e.newConnectedEndpoint(ctx, linux.SOCK_DGRAM, &waiter.Queue{}, opts) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect return c, nil } -func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue) (*transport.SCMConnectedEndpoint, *syserr.Error) { +func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.SockType, queue *waiter.Queue, opts transport.UnixSocketOpts) (*transport.SCMConnectedEndpoint, *syserr.Error) { e.dentry.fs.renameMu.RLock() hostSockFD, err := e.dentry.connect(ctx, sockType) e.dentry.fs.renameMu.RUnlock() @@ -110,7 +110,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, sockType linux.Sock return nil, syserr.ErrConnectionRefused } - c, serr := transport.NewSCMEndpoint(hostSockFD, queue, e.path) + c, serr := transport.NewSCMEndpoint(hostSockFD, queue, e.path, opts) if serr != nil { unix.Close(hostSockFD) log.Warningf("NewSCMEndpoint failed: path=%q, err=%v", e.path, serr) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 19f53f2f14..3272cfa6e2 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/seccheck" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -106,6 +107,7 @@ func Boot() (*kernel.Kernel, error) { RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace), RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace), PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace), + UnixSocketOpts: transport.UnixSocketOpts{}, }); err != nil { return nil, fmt.Errorf("initializing kernel: %v", err) } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 72703b42d4..232afa40c9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -70,6 +70,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" + unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/unimpl" uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" @@ -387,6 +388,9 @@ type Kernel struct { // attempt succeeded, after which at least one more checkpoint attempt was // made and failed with this error. It's protected by checkpointMu. lastCheckpointStatus error `state:"nosave"` + + // UnixSocketOpts stores configuration options for management of unix sockets. + UnixSocketOpts unixsocket.UnixSocketOpts } // Saver is an interface for saving the kernel. @@ -445,13 +449,16 @@ type InitKernelArgs struct { // used by processes. If it is zero, the limit will be set to // unlimited. MaxFDLimit int32 + + // UnixSocketOpts contains configuration options for unix sockets. + UnixSocketOpts unixsocket.UnixSocketOpts } // Init initialize the Kernel with no tasks. // // Callers must manually set Kernel.Platform and call Kernel.SetMemoryFile // before calling Init. -func (k *Kernel) Init(args InitKernelArgs) error { +func (k *Kernel) Init(args InitKernelArgs) error { // TODO (colin) propagate up if args.Timekeeper == nil { return fmt.Errorf("args.Timekeeper is nil") } @@ -567,6 +574,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.sockets = make(map[*vfs.FileDescription]*SocketRecord) k.cgroupRegistry = newCgroupRegistry() + k.UnixSocketOpts = args.UnixSocketOpts return nil } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index e84ca545eb..fd7fd3e96a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -131,7 +131,7 @@ func New(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *sy } // Create a connection from which the kernel can write messages. - connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) + connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t, t.Kernel().UnixSocketOpts) if err != nil { ep.Close(t) return nil, err diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 4c57804f22..deda6ac71f 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -275,7 +275,7 @@ func (e *connectionedEndpoint) Close(ctx context.Context) { } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. -func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { +func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint), opts UnixSocketOpts) *syserr.Error { if ce.Type() != e.stype { return syserr.ErrWrongProtocolForSocket } @@ -378,13 +378,13 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn } // UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. -func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) { +func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context, opts UnixSocketOpts) (ConnectedEndpoint, *syserr.Error) { return nil, syserr.ErrConnectionRefused } // Connect attempts to directly connect to another Endpoint. // Implements Endpoint.Connect. -func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error { +func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint, opts UnixSocketOpts) *syserr.Error { returnConnect := func(r Receiver, ce ConnectedEndpoint) { e.receiver = r e.connected = ce @@ -396,7 +396,7 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint } } - return server.BidirectionalConnect(ctx, e, returnConnect) + return server.BidirectionalConnect(ctx, e, returnConnect, opts) } // Listen starts listening on the connection. @@ -405,7 +405,7 @@ func (e *connectionedEndpoint) Listen(ctx context.Context, backlog int) *syserr. defer e.Unlock() if e.ListeningLocked() { // Adjust the size of the channel iff we can fix existing - // pending connections into the new one. + // pending connections into the new one if len(e.acceptedChan) > backlog { return syserr.ErrInvalidEndpointState } @@ -438,7 +438,7 @@ func (e *connectionedEndpoint) Listen(ctx context.Context, backlog int) *syserr. } // Accept accepts a new connection. -func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address) (Endpoint, *syserr.Error) { +func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address, opts UnixSocketOpts) (Endpoint, *syserr.Error) { e.Lock() if !e.ListeningLocked() { @@ -446,7 +446,7 @@ func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address) (E return nil, syserr.ErrInvalidEndpointState } - ne, err := e.getAcceptedEndpointLocked(ctx) + ne, err := e.getAcceptedEndpointLocked(ctx, opts) e.Unlock() if err != nil { return nil, err @@ -470,7 +470,7 @@ func (e *connectionedEndpoint) Accept(ctx context.Context, peerAddr *Address) (E // Preconditions: // - e.Listening() // - e is locked. -func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context) (*connectionedEndpoint, *syserr.Error) { +func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context, opts UnixSocketOpts) (*connectionedEndpoint, *syserr.Error) { // Accept connections from within the sentry first, since this avoids // an RPC to the gofer on the common path. select { @@ -493,7 +493,7 @@ func (e *connectionedEndpoint) getAcceptedEndpointLocked(ctx context.Context) (* return nil, syserr.FromError(err) } q := &waiter.Queue{} - scme, serr := NewSCMEndpoint(nfd, q, e.path) + scme, serr := NewSCMEndpoint(nfd, q, e.path, opts) if serr != nil { unix.Close(nfd) return nil, serr diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 8b5b4ecd1c..7890c73dd5 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -78,12 +78,12 @@ func (e *connectionlessEndpoint) Close(ctx context.Context) { } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. -func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { +func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint), opts UnixSocketOpts) *syserr.Error { return syserr.ErrConnectionRefused } // UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. -func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) { +func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context, opts UnixSocketOpts) (ConnectedEndpoint, *syserr.Error) { e.Lock() r := e.receiver e.Unlock() @@ -107,7 +107,8 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C return e.baseEndpoint.SendMsg(ctx, data, c, nil) } - connected, err := to.UnidirectionalConnect(ctx) + opts := UnixSocketOpts{} + connected, err := to.UnidirectionalConnect(ctx, opts) if err != nil { return 0, nil, syserr.ErrInvalidEndpointState } @@ -131,8 +132,8 @@ func (e *connectionlessEndpoint) Type() linux.SockType { } // Connect attempts to connect directly to server. -func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error { - connected, err := server.UnidirectionalConnect(ctx) +func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint, opts UnixSocketOpts) *syserr.Error { + connected, err := server.UnidirectionalConnect(ctx, opts) if err != nil { return err } @@ -153,7 +154,7 @@ func (*connectionlessEndpoint) Listen(context.Context, int) *syserr.Error { } // Accept accepts a new connection. -func (*connectionlessEndpoint) Accept(context.Context, *Address) (Endpoint, *syserr.Error) { +func (*connectionlessEndpoint) Accept(context.Context, *Address, UnixSocketOpts) (Endpoint, *syserr.Error) { return nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/unix/transport/host.go b/pkg/sentry/socket/unix/transport/host.go index a80bf10f12..d37a0bee5e 100644 --- a/pkg/sentry/socket/unix/transport/host.go +++ b/pkg/sentry/socket/unix/transport/host.go @@ -98,6 +98,12 @@ func (c *HostConnectedEndpoint) init() *syserr.Error { } func (c *HostConnectedEndpoint) initFromOptions() *syserr.Error { + + if c.fd < 0 { + // There is no underlying FD to restore; nothing to do + return nil + } + family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN) if err != nil { return syserr.FromError(err) @@ -163,6 +169,10 @@ func (c *HostConnectedEndpoint) Send(ctx context.Context, data [][]byte, control return 0, false, syserr.ErrInvalidEndpointState } + if c.IsSendClosed() { + return 0, false, syserr.ErrClosedForSend + } + // Since stream sockets don't preserve message boundaries, we can write // only as much of the message as fits in the send buffer. truncate := c.stype == linux.SOCK_STREAM @@ -192,6 +202,15 @@ func (c *HostConnectedEndpoint) SendNotify() {} func (c *HostConnectedEndpoint) CloseSend() { c.mu.Lock() defer c.mu.Unlock() + c.closeSendLocked() +} + +// Preconditions: c.mu must be held. +func (c *HostConnectedEndpoint) closeSendLocked() { + + if c.IsSendClosed() { + return + } if err := unix.Shutdown(c.fd, unix.SHUT_WR); err != nil { // A well-formed UDS shutdown can't fail. See @@ -245,6 +264,10 @@ func (c *HostConnectedEndpoint) Recv(ctx context.Context, data [][]byte, args Re c.mu.RLock() defer c.mu.RUnlock() + if c.IsRecvClosed() { + return RecvOutput{}, false, syserr.ErrClosedForReceive + } + var cm unet.ControlMessage if args.NumRights > 0 { cm.EnableFDs(int(args.NumRights)) @@ -300,6 +323,15 @@ func (c *HostConnectedEndpoint) RecvNotify() {} func (c *HostConnectedEndpoint) CloseRecv() { c.mu.Lock() defer c.mu.Unlock() + c.closeRecvLocked() +} + +// Preconditions: c.mu must be held. +func (c *HostConnectedEndpoint) closeRecvLocked() { + + if c.IsRecvClosed() { + return + } if err := unix.Shutdown(c.fd, unix.SHUT_RD); err != nil { // A well-formed UDS shutdown can't fail. See @@ -382,13 +414,34 @@ func (c *HostConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) { // SCMConnectedEndpoint represents an endpoint backed by a host fd that was // passed through a gofer Unix socket. It resembles HostConnectedEndpoint, with the // following differences: -// - SCMConnectedEndpoint is not saveable, because the host cannot guarantee -// the same descriptor number across S/R. +// - SCMConnectedEndpoint is not saveable by default, because the host +// cannot guarantee the same descriptor number across S/R. +// However, it can optionally be placed in a closed state before save. // - SCMConnectedEndpoint holds ownership of its fd and notification queue. +// +// +stateify savable type SCMConnectedEndpoint struct { HostConnectedEndpoint queue *waiter.Queue + opts UnixSocketOpts +} + +// beforeSave is invoked by stateify. +func (e *SCMConnectedEndpoint) beforeSave() { + if !e.opts.DisconnectOnSave { + panic("socket cannot be saved in a connected state") + } + + e.mu.Lock() + defer e.mu.Unlock() + fdnotifier.RemoveFD(int32(e.fd)) + e.closeRecvLocked() + e.closeSendLocked() + if err := unix.Close(e.fd); err != nil { + log.Warningf("Failed to close host fd %d: %v", err) + } + e.destroyLocked() } // Init will do the initialization required without holding other locks. @@ -400,12 +453,18 @@ func (e *SCMConnectedEndpoint) Init() error { func (e *SCMConnectedEndpoint) Release(ctx context.Context) { e.DecRef(func() { e.mu.Lock() + defer e.mu.Unlock() + + if e.fd == -1 { + return + } + fdnotifier.RemoveFD(int32(e.fd)) if err := unix.Close(e.fd); err != nil { log.Warningf("Failed to close host fd %d: %v", err) } e.destroyLocked() - e.mu.Unlock() + }) } @@ -415,13 +474,14 @@ func (e *SCMConnectedEndpoint) Release(ctx context.Context) { // The caller is responsible for calling Init(). Additionally, Release needs to // be called twice because ConnectedEndpoint is both a Receiver and // ConnectedEndpoint. -func NewSCMEndpoint(hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) { +func NewSCMEndpoint(hostFD int, queue *waiter.Queue, addr string, opts UnixSocketOpts) (*SCMConnectedEndpoint, *syserr.Error) { e := SCMConnectedEndpoint{ HostConnectedEndpoint: HostConnectedEndpoint{ fd: hostFD, addr: addr, }, queue: queue, + opts: opts, } if err := e.init(); err != nil { diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0fb070f00b..f83090bcf8 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -137,6 +137,15 @@ type RecvOutput struct { UnusedRights []RightsControlMessage } +// UnixSocketOpts is a container for configuration options for gvisor's management of +// unix sockets. +// +stateify savable +type UnixSocketOpts struct { + // If true, the endpoint will be put in a closed state before save; if false, an attempt to save + // will throw. + DisconnectOnSave bool +} + // Endpoint is the interface implemented by Unix transport protocol // implementations that expose functionality like sendmsg, recvmsg, connect, // etc. to Unix socket implementations. @@ -169,7 +178,7 @@ type Endpoint interface { // endpoint passed in as a parameter. // // The error codes are the same as Connect. - Connect(ctx context.Context, server BoundEndpoint) *syserr.Error + Connect(ctx context.Context, server BoundEndpoint, opts UnixSocketOpts) *syserr.Error // Shutdown closes the read and/or write end of the endpoint connection // to its peer. @@ -187,7 +196,7 @@ type Endpoint interface { // // peerAddr if not nil will be populated with the address of the connected // peer on a successful accept. - Accept(ctx context.Context, peerAddr *Address) (Endpoint, *syserr.Error) + Accept(ctx context.Context, peerAddr *Address, opts UnixSocketOpts) (Endpoint, *syserr.Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. @@ -262,7 +271,7 @@ type BoundEndpoint interface { // // This method will return syserr.ErrConnectionRefused on endpoints with a // type that isn't SockStream or SockSeqpacket. - BidirectionalConnect(ctx context.Context, ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error + BidirectionalConnect(ctx context.Context, ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint), opts UnixSocketOpts) *syserr.Error // UnidirectionalConnect establishes a write-only connection to a unix // endpoint. @@ -272,7 +281,7 @@ type BoundEndpoint interface { // // This method will return syserr.ErrConnectionRefused on a non-SockDgram // endpoint. - UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) + UnidirectionalConnect(ctx context.Context, opts UnixSocketOpts) (ConnectedEndpoint, *syserr.Error) // Passcred returns whether or not the SO_PASSCRED socket option is // enabled on this end. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index c7baf57939..7b38191f00 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -149,7 +149,7 @@ func (s *Socket) blockingAccept(t *kernel.Task, peerAddr *transport.Address) (tr // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(t, peerAddr); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(t, peerAddr, t.Kernel().UnixSocketOpts); err != syserr.ErrWouldBlock { return ep, err } @@ -166,7 +166,7 @@ func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking if peerRequested { peerAddr = &transport.Address{} } - ep, err := s.ep.Accept(t, peerAddr) + ep, err := s.ep.Accept(t, peerAddr, t.Kernel().UnixSocketOpts) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err @@ -582,7 +582,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr defer ep.Release(t) // Connect the server endpoint. - err = s.ep.Connect(t, ep) + err = s.ep.Connect(t, ep, t.Kernel().UnixSocketOpts) if err == syserr.ErrWrongProtocolForSocket { // Linux for abstract sockets returns ErrConnectionRefused diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 1931437963..fef0763de1 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -96,6 +96,7 @@ go_library( "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/plugin", "//pkg/sentry/socket/unix", + "//pkg/sentry/socket/unix/transport", "//pkg/sentry/state", "//pkg/sentry/strace", "//pkg/sentry/time", diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index aefd82fb04..4b09612ef7 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -56,6 +56,7 @@ import ( pb "gvisor.dev/gvisor/pkg/sentry/seccheck/points/points_go_proto" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/socket/plugin" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -583,6 +584,9 @@ func New(args Args) (*Loader, error) { } // Initiate the Kernel object, which is required by the Context passed // to createVFS in order to mount (among other things) procfs. + unixSocketOpts := transport.UnixSocketOpts{ + DisconnectOnSave: args.Conf.NetDisconnectOk, + } if err = l.k.Init(kernel.InitKernelArgs{ FeatureSet: cpuid.HostFeatureSet().Fixed(), Timekeeper: tk, @@ -595,6 +599,7 @@ func New(args Args) (*Loader, error) { RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace), PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace), MaxFDLimit: maxFDLimit, + UnixSocketOpts: unixSocketOpts, }); err != nil { return nil, fmt.Errorf("initializing kernel: %w", err) } From db4ffada100028c7e9b796d4d90442f59db22a70 Mon Sep 17 00:00:00 2001 From: cweld510 Date: Mon, 7 Oct 2024 22:39:13 +0000 Subject: [PATCH 2/3] style feedback: remove newlines, fix import, remove stray comment --- pkg/sentry/kernel/kernel.go | 8 ++++---- pkg/sentry/socket/unix/transport/host.go | 6 +----- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 232afa40c9..aa3ae8fd90 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -70,7 +70,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" - unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/unimpl" uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" @@ -390,7 +390,7 @@ type Kernel struct { lastCheckpointStatus error `state:"nosave"` // UnixSocketOpts stores configuration options for management of unix sockets. - UnixSocketOpts unixsocket.UnixSocketOpts + UnixSocketOpts transport.UnixSocketOpts } // Saver is an interface for saving the kernel. @@ -451,14 +451,14 @@ type InitKernelArgs struct { MaxFDLimit int32 // UnixSocketOpts contains configuration options for unix sockets. - UnixSocketOpts unixsocket.UnixSocketOpts + UnixSocketOpts transport.UnixSocketOpts } // Init initialize the Kernel with no tasks. // // Callers must manually set Kernel.Platform and call Kernel.SetMemoryFile // before calling Init. -func (k *Kernel) Init(args InitKernelArgs) error { // TODO (colin) propagate up +func (k *Kernel) Init(args InitKernelArgs) error { if args.Timekeeper == nil { return fmt.Errorf("args.Timekeeper is nil") } diff --git a/pkg/sentry/socket/unix/transport/host.go b/pkg/sentry/socket/unix/transport/host.go index d37a0bee5e..ddb67af68a 100644 --- a/pkg/sentry/socket/unix/transport/host.go +++ b/pkg/sentry/socket/unix/transport/host.go @@ -98,7 +98,6 @@ func (c *HostConnectedEndpoint) init() *syserr.Error { } func (c *HostConnectedEndpoint) initFromOptions() *syserr.Error { - if c.fd < 0 { // There is no underlying FD to restore; nothing to do return nil @@ -207,7 +206,6 @@ func (c *HostConnectedEndpoint) CloseSend() { // Preconditions: c.mu must be held. func (c *HostConnectedEndpoint) closeSendLocked() { - if c.IsSendClosed() { return } @@ -328,7 +326,6 @@ func (c *HostConnectedEndpoint) CloseRecv() { // Preconditions: c.mu must be held. func (c *HostConnectedEndpoint) closeRecvLocked() { - if c.IsRecvClosed() { return } @@ -455,7 +452,7 @@ func (e *SCMConnectedEndpoint) Release(ctx context.Context) { e.mu.Lock() defer e.mu.Unlock() - if e.fd == -1 { + if e.fd < 0 { return } @@ -464,7 +461,6 @@ func (e *SCMConnectedEndpoint) Release(ctx context.Context) { log.Warningf("Failed to close host fd %d: %v", err) } e.destroyLocked() - }) } From befd16ec5a740dfa9c14ad3772f7adc8f7383ec0 Mon Sep 17 00:00:00 2001 From: cweld510 Date: Mon, 7 Oct 2024 22:59:53 +0000 Subject: [PATCH 3/3] Update config/flags documentation --- runsc/config/config.go | 5 ++--- runsc/config/flags.go | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/runsc/config/config.go b/runsc/config/config.go index a7cecb85e1..934cf79a96 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -364,9 +364,8 @@ type Config struct { // present, and reproduce them in the sandbox. ReproduceNftables bool `flag:"reproduce-nftables"` - // NetDisconnectOk indicates whether the link endpoint capability - // CapabilityDisconnectOk should be set. This allows open connections to be - // disconnected upon save. + // Indicates whether open network connections and open unix domain + // sockets should be disconnected upon save." NetDisconnectOk bool `flag:"net-disconnect-ok"` // TestOnlyAutosaveImagePath if not empty enables auto save for syscall tests diff --git a/runsc/config/flags.go b/runsc/config/flags.go index d801a8b9bf..1efa70638b 100644 --- a/runsc/config/flags.go +++ b/runsc/config/flags.go @@ -128,7 +128,7 @@ func RegisterFlags(flagSet *flag.FlagSet) { flagSet.Bool("EXPERIMENTAL-xdp-need-wakeup", true, "EXPERIMENTAL. Use XDP_USE_NEED_WAKEUP with XDP sockets.") // TODO(b/240191988): Figure out whether this helps and remove it as a flag. flagSet.Bool("reproduce-nat", false, "Scrape the host netns NAT table and reproduce it in the sandbox.") flagSet.Bool("reproduce-nftables", false, "Attempt to scrape and reproduce nftable rules inside the sandbox. Overrides reproduce-nat when true.") - flagSet.Bool("net-disconnect-ok", false, "Indicates whether the link endpoint capability CapabilityDisconnectOk should be set. This allows open connections to be disconnected upon save.") + flagSet.Bool("net-disconnect-ok", false, "Indicates whether open network connections and open unix domain sockets should be disconnected upon save.") // Flags that control sandbox runtime behavior: accelerator related. flagSet.Bool("nvproxy", false, "EXPERIMENTAL: enable support for Nvidia GPUs")