Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to package netip internally #120

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions internal/addrmon/addrmon.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package addrmon

import (
"fmt"
"net"
"net/netip"

log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
Expand All @@ -12,7 +12,7 @@ import (
// Update is an address update.
type Update struct {
Add bool
Address net.IPNet
Address netip.Prefix
Index int
}

Expand Down Expand Up @@ -72,8 +72,16 @@ func (a *AddrMon) start() {
}

// forward event as address update
ip, ok := netip.AddrFromSlice(e.LinkAddress.IP)
if !ok || !ip.IsValid() {
log.WithField("LinkAddress", e.LinkAddress).
Error("AddrMon got invalid IP in addr event")
continue
}
ones, _ := e.LinkAddress.Mask.Size()
addr := netip.PrefixFrom(ip, ones)
u := &Update{
Address: e.LinkAddress,
Address: addr,
Index: e.LinkIndex,
Add: e.NewAddr,
}
Expand Down
8 changes: 7 additions & 1 deletion internal/addrmon/addrmon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package addrmon

import (
"log"
"net"
"testing"

"github.com/vishvananda/netlink"
Expand Down Expand Up @@ -44,7 +45,12 @@ func TestAddrMonStartStop(t *testing.T) {
// helper function for AddrUpdates
addrUpdates := func(updates chan netlink.AddrUpdate, done chan struct{}) {
for {
up := netlink.AddrUpdate{}
up := netlink.AddrUpdate{
LinkAddress: net.IPNet{
IP: net.IPv4(192, 168, 1, 1),
Mask: net.CIDRMask(24, 32),
},
}
select {
case updates <- up:
case <-done:
Expand Down
13 changes: 8 additions & 5 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"reflect"
"slices"
"strconv"
Expand Down Expand Up @@ -66,7 +67,7 @@ type Daemon struct {
disableTrafPol bool

// serverIP is the IP address of the current VPN server
serverIP net.IP
serverIP netip.Addr

// serverIPAllowed indicates whether server IP was added to
// the allowed addresses
Expand Down Expand Up @@ -307,7 +308,9 @@ func (d *Daemon) connectVPN(login *logininfo.LoginInfo) {
}

// set server address
d.serverIP = net.ParseIP(strings.Trim(login.Host, "[]"))
if serverIP, err := netip.ParseAddr(strings.Trim(login.Host, "[]")); err == nil {
d.serverIP = serverIP
}

// update status
d.setStatusOCRunning(true)
Expand All @@ -316,7 +319,7 @@ func (d *Daemon) connectVPN(login *logininfo.LoginInfo) {
d.setStatusConnectionState(vpnstatus.ConnectionStateConnecting)

// add server address to allowed addrs in trafpol
if d.trafpol != nil && d.serverIP != nil {
if d.trafpol != nil && d.serverIP.IsValid() {
d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP)
}

Expand Down Expand Up @@ -542,7 +545,7 @@ func (d *Daemon) handleRunnerDisconnect() {
if d.trafpol != nil && d.serverIPAllowed {
d.trafpol.RemoveAllowedAddr(d.serverIP)
}
d.serverIP = nil
d.serverIP = netip.Addr{}
d.serverIPAllowed = false
}

Expand Down Expand Up @@ -743,7 +746,7 @@ func (d *Daemon) startTrafPol() error {
d.setStatusTrafPolState(vpnstatus.TrafPolStateActive)
d.setStatusAllowedHosts(c.AllowedHosts)

if d.serverIP != nil {
if d.serverIP.IsValid() {
// VPN connection active, allow server IP
d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP)
}
Expand Down
17 changes: 15 additions & 2 deletions internal/dnsproxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dnsproxy

import (
"math/rand"
"net/netip"

"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -101,7 +102,13 @@ func (p *Proxy) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
log.Error("DNS-Proxy received invalid A record in reply")
return
}
report := NewReport(rr.Hdr.Name, rr.A, rr.Hdr.Ttl)
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.WithField("A", rr.A).
Error("DNS-Proxy received invalid IP in A record in reply")
return
}
report := NewReport(rr.Hdr.Name, addr, rr.Hdr.Ttl)
p.sendReport(report)
p.waitReport(report)
}
Expand All @@ -114,7 +121,13 @@ func (p *Proxy) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
log.Error("DNS-Proxy received invalid AAAA record in reply")
return
}
report := NewReport(rr.Hdr.Name, rr.AAAA, rr.Hdr.Ttl)
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.WithField("AAAA", rr.AAAA).
Error("DNS-Proxy received invalid IP in AAAA record in reply")
return
}
report := NewReport(rr.Hdr.Name, addr, rr.Hdr.Ttl)
p.sendReport(report)
p.waitReport(report)
}
Expand Down
9 changes: 5 additions & 4 deletions internal/dnsproxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dnsproxy
import (
"errors"
"net"
"net/netip"
"testing"

"github.com/miekg/dns"
Expand Down Expand Up @@ -127,8 +128,8 @@ func TestProxyHandleRequest(t *testing.T) {
if r.Name != "example.com." {
t.Errorf("invalid domain name: %s", r.Name)
}
if !r.IP.Equal(net.IPv4(127, 0, 0, 1)) &&
!r.IP.Equal(net.ParseIP("::1")) {
if r.IP != netip.MustParseAddr("127.0.0.1") &&
r.IP != netip.MustParseAddr("::1") {
t.Errorf("invalid IP: %s", r.IP)
}
}
Expand Down Expand Up @@ -205,8 +206,8 @@ func TestProxyHandleRequestRecords(t *testing.T) {
t.Fatalf("invalid reports for run %d: %v", i, reports)
}
for _, r := range reports {
if !r.IP.Equal(net.ParseIP("127.0.0.1")) &&
!r.IP.Equal(net.ParseIP("::1")) {
if r.IP != netip.MustParseAddr("127.0.0.1") &&
r.IP != netip.MustParseAddr("::1") {

t.Errorf("invalid report for run %d: %v", i, r)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/dnsproxy/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package dnsproxy

import (
"fmt"
"net"
"net/netip"
)

// Report is a report for a watched domain.
type Report struct {
Name string
IP net.IP
IP netip.Addr
TTL uint32

// done is used to signal that the report has been handled by
Expand All @@ -32,7 +32,7 @@ func (r *Report) Done() <-chan struct{} {
}

// NewReport returns a new report with domain name, IP and TTL.
func NewReport(name string, ip net.IP, ttl uint32) *Report {
func NewReport(name string, ip netip.Addr, ttl uint32) *Report {
return &Report{
Name: name,
IP: ip,
Expand Down
10 changes: 5 additions & 5 deletions internal/dnsproxy/report_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package dnsproxy

import (
"net"
"net/netip"
"testing"
)

// TestReportString tests String of Report.
func TestReportString(t *testing.T) {
name := "example.com."
ip := net.IPv4(192, 168, 1, 1)
ip := netip.MustParseAddr("192.168.1.1")
ttl := uint32(300)
r := NewReport(name, ip, ttl)

Expand All @@ -22,7 +22,7 @@ func TestReportString(t *testing.T) {
// TestReportDone tests Wait and Done of Report.
func TestReportWaitDone(_ *testing.T) {
name := "example.com."
ip := net.IPv4(192, 168, 1, 1)
ip := netip.MustParseAddr("192.168.1.1")
ttl := uint32(300)
r := NewReport(name, ip, ttl)

Expand All @@ -33,7 +33,7 @@ func TestReportWaitDone(_ *testing.T) {
// TestNewReport tests NewReport.
func TestNewReport(t *testing.T) {
name := "example.com."
ip := net.IPv4(192, 168, 1, 1)
ip := netip.MustParseAddr("192.168.1.1")
ttl := uint32(300)
r := NewReport(name, ip, ttl)

Expand All @@ -43,7 +43,7 @@ func TestNewReport(t *testing.T) {
if r.Name != name {
t.Errorf("got %s, want %s", r.Name, name)
}
if !r.IP.Equal(ip) {
if r.IP != ip {
t.Errorf("got %s, want %s", r.IP, ip)
}
if r.TTL != ttl {
Expand Down
6 changes: 3 additions & 3 deletions internal/splitrt/addresses.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package splitrt

import (
"net"
"net/netip"

"github.com/telekom-mms/oc-daemon/internal/addrmon"
)
Expand Down Expand Up @@ -51,9 +51,9 @@ func (a *Addresses) Remove(addr *addrmon.Update) {
}

// Get returns the addresses of the device identified by index.
func (a *Addresses) Get(index int) (addrs []*net.IPNet) {
func (a *Addresses) Get(index int) (addrs []netip.Prefix) {
for _, x := range a.m[index] {
addrs = append(addrs, &x.Address)
addrs = append(addrs, x.Address)
}
return
}
Expand Down
18 changes: 9 additions & 9 deletions internal/splitrt/addresses_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package splitrt

import (
"net"
"net/netip"
"reflect"
"testing"

Expand All @@ -10,14 +10,14 @@ import (

// getTestAddrMonUpdate returns an AddrMon update for testing.
func getTestAddrMonUpdate(t *testing.T, addr string) *addrmon.Update {
_, ipnet, err := net.ParseCIDR(addr)
prefix, err := netip.ParsePrefix(addr)
if err != nil {
t.Fatal(err)
}

return &addrmon.Update{
Add: true,
Address: *ipnet,
Address: prefix,
Index: 1,
}
}
Expand Down Expand Up @@ -72,16 +72,16 @@ func TestAddressesGet(t *testing.T) {
update2 := getTestAddrMonUpdate(t, "192.168.2.0/24")

// get empty
var want []*net.IPNet
var want []netip.Prefix
got := a.Get(1)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}

// get with one address
a.Add(update1)
want = []*net.IPNet{
&update1.Address,
want = []netip.Prefix{
update1.Address,
}
got = a.Get(1)
if !reflect.DeepEqual(got, want) {
Expand All @@ -97,9 +97,9 @@ func TestAddressesGet(t *testing.T) {

// get with multiple addresses
a.Add(update2)
want = []*net.IPNet{
&update1.Address,
&update2.Address,
want = []netip.Prefix{
update1.Address,
update2.Address,
}
got = a.Get(1)
if !reflect.DeepEqual(got, want) {
Expand Down
Loading
Loading