Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions client/internal/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
Expand All @@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)

Expand Down Expand Up @@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}

<-engineCtx.Done()

c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()

if engine != nil && engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown()

backOff.Reset()
Expand Down Expand Up @@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
}

func (c *ConnectClient) Stop() error {
if c == nil {
return nil
}
c.engineMutex.Lock()
defer c.engineMutex.Unlock()

if c.engine == nil {
return nil
}
if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
engine := c.Engine()
if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
}

return nil
}

Expand Down
9 changes: 6 additions & 3 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {

// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelFunc
shutdownWg sync.WaitGroup
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
// This is different from ServiceEnable=false from management which completely disables the DNS service.
disableSys bool
Expand Down Expand Up @@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
// Stop stops the server
func (s *DefaultServer) Stop() {
s.ctxCancel()
s.shutdownWg.Wait()

s.mux.Lock()
defer s.mux.Unlock()
Expand Down Expand Up @@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {

s.applyHostConfig()

s.shutdownWg.Add(1)
go func() {
// persist dns state right away
defer s.shutdownWg.Done()
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
Expand Down
119 changes: 90 additions & 29 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ type Engine struct {

// syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex
// sshMux protects sshServer field access
sshMux sync.Mutex

config *EngineConfig
mobileDep MobileDependency
Expand Down Expand Up @@ -200,8 +202,10 @@ type Engine struct {
flowManager nftypes.FlowManager

// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
wgIfaceMonitor *WGIfaceMonitor

// shutdownWg tracks all long-running goroutines to ensure clean shutdown
shutdownWg sync.WaitGroup

probeStunTurn *relay.StunTurnProbe
}
Expand Down Expand Up @@ -325,19 +329,13 @@ func (e *Engine) Stop() error {
e.cancel()
}

// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)

e.close()

// stop flow manager after wg interface is gone
if e.flowManager != nil {
e.flowManager.Close()
}

log.Infof("stopped Netbird Engine")

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

Expand All @@ -348,12 +346,52 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}

// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}

log.Infof("stopped Netbird Engine")

return nil
}

// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
func (e *Engine) calculateShutdownTimeout() time.Duration {
peerCount := len(e.peerStore.PeersPubKey())

baseTimeout := 10 * time.Second
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
timeout := baseTimeout + perPeerTimeout

maxTimeout := 30 * time.Second
if timeout > maxTimeout {
timeout = maxTimeout
}

return timeout
}

// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
Expand Down Expand Up @@ -483,14 +521,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)

// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
e.shutdownWg.Add(1)

go func() {
defer e.wgIfaceMonitorWg.Done()
defer e.shutdownWg.Done()

if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
e.triggerClientRestart()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
Expand Down Expand Up @@ -674,9 +712,11 @@ func (e *Engine) removeAllPeers() error {
func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey)

e.sshMux.Lock()
if !isNil(e.sshServer) {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
e.sshMux.Unlock()

e.connMgr.RemovePeerConn(peerKey)

Expand Down Expand Up @@ -878,41 +918,50 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
e.sshMux.Lock()
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)

server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
e.sshMux.Unlock()
return fmt.Errorf("create ssh server: %w", err)
}

e.sshServer = server
e.sshMux.Unlock()

go func() {
// blocking
err = e.sshServer.Start()
err = server.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshMux.Lock()
e.sshServer = nil
e.sshMux.Unlock()
log.Infof("stopped SSH server")
}()
} else {
e.sshMux.Unlock()
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
} else {
e.sshMux.Lock()
if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
e.sshServer = nil
e.sshMux.Unlock()
}
return nil
}
Expand Down Expand Up @@ -949,7 +998,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
Expand Down Expand Up @@ -1125,6 +1176,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications()

// update SSHServer by adding remote peer SSH keys
e.sshMux.Lock()
if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
Expand All @@ -1135,6 +1187,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
e.sshMux.Unlock()
}

// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
Expand Down Expand Up @@ -1377,7 +1430,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV

// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
e.syncMsgMux.Lock()
Expand Down Expand Up @@ -1494,12 +1549,14 @@ func (e *Engine) close() {
e.statusRecorder.SetWgIface(nil)
}

e.sshMux.Lock()
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed stopping the SSH server: %v", err)
}
}
e.sshMux.Unlock()

if e.firewall != nil {
err := e.firewall.Close(e.stateManager)
Expand Down Expand Up @@ -1730,8 +1787,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
return allHealthy
}

// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
// triggerClientRestart triggers a full client restart by cancelling the client context.
// Note: This does NOT just restart the engine - it cancels the entire client context,
// which causes the connect client's retry loop to create a completely new engine.
func (e *Engine) triggerClientRestart() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()

Expand All @@ -1753,7 +1812,9 @@ func (e *Engine) startNetworkMonitor() {
}

e.networkMonitor = networkmonitor.New()
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
Expand All @@ -1763,8 +1824,8 @@ func (e *Engine) startNetworkMonitor() {
return
}

log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
log.Infof("Network monitor: detected network change, triggering client restart")
e.triggerClientRestart()
}()
}

Expand Down
Loading
Loading