From 55676971693f0959a392a718fc6575a1671597ee Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 19 Sep 2024 21:49:16 -0500 Subject: [PATCH] Fixes --- inside.go | 4 +- interface.go | 103 +++++++++++++++++++++------------------------------ outside.go | 2 +- pki.go | 33 ++++++++++------- 4 files changed, 66 insertions(+), 76 deletions(-) diff --git a/inside.go b/inside.go index 1b75f0f46..6813237ed 100644 --- a/inside.go +++ b/inside.go @@ -21,7 +21,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // Ignore local broadcast packets if f.dropLocalBroadcast { - _, found := f.myBroadcastAddr.Lookup(fwPacket.RemoteIP) + _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteIP) if found { return } @@ -129,7 +129,7 @@ func (f *Interface) Handshake(vpnIp netip.Addr) { // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - _, found := f.myVpnNetworks.Lookup(vpnIp) + _, found := f.myVpnNetworksTable.Lookup(vpnIp) if !found { vpnIp = f.inside.RouteFor(vpnIp) if !vpnIp.IsValid() { diff --git a/interface.go b/interface.go index 9686d128b..a403f5d03 100644 --- a/interface.go +++ b/interface.go @@ -14,7 +14,6 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -52,26 +51,27 @@ type InterfaceConfig struct { } type Interface struct { - hostMap *HostMap - outside udp.Conn - inside overlay.Device - pki *PKI - firewall *Firewall - connectionManager *connectionManager - handshakeManager *HandshakeManager - serveDns bool - createTime time.Time - lightHouse *LightHouse - myBroadcastAddr *bart.Table[struct{}] - myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate - myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate - myVpnNetworks *bart.Table[struct{}] // A table of networks assigned to us via our certificate - dropLocalBroadcast bool - dropMulticast bool - routines int - disconnectInvalid atomic.Bool - closed atomic.Bool - relayManager *relayManager + hostMap *HostMap + outside udp.Conn + inside overlay.Device + pki *PKI + firewall *Firewall + connectionManager *connectionManager + handshakeManager *HandshakeManager + serveDns bool + createTime time.Time + lightHouse *LightHouse + myBroadcastAddrsTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate + myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate + myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate + dropLocalBroadcast bool + dropMulticast bool + routines int + disconnectInvalid atomic.Bool + closed atomic.Bool + relayManager *relayManager tryPromoteEvery atomic.Uint32 reQueryEvery atomic.Uint32 @@ -157,25 +157,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no firewall rules") } + cs := c.pki.getCertState() ifce := &Interface{ - pki: c.pki, - hostMap: c.HostMap, - outside: c.Outside, - inside: c.Inside, - firewall: c.Firewall, - serveDns: c.ServeDns, - handshakeManager: c.HandshakeManager, - createTime: time.Now(), - lightHouse: c.lightHouse, - dropLocalBroadcast: c.DropLocalBroadcast, - dropMulticast: c.DropMulticast, - routines: c.routines, - version: c.version, - writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), - myVpnNetworks: new(bart.Table[struct{}]), - myVpnAddrsTable: new(bart.Table[struct{}]), - relayManager: c.relayManager, + pki: c.pki, + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + routines: c.routines, + version: c.version, + writers: make([]udp.Conn, c.routines), + readers: make([]io.ReadWriteCloser, c.routines), + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + myVpnAddrs: cs.myVpnAddrs, + myVpnAddrsTable: cs.myVpnAddrsTable, + myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, + relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -189,27 +193,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - var crt cert.Certificate - cs := c.pki.getCertState() - crt = cs.getCertificate(cert.Version2) - if crt == nil { - // v2 certificates are a superset, only look at v1 if its all we have - crt = cs.getCertificate(cert.Version1) - } - - for _, network := range crt.Networks() { - ifce.myVpnNetworks.Insert(network, struct{}{}) - ifce.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) - ifce.myVpnAddrs = append(ifce.myVpnAddrs, network.Addr()) - - if network.Addr().Is4() { - //TODO: finish calculating the broadcast ips - //addr := network.Masked().Addr().As4() - //binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) - //ifce.myBroadcastAddr = netip.AddrFrom4(addr) - } - } - ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) diff --git a/outside.go b/outside.go index dd2ae2520..f7dbbd32e 100644 --- a/outside.go +++ b/outside.go @@ -49,7 +49,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] //l.Error("in packet ", header, packet[HeaderLen:]) if ip.IsValid() { - _, found := f.myVpnNetworks.Lookup(ip.Addr()) + _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) if found { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") diff --git a/pki.go b/pki.go index 25d4a0e14..c4160d5a8 100644 --- a/pki.go +++ b/pki.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/netip" "os" "slices" @@ -37,10 +38,11 @@ type CertState struct { pkcs11Backed bool cipher string - myVpnNetworks []netip.Prefix - myVpnNetworksTable *bart.Table[struct{}] - myVpnAddrs []netip.Addr - myVpnAddrsTable *bart.Table[struct{}] + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr + myVpnAddrsTable *bart.Table[struct{}] + myVpnBroadcastAddrsTable *bart.Table[struct{}] } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -294,7 +296,7 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } var crt, v1, v2 cert.Certificate - for len(rawCert) != 0 { + for { // Load the certificate crt, rawCert, err = loadCertificate(rawCert) if err != nil { @@ -316,6 +318,10 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { default: return nil, fmt.Errorf("unknown certificate version %v", crt.Version()) } + + if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { + break + } } rawDefaultVersion := c.GetUint32("pki.default_version", 1) @@ -334,10 +340,11 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { cs := CertState{ - privateKey: privateKey, - pkcs11Backed: pkcs11backed, - myVpnNetworksTable: new(bart.Table[struct{}]), - myVpnAddrsTable: new(bart.Table[struct{}]), + privateKey: privateKey, + pkcs11Backed: pkcs11backed, + myVpnNetworksTable: new(bart.Table[struct{}]), + myVpnAddrsTable: new(bart.Table[struct{}]), + myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), } if v1 != nil && v2 != nil { @@ -409,10 +416,10 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) if network.Addr().Is4() { - //TODO: finish calculating the broadcast ips - //addr := network.Masked().Addr().As4() - //binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) - //ifce.myBroadcastAddr = netip.AddrFrom4(addr) + addr := network.Masked().Addr().As4() + mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) + cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) } }