Skip to content

Commit

Permalink
Simplify the getOrHandshake flow, avoid races
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Aug 15, 2023
1 parent a5ac2c5 commit 3656e8c
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 219 deletions.
15 changes: 1 addition & 14 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,18 +473,5 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")

//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp)
if !newHostinfo.HandshakeReady {
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
}

//If this is a static host, we don't need to wait for the HostQueryReply
//We can trigger the handshake right now
if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
}
6 changes: 3 additions & 3 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
Expand Down Expand Up @@ -138,7 +138,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
firewall: &Firewall{},
lightHouse: lh,
pki: &PKI{},
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
}
ifce.pki.cs.Store(cs)
Expand Down Expand Up @@ -258,7 +258,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
outside: &udp.NoopConn{},
firewall: &Firewall{},
lightHouse: lh,
handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
l: l,
disconnectInvalid: true,
pki: &PKI{},
Expand Down
1 change: 0 additions & 1 deletion connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ type ConnectionState struct {
initiator bool
messageCounter atomic.Uint64
window *Bits
queueLock sync.Mutex
writeLock sync.Mutex
ready bool
}
Expand Down
12 changes: 1 addition & 11 deletions control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,5 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
}

func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp)
ixHandshakeStage0(c.f, vpnIp, hostinfo)

// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case c.f.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
c.f.handshakeManager.StartHandshake(vpnIp, nil)
}
70 changes: 30 additions & 40 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,12 @@ import (

// This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
// This queries the lighthouse if we don't know a remote for the host
// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
// more quickly, effect is a quicker handshake.
if hostinfo.remote == nil {
f.lightHouse.QueryServer(vpnIp, f)
}

err := f.handshakeManager.AddIndexHostInfo(hostinfo)
func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
err := f.handshakeManager.allocateIndex(hostinfo)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return
return false
}

certState := f.pki.GetCertState()
Expand All @@ -46,19 +39,19 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hsBytes, err = hs.Marshal()

if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return
return false
}

h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
ci.messageCounter.Add(1)

msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", vpnIp).
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
return false
}

// We are sending handshake packet 1, so we don't expect to receive
Expand All @@ -68,6 +61,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true
hostinfo.handshakeStart = time.Now()
return true
}

func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
Expand Down Expand Up @@ -428,31 +422,27 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
f.handshakeManager.DeleteHostInfo(hostinfo)

// Create a new hostinfo/handshake for the intended vpn ip
//TODO: this adds it to the timer wheel in a way that aggressively retries
newHostInfo := f.getOrHandshake(hostinfo.vpnIp)
newHostInfo.Lock()

// Block the current used address
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)

// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)

f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")

// Swap the packet store to benefit the original intended recipient
hostinfo.ConnectionState.queueLock.Lock()
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}
hostinfo.ConnectionState.queueLock.Unlock()

// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo)
newHostInfo.Unlock()
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) {
//TODO: this doesnt know if its being added or is being used for caching a packet
// Block the current used address
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)

// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)

f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")

// Swap the packet store to benefit the original intended recipient
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}

// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
f.sendCloseTunnel(hostinfo)
})

return true
}
Expand Down
Loading

0 comments on commit 3656e8c

Please sign in to comment.