Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VPN server IP to allowed addresses in TrafPol #105

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"reflect"
"strconv"
"strings"
"syscall"
"time"

Expand Down Expand Up @@ -62,6 +63,13 @@ type Daemon struct {
// disableTrafPol determines if traffic policing should be disabled,
// overrides other traffic policing settings
disableTrafPol bool

// serverIP is the IP address of the current VPN server
serverIP net.IP

// serverIPAllowed indicates whether server IP was added to
// the allowed addresses
serverIPAllowed bool
}

// setStatusTrustedNetwork sets the trusted network status in status.
Expand Down Expand Up @@ -224,6 +232,12 @@ func (d *Daemon) connectVPN(login *logininfo.LoginInfo) {
d.setStatusServer(login.Server)
d.setStatusConnectionState(vpnstatus.ConnectionStateConnecting)

// set server address and add it to allowed addrs in trafpol
d.serverIP = net.ParseIP(strings.Trim(login.Host, "[]"))
if d.trafpol != nil && d.serverIP != nil {
d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP)
}

// connect using runner
env := []string{
"oc_daemon_token=" + d.token,
Expand Down Expand Up @@ -437,6 +451,13 @@ func (d *Daemon) handleRunnerDisconnect() {

// make sure the vpn config is not active any more
d.updateVPNConfigDown()

// remove server ip from allowed addrs and delete it
if d.trafpol != nil && d.serverIPAllowed {
d.trafpol.RemoveAllowedAddr(d.serverIP)
}
d.serverIP = nil
d.serverIPAllowed = false
}

// handleRunnerEvent handles a connect event from the OC runner.
Expand Down Expand Up @@ -622,6 +643,12 @@ func (d *Daemon) startTrafPol() error {
if err := d.trafpol.Start(); err != nil {
return fmt.Errorf("Daemon could not start TrafPol: %w", err)
}

if d.serverIP != nil {
// VPN connection active, allow server IP
d.serverIPAllowed = d.trafpol.AddAllowedAddr(d.serverIP)
}

return nil
}

Expand All @@ -633,6 +660,7 @@ func (d *Daemon) stopTrafPol() {
log.Info("Daemon stopping TrafPol")
d.trafpol.Stop()
d.trafpol = nil
d.serverIPAllowed = false
}

// checkTrafPol checks if traffic policing should be running and
Expand Down
82 changes: 82 additions & 0 deletions internal/trafpol/trafpol.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ import (
"github.com/telekom-mms/oc-daemon/internal/dnsmon"
)

// trafPolAddrCmd is a TrafPol address command.
type trafPolAddrCmd struct {
add bool
ip net.IP
ok bool
done chan struct{}
}

// TrafPol is a traffic policing component.
type TrafPol struct {
config *Config
Expand All @@ -31,6 +39,9 @@ type TrafPol struct {
resolver *Resolver
resolvUp chan *ResolvedName

// address commands channel
cmds chan *trafPolAddrCmd

loopDone chan struct{}
done chan struct{}
}
Expand Down Expand Up @@ -124,6 +135,39 @@ func (t *TrafPol) handleResolverUpdate(ctx context.Context, update *ResolvedName
setAllowedIPs(ctx, t.getAllowedHostsIPs())
}

// handleAddressCommand handles an address command.
func (t *TrafPol) handleAddressCommand(ctx context.Context, cmd *trafPolAddrCmd) {
defer close(cmd.done)

// convert to ipnet
ipnet := &net.IPNet{IP: cmd.ip, Mask: net.CIDRMask(32, 32)}
if cmd.ip.To4() == nil {
ipnet.Mask = net.CIDRMask(128, 128)
}

// update allowed addrs
s := ipnet.String()
if cmd.add {
if _, ok := t.allowAddrs[s]; ok {
// ip already in allowed addrs
return
}
t.allowAddrs[s] = ipnet
} else {
if _, ok := t.allowAddrs[s]; !ok {
// ip not in allowed addrs
return
}
delete(t.allowAddrs, s)
}

// set new filter rules
setAllowedIPs(ctx, t.getAllowedHostsIPs())

// added/removed successfully
cmd.ok = true
}

// start starts the traffic policing component.
func (t *TrafPol) start(ctx context.Context) {
defer close(t.loopDone)
Expand Down Expand Up @@ -156,6 +200,11 @@ func (t *TrafPol) start(ctx context.Context) {
log.WithField("update", u).Debug("TrafPol got Resolver update")
t.handleResolverUpdate(ctx, u)

case c := <-t.cmds:
// Address Command
log.WithField("command", c).Debug("TrafPol got address command")
t.handleAddressCommand(ctx, c)

case <-t.done:
// shutdown
return
Expand Down Expand Up @@ -219,6 +268,37 @@ func (t *TrafPol) Stop() {
log.Debug("TrafPol stopped")
}

// AddAllowedAddr adds addr to the allowed addresses.
func (t *TrafPol) AddAllowedAddr(addr net.IP) (ok bool) {
log.WithField("addr", addr).
Debug("TrafPol adding IP to allowed addresses")

c := &trafPolAddrCmd{
add: true,
ip: addr,
done: make(chan struct{}),
}
t.cmds <- c
<-c.done

return c.ok
}

// RemoveAllowedAddr removes addr from the allowed addresses.
func (t *TrafPol) RemoveAllowedAddr(addr net.IP) (ok bool) {
log.WithField("addr", addr).
Debug("TrafPol removing IP from allowed addresses")

c := &trafPolAddrCmd{
ip: addr,
done: make(chan struct{}),
}
t.cmds <- c
<-c.done

return c.ok
}

// parseAllowedHosts parses the allowed hosts and returns IP addresses and DNS names
func parseAllowedHosts(hosts []string) (addrs []*net.IPNet, names []string) {
for _, h := range hosts {
Expand Down Expand Up @@ -282,6 +362,8 @@ func NewTrafPol(config *Config) *TrafPol {
resolver: NewResolver(config, n, resolvUp),
resolvUp: resolvUp,

cmds: make(chan *trafPolAddrCmd),

loopDone: make(chan struct{}),
done: make(chan struct{}),
}
Expand Down
60 changes: 60 additions & 0 deletions internal/trafpol/trafpol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,65 @@ func TestTrafPolStartStop(t *testing.T) {
tp.Stop()
}

// TestTrafPolAddRemoveAllowedAddr tests AddAllowedAddr and RemoveAllowedAddr of Trafpol.
func TestTrafPolAddRemoveAllowedAddr(t *testing.T) {
// set dummy low level function for devmon
oldRegisterLinkUpdates := devmon.RegisterLinkUpdates
devmon.RegisterLinkUpdates = func(*devmon.DevMon) (chan netlink.LinkUpdate, error) {
return nil, nil
}
defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }()

tp := NewTrafPol(NewConfig())
if err := tp.Start(); err != nil {
t.Fatal(err)
}

// add ipv4 address
_, ipnet, _ := net.ParseCIDR("192.168.1.1/32")
if ok := tp.AddAllowedAddr(ipnet.IP); !ok {
t.Errorf("address not added")
}

want := ipnet.String()
got := tp.allowAddrs[ipnet.String()].String()
if got != want {
t.Errorf("got %s, want %s", got, want)
}

// add ipv4 address again
if ok := tp.AddAllowedAddr(ipnet.IP); ok {
t.Errorf("existing address should not be added again")
}

// remove ipv4 address
if ok := tp.RemoveAllowedAddr(ipnet.IP); !ok {
t.Errorf("address not removed")
}

want = "<nil>"
got = tp.allowAddrs[ipnet.String()].String()
if got != want {
t.Errorf("got %s, want %s", got, want)
}

// remove ipv4 address again
if ok := tp.RemoveAllowedAddr(ipnet.IP); ok {
t.Errorf("not existing address should not be removed")
}

// add/remove ipv6 address
ip := net.ParseIP("2001:DB8:1::1")
if ok := tp.AddAllowedAddr(ip); !ok {
t.Errorf("address not added")
}
if ok := tp.RemoveAllowedAddr(ip); !ok {
t.Errorf("address not removed")
}

tp.Stop()
}

// TestNewTrafPol tests NewTrafPol.
func TestNewTrafPol(t *testing.T) {
c := NewConfig()
Expand All @@ -197,6 +256,7 @@ func TestNewTrafPol(t *testing.T) {
tp.allowNames == nil ||
tp.resolver == nil ||
tp.resolvUp == nil ||
tp.cmds == nil ||
tp.loopDone == nil ||
tp.done == nil {

Expand Down
Loading