Skip to content

Commit

Permalink
We only need the certificate in ConnectionState
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Aug 15, 2023
1 parent 5a131b2 commit d01cb2f
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 51 deletions.
6 changes: 3 additions & 3 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
}

certState := n.intf.pki.GetCertState()
return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
}

func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
Expand Down Expand Up @@ -465,7 +465,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {

func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.pki.GetCertState()
if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
return
}

Expand All @@ -474,7 +474,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
Info("Re-handshaking with remote")

//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp)
if !newHostinfo.HandshakeReady {
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
}
Expand Down
19 changes: 10 additions & 9 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

Expand Down Expand Up @@ -159,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
remoteIndexId: 9901,
}
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
H: &noise.HandshakeState{},
myCert: &cert.NebulaCertificate{},
H: &noise.HandshakeState{},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

Expand Down Expand Up @@ -222,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
PublicKey: pubCA,
},
}
caCert.Sign(cert.Curve_CURVE25519, privCA)

assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
ncp := &cert.NebulaCAPool{
CAs: cert.NewCAPool().CAs,
}
Expand All @@ -241,7 +242,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
Issuer: "ca",
},
}
peerCert.Sign(cert.Curve_CURVE25519, privCA)
assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))

cs := &CertState{
RawCertificate: []byte{},
Expand Down Expand Up @@ -275,9 +276,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
hostinfo := &HostInfo{
vpnIp: vpnIp,
ConnectionState: &ConnectionState{
certState: cs,
peerCert: &peerCert,
H: &noise.HandshakeState{},
myCert: &cert.NebulaCertificate{},
peerCert: &peerCert,
H: &noise.HandshakeState{},
},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
Expand Down
20 changes: 11 additions & 9 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type ConnectionState struct {
eKey *NebulaCipherState
dKey *NebulaCipherState
H *noise.HandshakeState
certState *CertState
myCert *cert.NebulaCertificate
peerCert *cert.NebulaCertificate
initiator bool
messageCounter atomic.Uint64
Expand All @@ -28,25 +28,27 @@ type ConnectionState struct {
ready bool
}

func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
var dhFunc noise.DHFunc
curCertState := f.pki.GetCertState()

switch curCertState.Certificate.Details.Curve {
switch certState.Certificate.Details.Curve {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
dhFunc = noiseutil.DHP256
default:
l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
return nil
}
cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
if f.cipher == "chachapoly" {

var cs noise.CipherSuite
if cipher == "chachapoly" {
cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
}

static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}
static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}

b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
Expand All @@ -72,7 +74,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
initiator: initiator,
window: b,
ready: false,
certState: curCertState,
myCert: certState.Certificate,
}

return ci
Expand Down
2 changes: 1 addition & 1 deletion control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
}

func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp)
ixHandshakeStage0(c.f, vpnIp, hostinfo)

// If this is a static host, we don't need to wait for the HostQueryReply
Expand Down
11 changes: 7 additions & 4 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
return
}

ci := hostinfo.ConnectionState
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
hostinfo.ConnectionState = ci

hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: ci.certState.RawCertificateNoKey,
Cert: certState.RawCertificateNoKey,
}

hsBytes := []byte{}
Expand Down Expand Up @@ -69,7 +71,8 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
}

func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)

Expand Down Expand Up @@ -155,7 +158,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
Info("Handshake message received")

hs.Details.ResponderIndex = myIndex
hs.Details.Cert = ci.certState.RawCertificateNoKey
hs.Details.Cert = certState.RawCertificateNoKey
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())

Expand Down
6 changes: 1 addition & 5 deletions handshake_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
}

// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
c.Lock()
defer c.Unlock()
Expand All @@ -317,10 +317,6 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
},
}

if init != nil {
init(hostinfo)
}

c.vpnIps[vpnIp] = hostinfo
c.metricInitiated.Inc(1)
c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
Expand Down
13 changes: 2 additions & 11 deletions handshake_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,8 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now, mw)

var initCalled bool
initFunc := func(*HostInfo) {
initCalled = true
}

i := blah.AddVpnIp(ip, initFunc)
assert.True(t, initCalled)

initCalled = false
i2 := blah.AddVpnIp(ip, initFunc)
assert.False(t, initCalled)
i := blah.AddVpnIp(ip)
i2 := blah.AddVpnIp(ip)
assert.Same(t, i, i2)

i.remotes = NewRemoteList(nil)
Expand Down
9 changes: 1 addition & 8 deletions inside.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package nebula

import (
"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
Expand Down Expand Up @@ -124,7 +123,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {

hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
if hostinfo == nil {
hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
}
ci := hostinfo.ConnectionState

Expand Down Expand Up @@ -168,12 +167,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
return hostinfo
}

// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
// will create the initial Noise ConnectionState
func (f *Interface) initHostInfo(hostinfo *HostInfo) {
hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
}

func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
fp := &firewall.Packet{}
err := newPacket(p, false, fp)
Expand Down
2 changes: 1 addition & 1 deletion ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
}
}

hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
if addr != nil {
hostInfo.SetRemote(addr)
}
Expand Down

0 comments on commit d01cb2f

Please sign in to comment.