diff --git a/cert/cert.go b/cert/cert.go index 5d5abd646..ed94ec043 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -113,7 +113,7 @@ type CachedCertificate struct { func UnmarshalCertificate(b []byte) (Certificate, error) { //TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers? var c Certificate - c, err := unmarshalCertificateV2(b, nil) + c, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) if err == nil { return c, nil } @@ -129,7 +129,7 @@ func UnmarshalCertificate(b []byte) (Certificate, error) { // UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. // Handshakes save space by placing the peers public key in a different part of the packet, we have to // reassemble the actual certificate structure with that in mind. -func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (Certificate, error) { +func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) { var c Certificate var err error @@ -137,7 +137,7 @@ func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (C case VersionPre1, Version1: c, err = unmarshalCertificateV1(b, publicKey) case Version2: - c, err = unmarshalCertificateV2(b, publicKey) + c, err = unmarshalCertificateV2(b, publicKey, curve) default: //TODO: make a static var return nil, fmt.Errorf("unknown certificate version %d", v) @@ -146,10 +146,15 @@ func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (C if err != nil { return nil, err } + + if c.Curve() != curve { + return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String()) + } + return c, nil } -func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAPool) (*CachedCertificate, error) { +func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) { if publicKey == nil { return nil, ErrNoPeerStaticKey } @@ -158,7 +163,7 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAP return nil, ErrNoPayload } - c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey) + c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve) if err != nil { return nil, fmt.Errorf("error unmarshaling cert: %w", err) } diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 4dc38fcef..9f88723bd 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -14,7 +14,6 @@ import ( "fmt" "net" "net/netip" - "slices" "time" "golang.org/x/crypto/curve25519" @@ -393,8 +392,7 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) } } - slices.SortFunc(nc.details.networks, comparePrefix) - slices.SortFunc(nc.details.unsafeNetworks, comparePrefix) + //do not sort the subnets field for V1 certs return &nc, nil } diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 2c6f14a00..11aa666bd 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -455,7 +455,7 @@ func (d *detailsV2) Marshal() ([]byte, error) { return b.Bytes() } -func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error) { +func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) { l := len(b) if l == 0 || l > MaxCertificateSize { return nil, ErrBadFormat @@ -473,11 +473,12 @@ func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error) return nil, ErrBadFormat } + //Maybe grab the curve var rawCurve byte - if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(Curve_CURVE25519)) { + if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) { return nil, ErrBadFormat } - curve := Curve(rawCurve) + curve = Curve(rawCurve) // Maybe grab the public key var rawPublicKey cryptobyte.String diff --git a/cert/pem.go b/cert/pem.go index 8f9fe8e99..249b63917 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -37,7 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { case CertificateBanner: c, err = unmarshalCertificateV1(p.Bytes, nil) case CertificateV2Banner: - c, err = unmarshalCertificateV2(p.Bytes, nil) + c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) default: return nil, r, ErrInvalidPEMCertificateBanner } diff --git a/connection_state.go b/connection_state.go index cfd86eb35..a13164d5c 100644 --- a/connection_state.go +++ b/connection_state.go @@ -91,3 +91,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "message_counter": cs.messageCounter.Load(), }) } + +func (cs *ConnectionState) Curve() cert.Curve { + return cs.myCert.Curve() +} diff --git a/dns_server.go b/dns_server.go index 991f27068..1b97d7a4f 100644 --- a/dns_server.go +++ b/dns_server.go @@ -85,6 +85,15 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { m.Answer = append(m.Answer, rr) } } + case dns.TypeAAAA: + l.Debugf("Query for AAAA %s", q.Name) + ip := dnsR.Query(q.Name) + if ip != "" { + rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip)) + if err == nil { + m.Answer = append(m.Answer, rr) + } + } case dns.TypeTXT: a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) b, err := netip.ParseAddr(a) diff --git a/handshake_ix.go b/handshake_ix.go index e20ac2f24..55b857e46 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -81,7 +81,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -404,7 +404,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) diff --git a/lighthouse.go b/lighthouse.go index 623ed0b9e..ab81d3c66 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1170,6 +1170,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd useVersion = 2 } + //todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? + //todo why do we care about the vpnip in the packet? We know where it came from, right? + if detailsVpnIp != vpnAddrs[0] { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") diff --git a/outside.go b/outside.go index f7dbbd32e..6eb6ea011 100644 --- a/outside.go +++ b/outside.go @@ -7,6 +7,9 @@ import ( "net/netip" "time" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv6" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -297,14 +300,104 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { - // Do we at least have an ipv4 header worth of data? - if len(data) < ipv4.HeaderLen { - return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) + if len(data) < 1 { + return errors.New("packet too short") + } + + version := int((data[0] >> 4) & 0x0f) + switch version { + case ipv4.Version: + return parseV4(data, incoming, fp) + case ipv6.Version: + return parseV6(data, incoming, fp) } + return fmt.Errorf("packet is an unknown ip version: %v", version) +} + +func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { + dataLen := len(data) + if dataLen < ipv6.HeaderLen { + return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen) + } + + if incoming { + fp.RemoteIP, _ = netip.AddrFromSlice(data[8:24]) + fp.LocalIP, _ = netip.AddrFromSlice(data[24:40]) + } else { + fp.LocalIP, _ = netip.AddrFromSlice(data[8:24]) + fp.RemoteIP, _ = netip.AddrFromSlice(data[24:40]) + } + + //TODO: whats a reasonable number of extension headers to attempt to parse? + //https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html + protoAt := 6 + offset := 40 + for i := 0; i < 24; i++ { + if dataLen < offset { + break + } + + proto := layers.IPProtocol(data[protoAt]) + //fmt.Println(proto, protoAt) + switch proto { + case layers.IPProtocolICMPv6: + //TODO: we need a new protocol in config language "icmpv6" + fp.Protocol = uint8(proto) + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = false + return nil - // Is it an ipv4 packet? - if int((data[0]>>4)&0x0f) != 4 { - return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) + case layers.IPProtocolTCP: + if dataLen < offset+4 { + return fmt.Errorf("ipv6 packet was too small") + } + fp.Protocol = uint8(proto) + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + fp.Fragment = false + return nil + + case layers.IPProtocolUDP: + if dataLen < offset+4 { + return fmt.Errorf("ipv6 packet was too small") + } + fp.Protocol = uint8(proto) + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + fp.Fragment = false + return nil + + case layers.IPProtocolIPv6Fragment: + //TODO: can we determine the protocol? + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = true + return nil + + default: + if dataLen < offset+1 { + break + } + + next := int(data[offset+1]) * 8 + if next == 0 { + // each extension is at least 8 bytes + next = 8 + } + + protoAt = offset + offset = offset + next + } + } + + return fmt.Errorf("could not find payload in ipv6 packet") +} + +func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { + // Do we at least have an ipv4 header worth of data? + if len(data) < ipv4.HeaderLen { + return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen) } // Adjust our start position based on the advertised ip header length @@ -312,7 +405,7 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { // Well formed ip header length? if ihl < ipv4.HeaderLen { - return fmt.Errorf("packet had an invalid header length: %v", ihl) + return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl) } // Check if this is the second or further fragment of a fragmented packet. @@ -328,12 +421,11 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { minLen += minFwPacketLen } if len(data) < minLen { - return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) + return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl) } // Firewall packets are locally oriented if incoming { - //TODO: IPV6-WORK fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { @@ -344,7 +436,6 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - //TODO: IPV6-WORK fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { diff --git a/outside_test.go b/outside_test.go index f9d4bfa48..aa5581f03 100644 --- a/outside_test.go +++ b/outside_test.go @@ -13,9 +13,15 @@ import ( func Test_newPacket(t *testing.T) { p := &firewall.Packet{} - // length fail - err := newPacket([]byte{0, 1}, true, p) - assert.EqualError(t, err, "packet is less than 20 bytes") + // length fails + err := newPacket([]byte{}, true, p) + assert.EqualError(t, err, "packet too short") + + err = newPacket([]byte{0x40}, true, p) + assert.EqualError(t, err, "ipv4 packet is less than 20 bytes") + + err = newPacket([]byte{0x60}, true, p) + assert.EqualError(t, err, "ipv6 packet is less than 20 bytes") // length fail with ip options h := ipv4.Header{ @@ -29,15 +35,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24") + assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24") // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet is not ipv4, type: 0") + assert.EqualError(t, err, "packet is an unknown ip version: 0") // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet had an invalid header length: 8") + assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8") // account for variable ip header length - incoming h = ipv4.Header{ diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 98ad9b408..72a656500 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -18,14 +18,14 @@ import ( type tun struct { io.ReadWriteCloser - fd int - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix t := &tun{ ReadWriteCloser: file, fd: deviceFd, - cidr: cidr, + vpnNetworks: vpnNetworks, l: l, } @@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } @@ -66,7 +66,7 @@ func (t tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index bdfeb5802..69690e948 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -46,12 +46,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -78,11 +78,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -195,8 +195,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 20981f08c..e99d44718 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -21,20 +21,20 @@ import ( type tun struct { io.ReadWriteCloser - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - cidr: cidr, + vpnNetworks: vpnNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } @@ -59,7 +59,7 @@ func (t *tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0e7e20d41..08c65b4e0 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -25,7 +25,7 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr netip.Prefix + vpnNetworks []netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int @@ -40,18 +40,16 @@ type tun struct { l *logrus.Logger } +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks +} + type ifReq struct { Name [16]byte Flags uint16 pad [8]byte } -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 @@ -64,10 +62,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -77,7 +75,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -112,7 +110,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -122,11 +120,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), - cidr: cidr, + vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), l: l, @@ -148,7 +146,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref } func (t *tun) reload(c *config.C, initial bool) error { - routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -190,11 +188,13 @@ func (t *tun) reload(c *config.C, initial bool) error { } if oldDefaultMTU != newDefaultMTU { - err := t.setDefaultRoute() - if err != nil { - t.l.Warn(err) - } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + for i := range t.vpnNetworks { + err := t.setDefaultRoute(t.vpnNetworks[i]) + if err != nil { + t.l.Warn(err) + } else { + t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + } } } @@ -237,10 +237,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) Write(b []byte) (int, error) { var nn int - max := len(b) + maximum := len(b) for { - n, err := unix.Write(t.fd, b[nn:max]) + n, err := unix.Write(t.fd, b[nn:maximum]) if n > 0 { nn += n } @@ -265,6 +265,58 @@ func (t *tun) deviceBytes() (o [16]byte) { return } +func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool { + for i := range al { + if al[i].Equal(x) { + return true + } + } + return false +} + +// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there +func (t *tun) addIPs(link netlink.Link) error { + newAddrs := make([]*netlink.Addr, len(t.vpnNetworks)) + for i := range t.vpnNetworks { + newAddrs[i] = &netlink.Addr{ + IPNet: &net.IPNet{ + IP: t.vpnNetworks[i].Addr().AsSlice(), + Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()), + }, + Label: t.vpnNetworks[i].Addr().Zone(), + } + } + + //add all new addresses + for i := range newAddrs { + //todo do we want to stack errors and try as many ops as possible? + //AddrReplace still adds new IPs, but if their properties change it will change them as well + if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { + return err + } + } + + //iterate over remainder, remove whoever shouldn't be there + al, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to get tun address list: %s", err) + } + + for i := range al { + if hasNetlinkAddr(newAddrs, al[i]) { + continue + } + err = netlink.AddrDel(link, &al[i]) + if err != nil { + t.l.WithError(err).Error("failed to remove address from tun address list") + } else { + t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + } + } + + return nil +} + func (t *tun) Activate() error { devName := t.deviceBytes() @@ -272,15 +324,8 @@ func (t *tun) Activate() error { t.watchRoutes() } - var addr, mask [4]byte - - //TODO: IPV6-WORK - addr = t.cidr.Addr().As4() - tmask := net.CIDRMask(t.cidr.Bits(), 32) - copy(mask[:], tmask) - s, err := unix.Socket( - unix.AF_INET, + unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine unix.SOCK_DGRAM, unix.IPPROTO_IP, ) @@ -289,31 +334,19 @@ func (t *tun) Activate() error { } t.ioctlFd = uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - // Set the device name ifrf := ifReq{Name: devName} if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } + link, err := netlink.LinkByName(t.Device) + if err != nil { + return fmt.Errorf("failed to get tun device link: %s", err) + } + + t.deviceIndex = link.Attrs().Index + // Setup our default MTU t.setMTU() @@ -324,20 +357,27 @@ func (t *tun) Activate() error { t.l.WithError(err).Error("Failed to set tun tx queue length") } + if err = t.addIPs(link); err != nil { + return err + } + // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } - link, err := netlink.LinkByName(t.Device) - if err != nil { - return fmt.Errorf("failed to get tun device link: %s", err) + // Run the interface + ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING + if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to run tun device: %s", err) } - t.deviceIndex = link.Attrs().Index - if err = t.setDefaultRoute(); err != nil { - return err + //set route MTU + for i := range t.vpnNetworks { + if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { + return fmt.Errorf("failed to set default route MTU: %w", err) + } } // Set the routes @@ -345,11 +385,7 @@ func (t *tun) Activate() error { return err } - // Run the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING - if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to run tun device: %s", err) - } + //todo do we want to keep the link-local address? return nil } @@ -363,12 +399,12 @@ func (t *tun) setMTU() { } } -func (t *tun) setDefaultRoute() error { +func (t *tun) setDefaultRoute(cidr netip.Prefix) error { // Default route dr := &net.IPNet{ - IP: t.cidr.Masked().Addr().AsSlice(), - Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + IP: cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()), } nr := netlink.Route{ @@ -377,7 +413,7 @@ func (t *tun) setDefaultRoute() error { MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: net.IP(t.cidr.Addr().AsSlice()), + Src: net.IP(cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, @@ -463,10 +499,6 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() netip.Prefix { - return t.cidr -} - func (t *tun) Name() string { return t.Device } @@ -523,9 +555,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } gwAddr = gwAddr.Unmap() - if !t.cidr.Contains(gwAddr) { + withinNetworks := false + for i := range t.vpnNetworks { + if t.vpnNetworks[i].Contains(gwAddr) { + withinNetworks = true + break + } + } + if !withinNetworks { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") return } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 24ab24f78..54984910e 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -27,12 +27,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -58,13 +58,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -130,8 +130,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error { continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 6463ccbba..dcba05478 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -21,12 +21,12 @@ import ( ) type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser @@ -42,13 +42,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -138,7 +138,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -148,6 +148,16 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) RouteFor(ip netip.Addr) netip.Addr { r, _ := t.routeTree.Load().Lookup(ip) return r @@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error { // We don't allow route MTUs so only install routes with a via continue } - - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error { if !r.Install { continue } - - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index d78f564cf..73252c71d 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -17,22 +17,22 @@ import ( ) type waterTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger - f *net.Interface + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger + f *net.Interface *water.Interface } -func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { +func newWaterTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*waterTun, error) { // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() t := &waterTun{ - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err := t.reload(c, true) @@ -52,11 +52,13 @@ func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*wat func (t *waterTun) Activate() error { var err error + //TODO multi-ip support + cidr := t.vpnNetworks[0] t.Interface, err = water.New(water.Config{ DeviceType: water.TUN, PlatformSpecificParams: water.PlatformSpecificParams{ ComponentID: "tap0901", - Network: t.cidr.String(), + Network: cidr.String(), }, }) if err != nil { @@ -70,8 +72,8 @@ func (t *waterTun) Activate() error { `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", fmt.Sprintf("name=%s", t.Device), "source=static", - fmt.Sprintf("addr=%s", t.cidr.Addr()), - fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), + fmt.Sprintf("addr=%s", cidr.Addr()), + fmt.Sprintf("mask=%s", net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen())), "gateway=none", ).Run() if err != nil { @@ -100,7 +102,7 @@ func (t *waterTun) Activate() error { } func (t *waterTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -187,8 +189,8 @@ func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *waterTun) Cidr() netip.Prefix { - return t.cidr +func (t *waterTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *waterTun) Name() string { diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 3d883093c..125d72bbe 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -15,11 +15,11 @@ import ( "github.com/slackhq/nebula/config" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") @@ -27,14 +27,14 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( } if useWintun { - device, err := newWinTun(c, l, cidr, multiqueue) + device, err := newWinTun(c, l, vpnNetworks, multiqueue) if err != nil { return nil, fmt.Errorf("create Wintun interface failed, %w", err) } return device, nil } - device, err := newWaterTun(c, l, cidr, multiqueue) + device, err := newWaterTun(c, l, vpnNetworks, multiqueue) if err != nil { return nil, fmt.Errorf("create wintap driver failed, %w", err) } diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index d0103879a..5a801c318 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -20,12 +20,12 @@ import ( const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger tun *wintun.NativeTun } @@ -49,7 +49,7 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) { return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } -func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { +func newWinTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { deviceName := c.GetString("tun.dev", "") guid, err := generateGUIDByDeviceName(deviceName) if err != nil { @@ -57,10 +57,10 @@ func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTu } t := &winTun{ - Device: deviceName, - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -92,7 +92,7 @@ func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTu } func (t *winTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -131,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error { func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) + err := luid.SetIPAddresses(t.vpnNetworks) if err != nil { return fmt.Errorf("failed to set address: %w", err) } @@ -216,8 +216,8 @@ func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *winTun) Cidr() netip.Prefix { - return t.cidr +func (t *winTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *winTun) Name() string {