From 6a89fb471801026ab42447db0c331a7f305713e0 Mon Sep 17 00:00:00 2001 From: hwipl <33433250+hwipl@users.noreply.github.com> Date: Thu, 15 Aug 2024 15:41:51 +0200 Subject: [PATCH] Switch to package netip internally Switch from using the types net.IP, net.IPNet and net.IPMask to netip.Addr and netip.Prefix internally. Signed-off-by: hwipl <33433250+hwipl@users.noreply.github.com> --- internal/addrmon/addrmon.go | 14 +++++-- internal/addrmon/addrmon_test.go | 8 +++- internal/daemon/daemon.go | 13 +++--- internal/dnsproxy/proxy.go | 17 +++++++- internal/dnsproxy/proxy_test.go | 9 ++-- internal/dnsproxy/report.go | 6 +-- internal/dnsproxy/report_test.go | 10 ++--- internal/splitrt/addresses.go | 6 +-- internal/splitrt/addresses_test.go | 18 ++++---- internal/splitrt/excludes.go | 42 +++++++------------ internal/splitrt/excludes_test.go | 14 +++---- internal/splitrt/filter.go | 13 +++--- internal/splitrt/splitrt.go | 67 +++++++++++++----------------- internal/splitrt/splitrt_test.go | 7 ++-- internal/trafpol/filter.go | 6 +-- internal/trafpol/filter_test.go | 8 ++-- internal/trafpol/resolver.go | 9 ++-- internal/trafpol/resolver_test.go | 4 +- internal/trafpol/trafpol.go | 63 +++++++++++----------------- internal/trafpol/trafpol_test.go | 34 ++++++++------- internal/vpnsetup/vpnsetup.go | 25 ++++++----- internal/vpnsetup/vpnsetup_test.go | 5 ++- 22 files changed, 199 insertions(+), 199 deletions(-) diff --git a/internal/addrmon/addrmon.go b/internal/addrmon/addrmon.go index 0b39444a..60794ac0 100644 --- a/internal/addrmon/addrmon.go +++ b/internal/addrmon/addrmon.go @@ -3,7 +3,7 @@ package addrmon import ( "fmt" - "net" + "net/netip" log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" @@ -12,7 +12,7 @@ import ( // Update is an address update. type Update struct { Add bool - Address net.IPNet + Address netip.Prefix Index int } @@ -72,8 +72,16 @@ func (a *AddrMon) start() { } // forward event as address update + ip, ok := netip.AddrFromSlice(e.LinkAddress.IP) + if !ok || !ip.IsValid() { + log.WithField("LinkAddress", e.LinkAddress). + Error("AddrMon got invalid IP in addr event") + continue + } + ones, _ := e.LinkAddress.Mask.Size() + addr := netip.PrefixFrom(ip, ones) u := &Update{ - Address: e.LinkAddress, + Address: addr, Index: e.LinkIndex, Add: e.NewAddr, } diff --git a/internal/addrmon/addrmon_test.go b/internal/addrmon/addrmon_test.go index bd0b1fc4..dc122901 100644 --- a/internal/addrmon/addrmon_test.go +++ b/internal/addrmon/addrmon_test.go @@ -2,6 +2,7 @@ package addrmon import ( "log" + "net" "testing" "github.com/vishvananda/netlink" @@ -44,7 +45,12 @@ func TestAddrMonStartStop(t *testing.T) { // helper function for AddrUpdates addrUpdates := func(updates chan netlink.AddrUpdate, done chan struct{}) { for { - up := netlink.AddrUpdate{} + up := netlink.AddrUpdate{ + LinkAddress: net.IPNet{ + IP: net.IPv4(192, 168, 1, 1), + Mask: net.CIDRMask(24, 32), + }, + } select { case updates <- up: case <-done: diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 22b50b59..a2f3cd0e 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "net" + "net/netip" "reflect" "slices" "strconv" @@ -66,7 +67,7 @@ type Daemon struct { disableTrafPol bool // serverIP is the IP address of the current VPN server - serverIP net.IP + serverIP netip.Addr // serverIPAllowed indicates whether server IP was added to // the allowed addresses @@ -307,7 +308,9 @@ func (d *Daemon) connectVPN(login *logininfo.LoginInfo) { } // set server address - d.serverIP = net.ParseIP(strings.Trim(login.Host, "[]")) + if serverIP, err := netip.ParseAddr(strings.Trim(login.Host, "[]")); err == nil { + d.serverIP = serverIP + } // update status d.setStatusOCRunning(true) @@ -316,7 +319,7 @@ func (d *Daemon) connectVPN(login *logininfo.LoginInfo) { d.setStatusConnectionState(vpnstatus.ConnectionStateConnecting) // add server address to allowed addrs in trafpol - if d.trafpol != nil && d.serverIP != nil { + if d.trafpol != nil && d.serverIP.IsValid() { d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP) } @@ -542,7 +545,7 @@ func (d *Daemon) handleRunnerDisconnect() { if d.trafpol != nil && d.serverIPAllowed { d.trafpol.RemoveAllowedAddr(d.serverIP) } - d.serverIP = nil + d.serverIP = netip.Addr{} d.serverIPAllowed = false } @@ -743,7 +746,7 @@ func (d *Daemon) startTrafPol() error { d.setStatusTrafPolState(vpnstatus.TrafPolStateActive) d.setStatusAllowedHosts(c.AllowedHosts) - if d.serverIP != nil { + if d.serverIP.IsValid() { // VPN connection active, allow server IP d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP) } diff --git a/internal/dnsproxy/proxy.go b/internal/dnsproxy/proxy.go index e3f13f60..be836fd8 100644 --- a/internal/dnsproxy/proxy.go +++ b/internal/dnsproxy/proxy.go @@ -3,6 +3,7 @@ package dnsproxy import ( "math/rand" + "net/netip" "github.com/miekg/dns" log "github.com/sirupsen/logrus" @@ -101,7 +102,13 @@ func (p *Proxy) handleRequest(w dns.ResponseWriter, r *dns.Msg) { log.Error("DNS-Proxy received invalid A record in reply") return } - report := NewReport(rr.Hdr.Name, rr.A, rr.Hdr.Ttl) + addr, ok := netip.AddrFromSlice(rr.A) + if !ok { + log.WithField("A", rr.A). + Error("DNS-Proxy received invalid IP in A record in reply") + return + } + report := NewReport(rr.Hdr.Name, addr, rr.Hdr.Ttl) p.sendReport(report) p.waitReport(report) } @@ -114,7 +121,13 @@ func (p *Proxy) handleRequest(w dns.ResponseWriter, r *dns.Msg) { log.Error("DNS-Proxy received invalid AAAA record in reply") return } - report := NewReport(rr.Hdr.Name, rr.AAAA, rr.Hdr.Ttl) + addr, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + log.WithField("AAAA", rr.AAAA). + Error("DNS-Proxy received invalid IP in AAAA record in reply") + return + } + report := NewReport(rr.Hdr.Name, addr, rr.Hdr.Ttl) p.sendReport(report) p.waitReport(report) } diff --git a/internal/dnsproxy/proxy_test.go b/internal/dnsproxy/proxy_test.go index 23ff58f9..217b7ee2 100644 --- a/internal/dnsproxy/proxy_test.go +++ b/internal/dnsproxy/proxy_test.go @@ -3,6 +3,7 @@ package dnsproxy import ( "errors" "net" + "net/netip" "testing" "github.com/miekg/dns" @@ -127,8 +128,8 @@ func TestProxyHandleRequest(t *testing.T) { if r.Name != "example.com." { t.Errorf("invalid domain name: %s", r.Name) } - if !r.IP.Equal(net.IPv4(127, 0, 0, 1)) && - !r.IP.Equal(net.ParseIP("::1")) { + if r.IP != netip.MustParseAddr("127.0.0.1") && + r.IP != netip.MustParseAddr("::1") { t.Errorf("invalid IP: %s", r.IP) } } @@ -205,8 +206,8 @@ func TestProxyHandleRequestRecords(t *testing.T) { t.Fatalf("invalid reports for run %d: %v", i, reports) } for _, r := range reports { - if !r.IP.Equal(net.ParseIP("127.0.0.1")) && - !r.IP.Equal(net.ParseIP("::1")) { + if r.IP != netip.MustParseAddr("127.0.0.1") && + r.IP != netip.MustParseAddr("::1") { t.Errorf("invalid report for run %d: %v", i, r) } diff --git a/internal/dnsproxy/report.go b/internal/dnsproxy/report.go index 23a4cec7..72d1567c 100644 --- a/internal/dnsproxy/report.go +++ b/internal/dnsproxy/report.go @@ -2,13 +2,13 @@ package dnsproxy import ( "fmt" - "net" + "net/netip" ) // Report is a report for a watched domain. type Report struct { Name string - IP net.IP + IP netip.Addr TTL uint32 // done is used to signal that the report has been handled by @@ -32,7 +32,7 @@ func (r *Report) Done() <-chan struct{} { } // NewReport returns a new report with domain name, IP and TTL. -func NewReport(name string, ip net.IP, ttl uint32) *Report { +func NewReport(name string, ip netip.Addr, ttl uint32) *Report { return &Report{ Name: name, IP: ip, diff --git a/internal/dnsproxy/report_test.go b/internal/dnsproxy/report_test.go index 6420a911..925d054a 100644 --- a/internal/dnsproxy/report_test.go +++ b/internal/dnsproxy/report_test.go @@ -1,14 +1,14 @@ package dnsproxy import ( - "net" + "net/netip" "testing" ) // TestReportString tests String of Report. func TestReportString(t *testing.T) { name := "example.com." - ip := net.IPv4(192, 168, 1, 1) + ip := netip.MustParseAddr("192.168.1.1") ttl := uint32(300) r := NewReport(name, ip, ttl) @@ -22,7 +22,7 @@ func TestReportString(t *testing.T) { // TestReportDone tests Wait and Done of Report. func TestReportWaitDone(_ *testing.T) { name := "example.com." - ip := net.IPv4(192, 168, 1, 1) + ip := netip.MustParseAddr("192.168.1.1") ttl := uint32(300) r := NewReport(name, ip, ttl) @@ -33,7 +33,7 @@ func TestReportWaitDone(_ *testing.T) { // TestNewReport tests NewReport. func TestNewReport(t *testing.T) { name := "example.com." - ip := net.IPv4(192, 168, 1, 1) + ip := netip.MustParseAddr("192.168.1.1") ttl := uint32(300) r := NewReport(name, ip, ttl) @@ -43,7 +43,7 @@ func TestNewReport(t *testing.T) { if r.Name != name { t.Errorf("got %s, want %s", r.Name, name) } - if !r.IP.Equal(ip) { + if r.IP != ip { t.Errorf("got %s, want %s", r.IP, ip) } if r.TTL != ttl { diff --git a/internal/splitrt/addresses.go b/internal/splitrt/addresses.go index 41c2e565..e009db88 100644 --- a/internal/splitrt/addresses.go +++ b/internal/splitrt/addresses.go @@ -1,7 +1,7 @@ package splitrt import ( - "net" + "net/netip" "github.com/telekom-mms/oc-daemon/internal/addrmon" ) @@ -51,9 +51,9 @@ func (a *Addresses) Remove(addr *addrmon.Update) { } // Get returns the addresses of the device identified by index. -func (a *Addresses) Get(index int) (addrs []*net.IPNet) { +func (a *Addresses) Get(index int) (addrs []netip.Prefix) { for _, x := range a.m[index] { - addrs = append(addrs, &x.Address) + addrs = append(addrs, x.Address) } return } diff --git a/internal/splitrt/addresses_test.go b/internal/splitrt/addresses_test.go index 9b110904..bf69bed9 100644 --- a/internal/splitrt/addresses_test.go +++ b/internal/splitrt/addresses_test.go @@ -1,7 +1,7 @@ package splitrt import ( - "net" + "net/netip" "reflect" "testing" @@ -10,14 +10,14 @@ import ( // getTestAddrMonUpdate returns an AddrMon update for testing. func getTestAddrMonUpdate(t *testing.T, addr string) *addrmon.Update { - _, ipnet, err := net.ParseCIDR(addr) + prefix, err := netip.ParsePrefix(addr) if err != nil { t.Fatal(err) } return &addrmon.Update{ Add: true, - Address: *ipnet, + Address: prefix, Index: 1, } } @@ -72,7 +72,7 @@ func TestAddressesGet(t *testing.T) { update2 := getTestAddrMonUpdate(t, "192.168.2.0/24") // get empty - var want []*net.IPNet + var want []netip.Prefix got := a.Get(1) if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) @@ -80,8 +80,8 @@ func TestAddressesGet(t *testing.T) { // get with one address a.Add(update1) - want = []*net.IPNet{ - &update1.Address, + want = []netip.Prefix{ + update1.Address, } got = a.Get(1) if !reflect.DeepEqual(got, want) { @@ -97,9 +97,9 @@ func TestAddressesGet(t *testing.T) { // get with multiple addresses a.Add(update2) - want = []*net.IPNet{ - &update1.Address, - &update2.Address, + want = []netip.Prefix{ + update1.Address, + update2.Address, } got = a.Get(1) if !reflect.DeepEqual(got, want) { diff --git a/internal/splitrt/excludes.go b/internal/splitrt/excludes.go index 7c3378d5..63c3ad54 100644 --- a/internal/splitrt/excludes.go +++ b/internal/splitrt/excludes.go @@ -2,7 +2,6 @@ package splitrt import ( "context" - "net" "net/netip" "sync" "time" @@ -24,7 +23,7 @@ type dynExclude struct { // Excludes contains split Excludes. type Excludes struct { sync.Mutex - s map[string]*netip.Prefix + s map[string]netip.Prefix d map[netip.Addr]*dynExclude done chan struct{} closed chan struct{} @@ -34,31 +33,20 @@ type Excludes struct { func (e *Excludes) setFilter(ctx context.Context) { log.Debug("SplitRouting resetting excludes in netfilter") - addresses := []*netip.Prefix{} + addresses := []netip.Prefix{} for _, v := range e.s { addresses = append(addresses, v) } for k := range e.d { prefix := netip.PrefixFrom(k, k.BitLen()) - addresses = append(addresses, &prefix) + addresses = append(addresses, prefix) } setExcludes(ctx, addresses) } -// prefixFromIPNet returns ipnet as netip.Prefix. -func prefixFromIPNet(ipnet *net.IPNet) netip.Prefix { - addr, _ := netip.AddrFromSlice(ipnet.IP) - bits, _ := ipnet.Mask.Size() - return netip.PrefixFrom(addr.Unmap(), bits) -} - // AddStatic adds a static entry to the split excludes. -func (e *Excludes) AddStatic(ctx context.Context, address *net.IPNet) { +func (e *Excludes) AddStatic(ctx context.Context, address netip.Prefix) { log.WithField("address", address).Debug("SplitRouting adding static exclude") - - // convert address - a := prefixFromIPNet(address) - e.Lock() defer e.Unlock() @@ -66,11 +54,11 @@ func (e *Excludes) AddStatic(ctx context.Context, address *net.IPNet) { // prefixes in static excludes removed := false for k, v := range e.s { - if !v.Overlaps(a) { + if !v.Overlaps(address) { // no overlap continue } - if v.Bits() <= a.Bits() { + if v.Bits() <= address.Bits() { // new prefix is already in existing prefix, // do not add it return @@ -83,7 +71,7 @@ func (e *Excludes) AddStatic(ctx context.Context, address *net.IPNet) { // add new prefix to static excludes key := address.String() - e.s[key] = &a + e.s[key] = address // add to netfilter if removed { @@ -92,23 +80,21 @@ func (e *Excludes) AddStatic(ctx context.Context, address *net.IPNet) { return } // single new entry, add it - addExclude(ctx, &a) + addExclude(ctx, address) } // AddDynamic adds a dynamic entry to the split excludes. -func (e *Excludes) AddDynamic(ctx context.Context, address *net.IPNet, ttl uint32) { +func (e *Excludes) AddDynamic(ctx context.Context, address netip.Prefix, ttl uint32) { log.WithFields(log.Fields{ "address": address, "ttl": ttl, }).Debug("SplitRouting adding dynamic exclude") - // convert address - prefix := prefixFromIPNet(address) - if !prefix.IsSingleIP() { + if !address.IsSingleIP() { log.Error("SplitRouting error adding dynamic exclude with multiple IPs") return } - a := prefix.Addr() + a := address.Addr() e.Lock() defer e.Unlock() @@ -135,11 +121,11 @@ func (e *Excludes) AddDynamic(ctx context.Context, address *net.IPNet, ttl uint3 } // add to netfilter - addExclude(ctx, &prefix) + addExclude(ctx, address) } // RemoveStatic removes a static entry from the split excludes. -func (e *Excludes) RemoveStatic(ctx context.Context, address *net.IPNet) { +func (e *Excludes) RemoveStatic(ctx context.Context, address netip.Prefix) { e.Lock() defer e.Unlock() @@ -214,7 +200,7 @@ func (e *Excludes) Stop() { // NewExcludes returns new split excludes. func NewExcludes() *Excludes { return &Excludes{ - s: make(map[string]*netip.Prefix), + s: make(map[string]netip.Prefix), d: make(map[netip.Addr]*dynExclude), done: make(chan struct{}), closed: make(chan struct{}), diff --git a/internal/splitrt/excludes_test.go b/internal/splitrt/excludes_test.go index f3a2bce1..ee3a7f4a 100644 --- a/internal/splitrt/excludes_test.go +++ b/internal/splitrt/excludes_test.go @@ -3,7 +3,7 @@ package splitrt import ( "context" "errors" - "net" + "net/netip" "reflect" "testing" @@ -11,10 +11,10 @@ import ( ) // getTestExcludes returns excludes for testing. -func getTestExcludes(t *testing.T, es []string) []*net.IPNet { - excludes := []*net.IPNet{} +func getTestExcludes(t *testing.T, es []string) []netip.Prefix { + excludes := []netip.Prefix{} for _, s := range es { - _, exclude, err := net.ParseCIDR(s) + exclude, err := netip.ParsePrefix(s) if err != nil { t.Fatal(err) } @@ -24,7 +24,7 @@ func getTestExcludes(t *testing.T, es []string) []*net.IPNet { } // getTestStaticExcludes returns static excludes for testing. -func getTestStaticExcludes(t *testing.T) []*net.IPNet { +func getTestStaticExcludes(t *testing.T) []netip.Prefix { return getTestExcludes(t, []string{ "192.168.1.0/24", "2001::/64", @@ -32,7 +32,7 @@ func getTestStaticExcludes(t *testing.T) []*net.IPNet { } // getTestStaticExcludesOverlap returns static excludes that overlap for testing. -func getTestStaticExcludesOverlap(t *testing.T) []*net.IPNet { +func getTestStaticExcludesOverlap(t *testing.T) []netip.Prefix { return getTestExcludes(t, []string{ "192.168.1.0/26", "192.168.1.64/26", @@ -52,7 +52,7 @@ func getTestStaticExcludesOverlap(t *testing.T) []*net.IPNet { } // getTestDynamicExcludes returns dynamic excludes for testing. -func getTestDynamicExcludes(t *testing.T) []*net.IPNet { +func getTestDynamicExcludes(t *testing.T) []netip.Prefix { return getTestExcludes(t, []string{ "192.168.1.1/32", "2001::1/128", diff --git a/internal/splitrt/filter.go b/internal/splitrt/filter.go index 92344c7c..c2a22148 100644 --- a/internal/splitrt/filter.go +++ b/internal/splitrt/filter.go @@ -3,7 +3,6 @@ package splitrt import ( "context" "fmt" - "net" "net/netip" "strings" @@ -126,10 +125,10 @@ func unsetRoutingRules(ctx context.Context) { // addLocalAddresses adds rules for device and its family (ip, ip6) addresses, // that drop non-local traffic from other devices to device's network // addresses; used to filter non-local traffic to vpn addresses. -func addLocalAddresses(ctx context.Context, device, family string, addresses []*net.IPNet) { +func addLocalAddresses(ctx context.Context, device, family string, addresses []netip.Prefix) { nftconf := "" for _, addr := range addresses { - if addr == nil || len(addr.IP) == 0 || len(addr.Mask) == 0 { + if !addr.IsValid() { continue } nftconf += "add rule inet oc-daemon-routing preraw iifname != " @@ -151,14 +150,14 @@ func addLocalAddresses(ctx context.Context, device, family string, addresses []* // addLocalAddressesIPv4 adds rules for device and its addresses, that drop // non-local traffic from other devices to device's network addresses; used to // filter non-local traffic to vpn addresses. -func addLocalAddressesIPv4(ctx context.Context, device string, addresses []*net.IPNet) { +func addLocalAddressesIPv4(ctx context.Context, device string, addresses []netip.Prefix) { addLocalAddresses(ctx, device, "ip", addresses) } // addLocalAddressesIPv6 adds rules for device and its addresses, that drop // non-local traffic from other devices to device's network addresses; used to // filter non-local traffic to vpn addresses. -func addLocalAddressesIPv6(ctx context.Context, device string, addresses []*net.IPNet) { +func addLocalAddressesIPv6(ctx context.Context, device string, addresses []netip.Prefix) { addLocalAddresses(ctx, device, "ip6", addresses) } @@ -197,7 +196,7 @@ func rejectIPv4(ctx context.Context, device string) { } // addExclude adds exclude address to netfilter. -func addExclude(ctx context.Context, address *netip.Prefix) { +func addExclude(ctx context.Context, address netip.Prefix) { log.WithField("address", address).Debug("SplitRouting adding exclude to netfilter") set := "excludes4" @@ -217,7 +216,7 @@ func addExclude(ctx context.Context, address *netip.Prefix) { } // setExcludes resets the excludes to addresses in netfilter. -func setExcludes(ctx context.Context, addresses []*netip.Prefix) { +func setExcludes(ctx context.Context, addresses []netip.Prefix) { // flush existing entries nftconf := "" nftconf += "flush set inet oc-daemon-routing excludes4\n" diff --git a/internal/splitrt/splitrt.go b/internal/splitrt/splitrt.go index cbc66c67..f77820a0 100644 --- a/internal/splitrt/splitrt.go +++ b/internal/splitrt/splitrt.go @@ -4,7 +4,7 @@ package splitrt import ( "context" "fmt" - "net" + "net/netip" log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/addrmon" @@ -21,7 +21,7 @@ type SplitRouting struct { addrmon *addrmon.AddrMon devices *Devices addrs *Addresses - locals []*net.IPNet + locals []netip.Prefix excludes *Excludes dnsreps chan *dnsproxy.Report done chan struct{} @@ -30,28 +30,30 @@ type SplitRouting struct { // setupRouting sets up routing using config. func (s *SplitRouting) setupRouting(ctx context.Context) { - // get vpn network addresses - ipnet4 := &net.IPNet{ - IP: s.vpnconfig.IPv4.Address, - Mask: s.vpnconfig.IPv4.Netmask, - } - ipnet6 := &net.IPNet{ - IP: s.vpnconfig.IPv6.Address, - Mask: s.vpnconfig.IPv6.Netmask, - } - // prepare netfilter and excludes setRoutingRules(ctx, s.config.FirewallMark) + // convert to netip + pre4 := netip.Prefix{} + if ipv4, ok := netip.AddrFromSlice(s.vpnconfig.IPv4.Address.To4()); ok { + one4, _ := s.vpnconfig.IPv4.Netmask.Size() + pre4 = netip.PrefixFrom(ipv4, one4) + } + pre6 := netip.Prefix{} + if ipv6, ok := netip.AddrFromSlice(s.vpnconfig.IPv6.Address); ok { + one6, _ := s.vpnconfig.IPv6.Netmask.Size() + pre6 = netip.PrefixFrom(ipv6, one6) + } + // filter non-local traffic to vpn addresses - addLocalAddressesIPv4(ctx, s.vpnconfig.Device.Name, []*net.IPNet{ipnet4}) - addLocalAddressesIPv6(ctx, s.vpnconfig.Device.Name, []*net.IPNet{ipnet6}) + addLocalAddressesIPv4(ctx, s.vpnconfig.Device.Name, []netip.Prefix{pre4}) + addLocalAddressesIPv6(ctx, s.vpnconfig.Device.Name, []netip.Prefix{pre6}) // reject unsupported ip versions on vpn - if len(s.vpnconfig.IPv6.Address) == 0 { + if !pre6.IsValid() { rejectIPv6(ctx, s.vpnconfig.Device.Name) } - if len(s.vpnconfig.IPv4.Address) == 0 { + if !pre4.IsValid() { rejectIPv4(ctx, s.vpnconfig.Device.Name) } @@ -59,21 +61,19 @@ func (s *SplitRouting) setupRouting(ctx context.Context) { s.excludes.Start() // add gateway to static excludes - gateway := &net.IPNet{ - IP: s.vpnconfig.Gateway, - Mask: net.CIDRMask(32, 32), - } - if gateway.IP.To4() == nil { - gateway.Mask = net.CIDRMask(128, 128) + if s.vpnconfig.Gateway != nil { + g := netip.MustParseAddr(s.vpnconfig.Gateway.String()) + gateway := netip.PrefixFrom(g, g.BitLen()) + s.excludes.AddStatic(ctx, gateway) } - s.excludes.AddStatic(ctx, gateway) // add static IPv4 excludes for _, e := range s.vpnconfig.Split.ExcludeIPv4 { if e.String() == "0.0.0.0/32" { continue } - s.excludes.AddStatic(ctx, e) + p := netip.MustParsePrefix(e.String()) + s.excludes.AddStatic(ctx, p) } // add static IPv6 excludes @@ -82,7 +82,8 @@ func (s *SplitRouting) setupRouting(ctx context.Context) { if e.String() == "::/128" { continue } - s.excludes.AddStatic(ctx, e) + p := netip.MustParsePrefix(e.String()) + s.excludes.AddStatic(ctx, p) } // setup routing @@ -132,14 +133,14 @@ func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) { } // get addresses of these devices - excludes := []*net.IPNet{} + excludes := []netip.Prefix{} for _, d := range devs { excludes = append(excludes, s.addrs.Get(d)...) } // determine changes // TODO: move s.locals into excludes? - isIn := func(n *net.IPNet, nets []*net.IPNet) bool { + isIn := func(n netip.Prefix, nets []netip.Prefix) bool { for _, net := range nets { if n.String() == net.String() { return true @@ -205,17 +206,7 @@ func (s *SplitRouting) handleDNSReport(ctx context.Context, r *dnsproxy.Report) defer r.Close() log.WithField("report", r).Debug("SplitRouting handling DNS report") - if r.IP.To4() != nil { - s.excludes.AddDynamic(ctx, &net.IPNet{ - IP: r.IP, - Mask: net.CIDRMask(32, 32), - }, r.TTL) - return - } - s.excludes.AddDynamic(ctx, &net.IPNet{ - IP: r.IP, - Mask: net.CIDRMask(128, 128), - }, r.TTL) + s.excludes.AddDynamic(ctx, netip.PrefixFrom(r.IP, r.IP.BitLen()), r.TTL) } // start starts split routing. diff --git a/internal/splitrt/splitrt_test.go b/internal/splitrt/splitrt_test.go index 405854d9..54f8e433 100644 --- a/internal/splitrt/splitrt_test.go +++ b/internal/splitrt/splitrt_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "reflect" "strings" "testing" @@ -166,12 +167,12 @@ func TestSplitRoutingHandleDNSReport(t *testing.T) { defer func() { execs.RunCmd = oldRunCmd }() // test ipv4 - report := dnsproxy.NewReport("example.com", net.ParseIP("192.168.1.1"), 300) + report := dnsproxy.NewReport("example.com", netip.MustParseAddr("192.168.1.1"), 300) go s.handleDNSReport(ctx, report) <-report.Done() // test ipv6 - report = dnsproxy.NewReport("example.com", net.ParseIP("2001::1"), 300) + report = dnsproxy.NewReport("example.com", netip.MustParseAddr("2001::1"), 300) go s.handleDNSReport(ctx, report) <-report.Done() @@ -257,7 +258,7 @@ func TestSplitRoutingStartStop(t *testing.T) { } s.devmon.Updates() <- getTestDevMonUpdate() s.addrmon.Updates() <- getTestAddrMonUpdate(t, "192.168.1.1/32") - report := dnsproxy.NewReport("example.com", net.ParseIP("192.168.1.1"), 300) + report := dnsproxy.NewReport("example.com", netip.MustParseAddr("192.168.1.1"), 300) s.dnsreps <- report <-report.Done() s.Stop() diff --git a/internal/trafpol/filter.go b/internal/trafpol/filter.go index 3b169c43..dadcd314 100644 --- a/internal/trafpol/filter.go +++ b/internal/trafpol/filter.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "net" + "net/netip" "strconv" "strings" @@ -206,7 +206,7 @@ func removeAllowedDevice(ctx context.Context, device string) { } // setAllowedIPs set the allowed hosts. -func setAllowedIPs(ctx context.Context, ips []*net.IPNet) { +func setAllowedIPs(ctx context.Context, ips []netip.Prefix) { // we perform all nft commands separately here and not as one atomic // operation to avoid issues where the whole update fails because nft // runs into "file exists" errors even though we remove duplicates from @@ -234,7 +234,7 @@ func setAllowedIPs(ctx context.Context, ips []*net.IPNet) { fmt4 := "add element inet oc-daemon-filter allowhosts4 { %s }" fmt6 := "add element inet oc-daemon-filter allowhosts6 { %s }" for _, ip := range ips { - if ip.IP.To4() != nil { + if ip.Addr().Is4() { // ipv4 address nftconf := fmt.Sprintf(fmt4, ip) if stdout, stderr, err := execs.RunNft(ctx, nftconf); err != nil && diff --git a/internal/trafpol/filter_test.go b/internal/trafpol/filter_test.go index a44fde49..ad6ceb00 100644 --- a/internal/trafpol/filter_test.go +++ b/internal/trafpol/filter_test.go @@ -3,7 +3,7 @@ package trafpol import ( "context" "errors" - "net" + "net/netip" "testing" "github.com/telekom-mms/oc-daemon/internal/execs" @@ -28,9 +28,9 @@ func TestFilterFunctionsErrors(_ *testing.T) { removeAllowedDevice(ctx, "eth0") // allowed IPs - setAllowedIPs(ctx, []*net.IPNet{ - {IP: net.ParseIP("192.168.1.1"), Mask: net.CIDRMask(32, 32)}, - {IP: net.ParseIP("2000::1"), Mask: net.CIDRMask(128, 128)}, + setAllowedIPs(ctx, []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("2000::1/128"), }) // portal ports diff --git a/internal/trafpol/resolver.go b/internal/trafpol/resolver.go index 188bf4ee..cabf58e1 100644 --- a/internal/trafpol/resolver.go +++ b/internal/trafpol/resolver.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "sort" "sync" "time" @@ -12,7 +13,7 @@ import ( // ResolvedName is a resolved DNS name. type ResolvedName struct { Name string - IPs []net.IP + IPs []netip.Addr TTL time.Duration } @@ -48,8 +49,8 @@ func (r *ResolvedName) resolve(ctx context.Context, config *Config, updates chan // lookup IPv4 and IPv6 addresses in one call (argument // network == "ip"). So, resolve IPv4 and IPv6 addresses in // separate calls - ipv4s, err4 := resolver.LookupIP(ctxTO, "ip4", r.Name) - ipv6s, err6 := resolver.LookupIP(ctxTO, "ip6", r.Name) + ipv4s, err4 := resolver.LookupNetIP(ctxTO, "ip4", r.Name) + ipv6s, err6 := resolver.LookupNetIP(ctxTO, "ip6", r.Name) if err4 != nil && err6 != nil { // do not retry hostnames that are not found var dnsErr4 *net.DNSError @@ -79,7 +80,7 @@ func (r *ResolvedName) resolve(ctx context.Context, config *Config, updates chan return false } for i := range r.IPs { - if !r.IPs[i].Equal(ips[i]) { + if r.IPs[i] != ips[i] { return false } } diff --git a/internal/trafpol/resolver_test.go b/internal/trafpol/resolver_test.go index 898cfd51..ee0ce6df 100644 --- a/internal/trafpol/resolver_test.go +++ b/internal/trafpol/resolver_test.go @@ -1,7 +1,7 @@ package trafpol import ( - "net" + "net/netip" "testing" "time" ) @@ -45,7 +45,7 @@ func TestResolverResolve(_ *testing.T) { names = []string{"example.com"} r = NewResolver(config, names, updates) r.names["example.com"] = u - r.names["example.com"].IPs[0] = net.ParseIP("127.0.0.1") + r.names["example.com"].IPs[0] = netip.MustParseAddr("127.0.0.1") r.Start() diff --git a/internal/trafpol/trafpol.go b/internal/trafpol/trafpol.go index 21a0da4d..a904baeb 100644 --- a/internal/trafpol/trafpol.go +++ b/internal/trafpol/trafpol.go @@ -4,7 +4,7 @@ package trafpol import ( "context" "fmt" - "net" + "net/netip" log "github.com/sirupsen/logrus" "github.com/telekom-mms/oc-daemon/internal/cpd" @@ -15,7 +15,7 @@ import ( // trafPolAddrCmd is a TrafPol address command. type trafPolAddrCmd struct { add bool - ip net.IP + ip netip.Addr ok bool done chan struct{} } @@ -32,8 +32,8 @@ type TrafPol struct { // allowed devices, addresses, names allowDevs *AllowDevs - allowAddrs map[string]*net.IPNet - allowNames map[string][]net.IP + allowAddrs map[string]netip.Prefix + allowNames map[string][]netip.Addr // resolver for allowed names, channel for resolver updates resolver *Resolver @@ -96,21 +96,15 @@ func (t *TrafPol) handleCPDReport(ctx context.Context, report *cpd.Report) { // getAllowedHostsIPs returns the IPs of the allowed hosts, // used for filter rules -func (t *TrafPol) getAllowedHostsIPs() []*net.IPNet { +func (t *TrafPol) getAllowedHostsIPs() []netip.Prefix { // get a list of all unique ip addresses from // - allowed names // - allowed addrs - ipset := make(map[string]*net.IPNet) + ipset := make(map[string]netip.Prefix) for _, n := range t.allowNames { for _, ip := range n { - ipnet := &net.IPNet{ - IP: ip, - Mask: net.CIDRMask(32, 32), - } - if ip.To4() == nil { - ipnet.Mask = net.CIDRMask(128, 128) - } - ipset[ipnet.String()] = ipnet + prefix := netip.PrefixFrom(ip, ip.BitLen()) + ipset[prefix.String()] = prefix } } for _, a := range t.allowAddrs { @@ -118,7 +112,7 @@ func (t *TrafPol) getAllowedHostsIPs() []*net.IPNet { } // get resulting list of IPs - ips := []*net.IPNet{} + ips := []netip.Prefix{} for _, ip := range ipset { ips = append(ips, ip) } @@ -139,20 +133,17 @@ func (t *TrafPol) handleResolverUpdate(ctx context.Context, update *ResolvedName func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolAddrCmd) { defer close(cmd.done) - // convert to ipnet - ipnet := &net.IPNet{IP: cmd.ip, Mask: net.CIDRMask(32, 32)} - if cmd.ip.To4() == nil { - ipnet.Mask = net.CIDRMask(128, 128) - } + // convert to prefix + prefix := netip.PrefixFrom(cmd.ip, cmd.ip.BitLen()) // update allowed addrs - s := ipnet.String() + s := prefix.String() if cmd.add { if _, ok := t.allowAddrs[s]; ok { // ip already in allowed addrs return } - t.allowAddrs[s] = ipnet + t.allowAddrs[s] = prefix } else { if _, ok := t.allowAddrs[s]; !ok { // ip not in allowed addrs @@ -269,7 +260,7 @@ func (t *TrafPol) Stop() { } // AddAllowedAddr adds addr to the allowed addresses. -func (t *TrafPol) AddAllowedAddr(addr net.IP) (ok bool) { +func (t *TrafPol) AddAllowedAddr(addr netip.Addr) (ok bool) { log.WithField("addr", addr). Debug("TrafPol adding IP to allowed addresses") @@ -285,7 +276,7 @@ func (t *TrafPol) AddAllowedAddr(addr net.IP) (ok bool) { } // RemoveAllowedAddr removes addr from the allowed addresses. -func (t *TrafPol) RemoveAllowedAddr(addr net.IP) (ok bool) { +func (t *TrafPol) RemoveAllowedAddr(addr netip.Addr) (ok bool) { log.WithField("addr", addr). Debug("TrafPol removing IP from allowed addresses") @@ -300,23 +291,17 @@ func (t *TrafPol) RemoveAllowedAddr(addr net.IP) (ok bool) { } // parseAllowedHosts parses the allowed hosts and returns IP addresses and DNS names -func parseAllowedHosts(hosts []string) (addrs []*net.IPNet, names []string) { +func parseAllowedHosts(hosts []string) (addrs []netip.Prefix, names []string) { for _, h := range hosts { // check if it's an IP address - if ip := net.ParseIP(h); ip != nil { - ipnet := &net.IPNet{ - IP: ip, - Mask: net.CIDRMask(32, 32), - } - if ip.To4() == nil { - ipnet.Mask = net.CIDRMask(128, 128) - } - addrs = append(addrs, ipnet) + if ip, err := netip.ParseAddr(h); err == nil { + prefix := netip.PrefixFrom(ip, ip.BitLen()) + addrs = append(addrs, prefix) continue } // check if it's an IP network - if _, ipnet, err := net.ParseCIDR(h); err == nil { - addrs = append(addrs, ipnet) + if prefix, err := netip.ParsePrefix(h); err == nil { + addrs = append(addrs, prefix) continue } @@ -336,13 +321,13 @@ func NewTrafPol(config *Config) *TrafPol { a, n := parseAllowedHosts(hosts) // create allowed addrs and names - addrs := make(map[string]*net.IPNet) - names := make(map[string][]net.IP) + addrs := make(map[string]netip.Prefix) + names := make(map[string][]netip.Addr) for _, addr := range a { addrs[addr.String()] = addr } for _, name := range n { - names[name] = []net.IP{} + names[name] = []netip.Addr{} } // create channel for resolver updates diff --git a/internal/trafpol/trafpol_test.go b/internal/trafpol/trafpol_test.go index 3a91d37d..dfe6d911 100644 --- a/internal/trafpol/trafpol_test.go +++ b/internal/trafpol/trafpol_test.go @@ -2,7 +2,7 @@ package trafpol import ( "context" - "net" + "net/netip" "reflect" "sort" "sync" @@ -132,19 +132,21 @@ func TestTrafPolGetAllowedHostsIPs(t *testing.T) { tp := NewTrafPol(c) // add allowed names - tp.allowNames["example.com"] = []net.IP{net.ParseIP("192.168.1.1"), - net.ParseIP("2001:DB8:1::1")} + tp.allowNames["example.com"] = []netip.Addr{ + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("2001:DB8:1::1"), + } // wanted IPs - want := []*net.IPNet{} + want := []netip.Prefix{} for _, addr := range []string{ "192.168.1.1/32", "192.168.2.0/24", "2001:db8:1::1/128", "2001:db8:2::/64", } { - _, ipnet, _ := net.ParseCIDR(addr) - want = append(want, ipnet) + prefix := netip.MustParsePrefix(addr) + want = append(want, prefix) } // get IPs @@ -195,40 +197,40 @@ func TestTrafPolAddRemoveAllowedAddr(t *testing.T) { } // add ipv4 address - _, ipnet, _ := net.ParseCIDR("192.168.1.1/32") - if ok := tp.AddAllowedAddr(ipnet.IP); !ok { + prefix := netip.MustParsePrefix("192.168.1.1/32") + if ok := tp.AddAllowedAddr(prefix.Addr()); !ok { t.Errorf("address not added") } - want := ipnet.String() - got := tp.allowAddrs[ipnet.String()].String() + want := prefix.String() + got := tp.allowAddrs[prefix.String()].String() if got != want { t.Errorf("got %s, want %s", got, want) } // add ipv4 address again - if ok := tp.AddAllowedAddr(ipnet.IP); ok { + if ok := tp.AddAllowedAddr(prefix.Addr()); ok { t.Errorf("existing address should not be added again") } // remove ipv4 address - if ok := tp.RemoveAllowedAddr(ipnet.IP); !ok { + if ok := tp.RemoveAllowedAddr(prefix.Addr()); !ok { t.Errorf("address not removed") } - want = "" - got = tp.allowAddrs[ipnet.String()].String() + want = netip.Prefix{}.String() + got = tp.allowAddrs[prefix.String()].String() if got != want { t.Errorf("got %s, want %s", got, want) } // remove ipv4 address again - if ok := tp.RemoveAllowedAddr(ipnet.IP); ok { + if ok := tp.RemoveAllowedAddr(prefix.Addr()); ok { t.Errorf("not existing address should not be removed") } // add/remove ipv6 address - ip := net.ParseIP("2001:DB8:1::1") + ip := netip.MustParseAddr("2001:DB8:1::1") if ok := tp.AddAllowedAddr(ip); !ok { t.Errorf("address not added") } diff --git a/internal/vpnsetup/vpnsetup.go b/internal/vpnsetup/vpnsetup.go index cdfd81b6..a8a60ecc 100644 --- a/internal/vpnsetup/vpnsetup.go +++ b/internal/vpnsetup/vpnsetup.go @@ -3,7 +3,7 @@ package vpnsetup import ( "context" - "net" + "net/netip" "strconv" "strings" "time" @@ -70,13 +70,9 @@ func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { } // set ipv4 and ipv6 addresses on device - setupIP := func(ip net.IP, mask net.IPMask) { - ipnet := &net.IPNet{ - IP: ip, - Mask: mask, - } + setupIP := func(a netip.Prefix) { dev := c.Device.Name - addr := ipnet.String() + addr := a.String() if stdout, stderr, err := execs.RunIPAddress(ctx, "add", addr, "dev", dev); err != nil { log.WithError(err).WithFields(log.Fields{ "device": dev, @@ -88,11 +84,18 @@ func setupVPNDevice(ctx context.Context, c *vpnconfig.Config) { } } - if len(c.IPv4.Address) > 0 { - setupIP(c.IPv4.Address, c.IPv4.Netmask) + + if ipv4, ok := netip.AddrFromSlice(c.IPv4.Address.To4()); ok { + one4, _ := c.IPv4.Netmask.Size() + pre4 := netip.PrefixFrom(ipv4, one4) + + setupIP(pre4) } - if len(c.IPv6.Address) > 0 { - setupIP(c.IPv6.Address, c.IPv6.Netmask) + if ipv6, ok := netip.AddrFromSlice(c.IPv6.Address); ok { + one6, _ := c.IPv6.Netmask.Size() + pre6 := netip.PrefixFrom(ipv6, one6) + + setupIP(pre6) } } diff --git a/internal/vpnsetup/vpnsetup_test.go b/internal/vpnsetup/vpnsetup_test.go index 5fe26cb5..2cd577bd 100644 --- a/internal/vpnsetup/vpnsetup_test.go +++ b/internal/vpnsetup/vpnsetup_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "reflect" "strings" "testing" @@ -361,7 +362,7 @@ func TestVPNSetupSetupTeardown(_ *testing.T) { v.Setup(vpnconf) // send dns report while config is active - report := dnsproxy.NewReport("example.com", nil, 300) + report := dnsproxy.NewReport("example.com", netip.Addr{}, 300) v.dnsProxy.Reports() <- report // wait long enough for ensure timer @@ -371,7 +372,7 @@ func TestVPNSetupSetupTeardown(_ *testing.T) { v.Teardown(vpnconf) // send dns report while config is not active - v.dnsProxy.Reports() <- dnsproxy.NewReport("example.com", nil, 300) + v.dnsProxy.Reports() <- dnsproxy.NewReport("example.com", netip.Addr{}, 300) // stop vpn setup v.Stop()