diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 5cf0a028a..e0d26a28d 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -4,10 +4,14 @@ import ( "flag" "fmt" "os" + "os/signal" + "syscall" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/port_forwarder" + "github.com/slackhq/nebula/service" "github.com/slackhq/nebula/util" ) @@ -52,16 +56,58 @@ func main() { os.Exit(1) } - ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - if err != nil { - util.LogWithContextIfNeeded("Failed to start", err, l) - os.Exit(1) + fwd_list := port_forwarder.NewPortForwardingList() + disabled_tun := c.GetBool("tun.disabled", false) + activate_service_anyway := c.GetBool("port_forwarding.enable_without_rules", false) + if disabled_tun { + port_forwarder.ParseConfig(l, c, fwd_list) } - if !*configTest { - ctrl.Start() - notifyReady(l) - ctrl.ShutdownBlock() + if !*configTest && disabled_tun && (activate_service_anyway || !fwd_list.IsEmpty()) { + l.Infof("Configuring user-tun instead of disabled-tun as port forwarding is configured") + + service, err := service.New(c, l) + if err != nil { + util.LogWithContextIfNeeded("Failed to create service", err, l) + os.Exit(1) + } + + // initialize port forwarding: + pf_service, err := port_forwarder.ConstructFromInitialFwdList(service, l, &fwd_list) + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) + os.Exit(1) + } + + c.RegisterReloadCallback(func(c *config.C) { + pf_service.ReloadConfigAndApplyChanges(c) + }) + + pf_service.Activate() + + // wait for termination request + signalChannel := make(chan os.Signal, 1) + signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM) + fmt.Println("Running, press ctrl+c to shutdown...") + <-signalChannel + + // shutdown: + service.CloseAndWait() + + } else { + + l.Info("Configuring for disabled or kernel tun. no port forwarding provided") + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) + os.Exit(1) + } + + if !*configTest { + ctrl.Start() + notifyReady(l) + ctrl.ShutdownBlock() + } } os.Exit(0) diff --git a/connection_manager.go b/connection_manager.go index d2e861647..9a2d310d4 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -26,6 +26,11 @@ const ( sendTestPacket trafficDecision = 6 ) +// The data written into this variable is never used. +// Its there to avoid a fresh dynamic memory allocation of 1 byte +// for each time its used. +var BYTE_SLICE_ONE []byte = []byte{1} + type connectionManager struct { in map[uint32]struct{} inLock *sync.RWMutex @@ -463,12 +468,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { if n.punchy.GetTargetEverything() { hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { n.metricsTxPunchy.Inc(1) - n.intf.outside.WriteTo([]byte{1}, addr) + n.intf.outside.WriteTo(BYTE_SLICE_ONE, addr) }) } else if hostinfo.remote.IsValid() { n.metricsTxPunchy.Inc(1) - n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) + n.intf.outside.WriteTo(BYTE_SLICE_ONE, hostinfo.remote) } } diff --git a/e2e/forwarding/.gitignore b/e2e/forwarding/.gitignore new file mode 100644 index 000000000..ccdfcbd6d --- /dev/null +++ b/e2e/forwarding/.gitignore @@ -0,0 +1,3 @@ +*.out +*.crt +*.key diff --git a/e2e/forwarding/README.md b/e2e/forwarding/README.md new file mode 100644 index 000000000..3b6d39b63 --- /dev/null +++ b/e2e/forwarding/README.md @@ -0,0 +1,27 @@ +# Userspace port forwarding +A simple speedtest for userspace port forwarding that can run without root access. + +## A side +Nebula running at port 10000, forwarding inbound TCP connections on port 5201 to 127.0.0.1:15001. + +## B side +Nebula running at port 10001, forwarding outbound TCP connections from 127.0.0.1:15002 to port 5201 of the A side. + +## Speedtest + + ┌──────────────────────┐:10001 :10002┌──────────────────────┐ + │ Nebula A side ├─────────────────┤ Nebula B side │ + │ │ │ │ + │ 192.168.100.1 │ TCP 5201 │ 192.168.100.2 │ + │ ┌───────────┼─────────────────┼──────────┐ │ + │ │ ├─────────────────┤ │ │ + └──────────▼───────────┘ └──────────▲───────────┘ + │ │ 127.0.0.1:15002 + │ │ + ┌──────────▼───────────┐ ┌──────────┴───────────┐ + │ │ │ │ + │ │ │ │ + │ iperf3 -s -p 15001 │ │ iperf3 -c -p 15001 │ + │ │ │ │ + │ │ │ │ + └──────────────────────┘ └──────────────────────┘ diff --git a/e2e/forwarding/a_config.yml b/e2e/forwarding/a_config.yml new file mode 100644 index 000000000..01b8e120e --- /dev/null +++ b/e2e/forwarding/a_config.yml @@ -0,0 +1,35 @@ +pki: + ca: ca.crt + cert: a.crt + key: a.key + +static_host_map: + "192.168.100.2": ["127.0.0.1:10002"] + +logging: + level: info + +listen: + host: 127.0.0.1 + port: 10001 + +port_forwarding: + enable_without_rules: true + inbound: + - listen_port: 5201 + dial_address: "127.0.0.1:15001" + protocols: [tcp, udp] + +tun: + disabled: true + mtu: 1300 + +firewall: + outbound: + - port: any + proto: udp + host: any + inbound: + - port: 5201 + proto: any + host: any diff --git a/e2e/forwarding/b_config.yml b/e2e/forwarding/b_config.yml new file mode 100644 index 000000000..b7e486cbc --- /dev/null +++ b/e2e/forwarding/b_config.yml @@ -0,0 +1,34 @@ +pki: + ca: ca.crt + cert: b.crt + key: b.key + +static_host_map: + "192.168.100.1": ["127.0.0.1:10001"] + +logging: + level: info + +listen: + host: 127.0.0.1 + port: 10002 + +port_forwarding: + enable_without_rules: true + outbound: + - listen_address: "127.0.0.1:15002" + dial_address: "192.168.100.1:5201" + protocols: [tcp, udp] + +tun: + disabled: true + mtu: 1300 + +firewall: + outbound: + - port: any + proto: udp + host: any + - port: 5201 + proto: any + host: any diff --git a/e2e/forwarding/generate_certificates.sh b/e2e/forwarding/generate_certificates.sh new file mode 100755 index 000000000..4df299c4e --- /dev/null +++ b/e2e/forwarding/generate_certificates.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +../../nebula-cert ca -name "E2E test CA" +../../nebula-cert sign -name "A" -ip "192.168.100.1/24" -out-crt a.crt -out-key a.key +../../nebula-cert sign -name "B" -ip "192.168.100.2/24" -out-crt b.crt -out-key b.key diff --git a/e2e/forwarding/speedtest.sh b/e2e/forwarding/speedtest.sh new file mode 100755 index 000000000..9d166de99 --- /dev/null +++ b/e2e/forwarding/speedtest.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +cd "$(dirname "$0")" + +if ! test -f ca.key; then + echo "Generating new test certificates" + ./generate_certificates.sh +fi + +../../nebula -config "$(pwd)/a_config.yml" &>a.out & +A_PID=$! +../../nebula -config "$(pwd)/b_config.yml" &>b.out & +B_PID=$! + +iperf3 -s -p 15001 & +IPERF_SERVER_PID=$! + +sleep 1 +iperf3 -c 127.0.0.1 -p 15002 -P 10 "$@" + +# Cleanup +kill $IPERF_SERVER_PID $A_PID $B_PID + +# wait for shutdown logs are written to files +sleep 1 + +echo "##########################################" +echo "A side logs:" +echo "##########################################" +cat a.out + +echo "##########################################" +echo "B side logs:" +echo "##########################################" +cat b.out +rm a.out b.out diff --git a/e2e/forwarding/speedtest_udp.sh b/e2e/forwarding/speedtest_udp.sh new file mode 100755 index 000000000..cb42d05e8 --- /dev/null +++ b/e2e/forwarding/speedtest_udp.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd "$(dirname "$0")" + +./speedtest.sh --udp --bidir --bitrate=100MiB "$@" diff --git a/examples/config.yml b/examples/config.yml index c74ffc68f..62c800573 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -206,7 +206,11 @@ relay: # Configure the private interface. Note: addr is baked into the nebula certificate tun: - # When tun is disabled, a lighthouse can be started without a local tun interface (and therefore without root) + # When tun is disabled, a feature limited Nebula can be started without root privileges. + # In this limited mode, Nebula can + # - run a lighthouse node + # - offer access from and to the nebula network via port forwarding + # - respond to ping requests disabled: false # Name of the device. If not set, a default will be chosen by the OS. # For macOS: if set, must be in the form `utun[0-9]+`. @@ -368,3 +372,26 @@ firewall: proto: tcp group: remote_client local_cidr: 192.168.100.1/24 + +# By using port port forwarding (port tunnels) its possible to establish connections +# from/into the nebula-network without using a tun/tap device and thus without requiring root access +# on the host. Port forwarding is only supported when setting "tun.disabled" is set to true. +# In this case, if a user-tun instead of a real one is instantiated to allow the port forwarding. +# IMPORTANT: For incoming tunnels, don't forget to also open the firewall for the relevant ports. +port_forwarding: + # Forces activation of the user tun, even when there is no rule specified. + # This can be useful, when rules are planned to be added later by reload. + # Reload config can't consider a change on tun-type. + enable_without_rules: false + outbound: + # format of listen- and dial-address: : + #- listen_address: 127.0.0.1:3399 + # dial_address: 192.168.100.92:4499 + # format of protocols lists (yml-list): [tcp], [udp], [tcp, udp] + # protocols: [tcp, udp] + inbound: + # format of dial_address: : + #- listen_port: 5599 + # dial_address: 127.0.0.1:5599 + # format of protocols lists (yml-list): [tcp], [udp], [tcp, udp] + # protocols: [tcp, udp] diff --git a/examples/go_service/main.go b/examples/go_service/main.go index 30178c034..810703f9b 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -5,7 +5,9 @@ import ( "fmt" "log" "net" + "os" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/service" ) @@ -59,7 +61,9 @@ pki: if err := cfg.LoadString(configStr); err != nil { return err } - svc, err := service.New(&cfg) + l := logrus.New() + l.Out = os.Stdout + svc, err := service.New(&cfg, l) if err != nil { return err } diff --git a/interface.go b/interface.go index f2519076c..26c25340d 100644 --- a/interface.go +++ b/interface.go @@ -301,7 +301,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { + if (errors.Is(err, os.ErrClosed) && f.closed.Load()) || errors.Is(err, io.EOF) { return } diff --git a/overlay/user.go b/overlay/user.go index 1bb4ef5f7..632ab1817 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -3,60 +3,116 @@ package overlay import ( "io" "net/netip" + "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "gvisor.dev/gvisor/pkg/buffer" ) func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return NewUserDevice(tunCidr) + d, err := NewUserDevice(tunCidr) + if err != nil { + return nil, err + } + + _, routes, err := getAllRoutesFromConfig(c, tunCidr, true) + if err != nil { + return nil, err + } + + routeTree, err := makeRouteTree(l, routes, true) + if err != nil { + return nil, err + } + + newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU) + for i, r := range routes { + if r.MTU == 0 { + routes[i].MTU = newDefaultMTU + } + } + + // this is needed to enable the "unsafe_routes" feature in combination with port forwarding. + d.routeTree.Store(routeTree) + + return d, nil } -func NewUserDevice(tunCidr netip.Prefix) (Device, error) { +func NewUserDevice(tunCidr netip.Prefix) (*UserDevice, error) { // these pipes guarantee each write/read will match 1:1 - or, ow := io.Pipe() - ir, iw := io.Pipe() return &UserDevice{ - tunCidr: tunCidr, - outboundReader: or, - outboundWriter: ow, - inboundReader: ir, - inboundWriter: iw, + tunCidr: tunCidr, + outboundChannel: make(chan *buffer.View, 16), + inboundChannel: make(chan *buffer.View, 16), }, nil } type UserDevice struct { tunCidr netip.Prefix - outboundReader *io.PipeReader - outboundWriter *io.PipeWriter + // using channel of *buffer.View significantly improves performance + outboundChannel chan *buffer.View + inboundChannel chan *buffer.View - inboundReader *io.PipeReader - inboundWriter *io.PipeWriter + routeTree atomic.Pointer[bart.Table[netip.Addr]] } func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } +func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { + ptr := d.routeTree.Load() + if ptr != nil { + r, _ := d.routeTree.Load().Lookup(ip) + return r + } else { + return ip + } +} func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } -func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { - return d.inboundReader, d.outboundWriter +func (d *UserDevice) Pipe() (<-chan *buffer.View, chan<- *buffer.View) { + return d.inboundChannel, d.outboundChannel } func (d *UserDevice) Read(p []byte) (n int, err error) { - return d.outboundReader.Read(p) + view, ok := <-d.outboundChannel + if !ok { + return 0, io.EOF + } + return view.Read(p) +} +func (d *UserDevice) WriteTo(w io.Writer) (n int64, err error) { + view, ok := <-d.outboundChannel + if !ok { + return 0, io.EOF + } + return view.WriteTo(w) } + func (d *UserDevice) Write(p []byte) (n int, err error) { - return d.inboundWriter.Write(p) + view := buffer.NewViewWithData(p) + d.inboundChannel <- view + return view.Size(), nil } +func (d *UserDevice) ReadFrom(r io.Reader) (n int64, err error) { + view := buffer.NewViewSize(2048) + n, err = view.ReadFrom(r) + if n > 0 { + d.inboundChannel <- view + } + return +} + func (d *UserDevice) Close() error { - d.inboundWriter.Close() - d.outboundWriter.Close() + // There is nothing to be done for the UserDevice. + // It doesn't start any goroutines on its own. + // It doesn't manage any resources that needs closing. return nil } diff --git a/port_forwarder/builder.go b/port_forwarder/builder.go new file mode 100644 index 000000000..1bd997e82 --- /dev/null +++ b/port_forwarder/builder.go @@ -0,0 +1,236 @@ +package port_forwarder + +import ( + "fmt" + "io" + "strconv" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/service" +) + +func ymlGetStringOfNode(node interface{}) string { + return fmt.Sprintf("%v", node) +} + +func ymlMapGetStringEntry(k string, m map[interface{}]interface{}) string { + v, ok := m[k] + if !ok { + return "" + } + return fmt.Sprintf("%v", v) +} + +type ymlListNode = []interface{} +type ymlMapNode = map[interface{}]interface{} +type configFactoryFn = func(yml_node ymlMapNode) error +type configFactoryFnMap = map[string]configFactoryFn + +type builderData struct { + l *logrus.Logger + target ConfigList + factories map[string]configFactoryFnMap +} + +func ParseConfig( + l *logrus.Logger, + c *config.C, + target ConfigList, +) error { + builder := builderData{ + l: l, + target: target, + factories: map[string]configFactoryFnMap{}, + } + + in := configFactoryFnMap{} + in["udp"] = func(yml_node ymlMapNode) error { + return builder.convertToForwardConfigIncoming(l, yml_node, false) + } + in["tcp"] = func(yml_node ymlMapNode) error { + return builder.convertToForwardConfigIncoming(l, yml_node, true) + } + builder.factories["inbound"] = in + + out := configFactoryFnMap{} + out["udp"] = func(yml_node ymlMapNode) error { + return builder.convertToForwardConfigOutgoing(l, yml_node, false) + } + out["tcp"] = func(yml_node ymlMapNode) error { + return builder.convertToForwardConfigOutgoing(l, yml_node, true) + } + builder.factories["outbound"] = out + + for _, direction := range [...]string{"inbound", "outbound"} { + cfg_fwds := c.Get("port_forwarding." + direction) + if cfg_fwds == nil { + continue + } + + cfg_fwds_list, ok := cfg_fwds.(ymlListNode) + if !ok { + return fmt.Errorf("yml node \"port_forwarding.%s\" needs to be a list", direction) + } + + for fwd_idx, node := range cfg_fwds_list { + node_map, ok := node.(ymlMapNode) + if !ok { + return fmt.Errorf("child yml node of \"port_forwarding.%s\" needs to be a map", direction) + } + + protocols, ok := node_map["protocols"] + if !ok { + l.Infof("child yml node of \"port_forwarding.%s\" should have a child \"protocols\"", direction) + continue + } + + protocols_list, ok := protocols.(ymlListNode) + if !ok { + return fmt.Errorf("child yml node of \"port_forwarding.%s\" needs to have a child \"protocols\" that is a yml list", direction) + } + + for _, proto := range protocols_list { + proto_str := ymlGetStringOfNode(proto) + factoryFn, ok := builder.factories[direction][proto_str] + if !ok { + return fmt.Errorf("child yml node of \"port_forwarding.%s.%d.protocols\" doesn't support: %s", direction, fwd_idx, proto_str) + } + + err := factoryFn(node_map) + if err != nil { + return fmt.Errorf("child yml node of \"port_forwarding.%s.%d.protocols\" with proto %s - failed to instantiate forwarder: %v", direction, fwd_idx, proto_str, err) + } + } + } + } + + return nil +} + +func ConstructFromInitialFwdList( + tunService *service.Service, + l *logrus.Logger, + fwd_list *PortForwardingList, +) (*PortForwardingService, error) { + + pfService := &PortForwardingService{ + l: l, + tunService: tunService, + configPortForwardings: fwd_list.configPortForwardings, + portForwardings: make(map[string]io.Closer), + } + + return pfService, nil +} + +func NewPortForwardingList() PortForwardingList { + return PortForwardingList{ + configPortForwardings: map[string]ForwardConfig{}, + } +} + +type PortForwardingList struct { + configPortForwardings map[string]ForwardConfig +} + +func (pfl PortForwardingList) AddConfig(cfg ForwardConfig) { + pfl.configPortForwardings[cfg.ConfigDescriptor()] = cfg +} + +func (pfl PortForwardingList) IsEmpty() bool { + return len(pfl.configPortForwardings) == 0 +} + +func (s *PortForwardingService) ReloadConfigAndApplyChanges( + c *config.C, +) error { + + s.l.Infof("reloading port forwarding configuration...") + + pflNew := NewPortForwardingList() + + err := ParseConfig(s.l, c, pflNew) + if err != nil { + return err + } + + return s.ApplyChangesByNewFwdList(&pflNew) +} + +func (s *PortForwardingService) ApplyChangesByNewFwdList( + pflNew *PortForwardingList, +) error { + + to_be_closed := []string{} + for old := range s.configPortForwardings { + _, corresponding_new_exists := pflNew.configPortForwardings[old] + if !corresponding_new_exists { + to_be_closed = append(to_be_closed, old) + } + } + + s.CloseSelective(to_be_closed) + + to_be_added := map[string]ForwardConfig{} + for new, cfg := range pflNew.configPortForwardings { + _, corresponding_old_exists := s.configPortForwardings[new] + if !corresponding_old_exists { + to_be_added[cfg.ConfigDescriptor()] = cfg + } + } + + s.ActivateNew(to_be_added) + + return nil +} + +func (builder *builderData) convertToForwardConfigOutgoing( + _ *logrus.Logger, + m ymlMapNode, + isTcp bool, +) error { + fwd_port := ForwardConfigOutgoing{ + localListen: ymlMapGetStringEntry("listen_address", m), + remoteConnect: ymlMapGetStringEntry("dial_address", m), + } + + var cfg ForwardConfig + if isTcp { + cfg = ForwardConfigOutgoingTcp{fwd_port} + } else { + cfg = ForwardConfigOutgoingUdp{fwd_port} + } + + builder.target.AddConfig(cfg) + + return nil +} + +func (builder *builderData) convertToForwardConfigIncoming( + _ *logrus.Logger, + m ymlMapNode, + isTcp bool, +) error { + + v, err := strconv.ParseUint(ymlMapGetStringEntry("listen_port", m), 10, 32) + if err != nil { + return err + } + + fwd_port := ForwardConfigIncoming{ + port: uint32(v), + forwardLocalAddress: ymlMapGetStringEntry("dial_address", m), + } + + var cfg ForwardConfig + if isTcp { + cfg = ForwardConfigIncomingTcp{fwd_port} + } else { + cfg = ForwardConfigIncomingUdp{fwd_port} + } + + builder.target.AddConfig(cfg) + + return nil +} diff --git a/port_forwarder/config.go b/port_forwarder/config.go new file mode 100644 index 000000000..6eeb07aa4 --- /dev/null +++ b/port_forwarder/config.go @@ -0,0 +1,27 @@ +package port_forwarder + +import ( + "io" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" +) + +type ForwardConfig interface { + SetupPortForwarding(tunService *service.Service, l *logrus.Logger) (io.Closer, error) + ConfigDescriptor() string +} + +type ConfigList interface { + AddConfig(cfg ForwardConfig) +} + +type ForwardConfigOutgoing struct { + localListen string + remoteConnect string +} + +type ForwardConfigIncoming struct { + port uint32 + forwardLocalAddress string +} diff --git a/port_forwarder/config_test.go b/port_forwarder/config_test.go new file mode 100644 index 000000000..944f5a16c --- /dev/null +++ b/port_forwarder/config_test.go @@ -0,0 +1,275 @@ +package port_forwarder + +import ( + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" +) + +func TestEmptyConfig(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString("bla:") + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + +func TestConfigWithNoProtocols(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [] + inbound: + - listen_port: 5599 + dial_address: 127.0.0.1:5599 + protocols: [] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + +func TestConfigWithNoProtocols_commentedProtos(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + # protocols: [tcp, udp] + inbound: + - listen_port: 5599 + dial_address: 127.0.0.1:5599 + # protocols: [tc, udp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + +func TestConfigWithNoProtocols_missing_in_out(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + +func TestConfigWithTcpIn(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [tcp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 1) + assert.False(t, fwd_list.IsEmpty()) + + fwd1 := fwd_list.configPortForwardings["inbound.tcp.5580.127.0.0.1:5599"].(ForwardConfigIncomingTcp) + assert.NotNil(t, fwd1) + assert.Equal(t, fwd1.forwardLocalAddress, "127.0.0.1:5599") + assert.Equal(t, int(fwd1.port), 5580) +} + +func TestConfigWithTcpOut(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [tcp] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 1) + assert.False(t, fwd_list.IsEmpty()) + + fwd1 := fwd_list.configPortForwardings["outbound.tcp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingTcp) + assert.NotNil(t, fwd1) + assert.Equal(t, fwd1.localListen, "127.0.0.1:3399") + assert.Equal(t, fwd1.remoteConnect, "192.168.100.92:4499") +} + +func TestConfigWithUdpIn(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [udp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 1) + assert.False(t, fwd_list.IsEmpty()) + + fwd1 := fwd_list.configPortForwardings["inbound.udp.5580.127.0.0.1:5599"].(ForwardConfigIncomingUdp) + assert.NotNil(t, fwd1) + assert.Equal(t, fwd1.forwardLocalAddress, "127.0.0.1:5599") + assert.Equal(t, int(fwd1.port), 5580) +} + +func TestConfigWithUdpOut(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [udp] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 1) + assert.False(t, fwd_list.IsEmpty()) + + fwd1 := fwd_list.configPortForwardings["outbound.udp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingUdp) + assert.NotNil(t, fwd1) + assert.Equal(t, fwd1.localListen, "127.0.0.1:3399") + assert.Equal(t, fwd1.remoteConnect, "192.168.100.92:4499") +} + +func TestConfigWithMultipleMixed(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [udp, tcp] + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:5499 + protocols: [tcp] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [tcp, udp] + - listen_port: 5570 + dial_address: 127.0.0.1:5555 + protocols: [udp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 6) + assert.False(t, fwd_list.IsEmpty()) + + assert.NotNil(t, fwd_list.configPortForwardings["outbound.udp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingUdp)) + assert.NotNil(t, fwd_list.configPortForwardings["outbound.tcp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["outbound.tcp.127.0.0.1:3399.192.168.100.92:5499"].(ForwardConfigOutgoingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.tcp.5580.127.0.0.1:5599"].(ForwardConfigIncomingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.udp.5580.127.0.0.1:5599"].(ForwardConfigIncomingUdp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.udp.5570.127.0.0.1:5555"].(ForwardConfigIncomingUdp)) +} + +func TestConfigWithOverlappingRulesNoDuplicatesInResult(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [udp, tcp, udp] + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [tcp] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [tcp, udp] + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [udp, udp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 4) + assert.False(t, fwd_list.IsEmpty()) + + assert.NotNil(t, fwd_list.configPortForwardings["outbound.udp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingUdp)) + assert.NotNil(t, fwd_list.configPortForwardings["outbound.tcp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.tcp.5580.127.0.0.1:5599"].(ForwardConfigIncomingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.udp.5580.127.0.0.1:5599"].(ForwardConfigIncomingUdp)) +} diff --git a/port_forwarder/fwd_tcp.go b/port_forwarder/fwd_tcp.go new file mode 100644 index 000000000..184031ffc --- /dev/null +++ b/port_forwarder/fwd_tcp.go @@ -0,0 +1,240 @@ +package port_forwarder + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" + "golang.org/x/sync/errgroup" +) + +type ForwardConfigOutgoingTcp struct { + ForwardConfigOutgoing +} + +func (cfg ForwardConfigOutgoingTcp) ConfigDescriptor() string { + return fmt.Sprintf("outbound.tcp.%s.%s", cfg.localListen, cfg.remoteConnect) +} + +type ForwardConfigIncomingTcp struct { + ForwardConfigIncoming +} + +func (cfg ForwardConfigIncomingTcp) ConfigDescriptor() string { + return fmt.Sprintf("inbound.tcp.%d.%s", cfg.port, cfg.forwardLocalAddress) +} + +type PortForwardingCommonTcp struct { + ctx context.Context + wg *sync.WaitGroup + l *logrus.Logger + tunService *service.Service + localListenConnection net.Listener +} + +func (fwd PortForwardingCommonTcp) Close() error { + fwd.localListenConnection.Close() + fwd.wg.Wait() + return nil +} + +type PortForwardingOutgoingTcp struct { + PortForwardingCommonTcp + cfg ForwardConfigOutgoingTcp +} + +func (cf ForwardConfigOutgoingTcp) SetupPortForwarding( + tunService *service.Service, + l *logrus.Logger, +) (io.Closer, error) { + localTcpListenAddr, err := net.ResolveTCPAddr("tcp", cf.localListen) + if err != nil { + return nil, err + } + localListenPort, err := net.ListenTCP("tcp", localTcpListenAddr) + if err != nil { + return nil, err + } + + l.Infof("TCP port forwarding to '%v': listening on local TCP addr: '%v'", + cf.remoteConnect, localTcpListenAddr) + + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + + portForwarding := &PortForwardingOutgoingTcp{ + PortForwardingCommonTcp: PortForwardingCommonTcp{ + ctx: ctx, + wg: wg, + l: l, + tunService: tunService, + localListenConnection: localListenPort, + }, + cfg: cf, + } + + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + portForwarding.acceptOnLocalListenPort_generic(portForwarding.handleClientConnectionWithErrorReturn) + }() + + return portForwarding, nil +} + +func (pt *PortForwardingCommonTcp) acceptOnLocalListenPort_generic( + handleClientConnectionWithErrorReturn func(localConnection net.Conn) error, +) error { + for { + pt.l.Debug("listening on local TCP port ...") + connection, err := pt.localListenConnection.Accept() + if err != nil { + fmt.Println(err) + return err + } + + pt.l.Debugf("accept TCP connect from local TCP port: %v", connection.RemoteAddr()) + + pt.wg.Add(1) + go func() { + defer pt.wg.Done() + defer connection.Close() + <-pt.ctx.Done() + }() + + pt.wg.Add(1) + go func() { + defer pt.wg.Done() + err := handleClientConnectionWithErrorReturn(connection) + if err != nil { + pt.l.Debugf("Closed TCP client connection %s. Err: %+v", + connection.LocalAddr().String(), err) + } + }() + } +} + +func (pt *PortForwardingOutgoingTcp) handleClientConnectionWithErrorReturn(localConnection net.Conn) error { + + remoteConnection, err := pt.tunService.DialContext(context.Background(), "tcp", pt.cfg.remoteConnect) + if err != nil { + return err + } + return handleTcpClientConnectionPair_generic(pt.l, localConnection, remoteConnection) +} + +func handleTcpClientConnectionPair_generic(l *logrus.Logger, connA, connB net.Conn) error { + + dataTransferHandler := func(from, to net.Conn) error { + + name := fmt.Sprintf("%s -> %s", from.LocalAddr().String(), to.LocalAddr().String()) + + defer from.Close() + defer to.Close() + + // no write/read timeout + to.SetDeadline(time.Time{}) + from.SetDeadline(time.Time{}) + megabyte := (1 << 20) + buf := make([]byte, 1*megabyte) + if false { + // this variant seems to be slightly slower on the local speed-test. 1.60GiB/s vs. 1.70GiB/s + n, err := io.CopyBuffer(to, from, buf) + l.WithError(err). + WithField("payloadSize", n). + WithField("from", from.RemoteAddr()). + WithField("to", to.RemoteAddr()). + WithField("localFrom", from.LocalAddr()). + WithField("localTo", to.LocalAddr()). + Debug("stopped data forwarding") + return err + } else { + for { + rn, r_err := from.Read(buf) + l.Tracef("%s read(%d), err: %v", name, rn, r_err) + for i := 0; i < rn; { + wn, w_err := to.Write(buf[i:rn]) + if w_err != nil { + l.Debugf("%s writing(%d) to to-connection failed: %v", name, rn, w_err) + return w_err + } + i += wn + } + if r_err != nil { + l.Debugf("%s reading(%d) from from-connection failed: %v", name, rn, r_err) + return r_err + } + } + } + } + + errGroup := errgroup.Group{} + + errGroup.Go(func() error { return dataTransferHandler(connA, connB) }) + errGroup.Go(func() error { return dataTransferHandler(connB, connA) }) + + return errGroup.Wait() +} + +type PortForwardingIncomingTcp struct { + PortForwardingCommonTcp + cfg ForwardConfigIncomingTcp +} + +func (cf ForwardConfigIncomingTcp) SetupPortForwarding( + tunService *service.Service, + l *logrus.Logger, +) (io.Closer, error) { + + localListenPort, err := tunService.Listen("tcp", fmt.Sprintf(":%d", cf.port)) + if err != nil { + return nil, err + } + + l.Infof("TCP port forwarding to '%v': listening on local, outside TCP addr: ':%d'", + cf.forwardLocalAddress, cf.port) + + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + + portForwarding := &PortForwardingIncomingTcp{ + PortForwardingCommonTcp: PortForwardingCommonTcp{ + ctx: ctx, + wg: wg, + l: l, + tunService: tunService, + localListenConnection: localListenPort, + }, + cfg: cf, + } + + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + portForwarding.acceptOnLocalListenPort_generic(portForwarding.handleClientConnectionWithErrorReturn) + }() + + return portForwarding, nil +} + +func (pt *PortForwardingIncomingTcp) handleClientConnectionWithErrorReturn(outsideConnection net.Conn) error { + + fwdAddr, err := net.ResolveTCPAddr("tcp", pt.cfg.forwardLocalAddress) + if err != nil { + return err + } + + localConnection, err := net.DialTCP("tcp", nil, fwdAddr) + if err != nil { + return err + } + + return handleTcpClientConnectionPair_generic(pt.l, outsideConnection, localConnection) +} diff --git a/port_forwarder/fwd_udp.go b/port_forwarder/fwd_udp.go new file mode 100644 index 000000000..4eeb0c07b --- /dev/null +++ b/port_forwarder/fwd_udp.go @@ -0,0 +1,376 @@ +package port_forwarder + +import ( + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" +) + +type ForwardConfigOutgoingUdp struct { + ForwardConfigOutgoing +} + +func (cfg ForwardConfigOutgoingUdp) ConfigDescriptor() string { + return fmt.Sprintf("outbound.udp.%s.%s", cfg.localListen, cfg.remoteConnect) +} + +type ForwardConfigIncomingUdp struct { + ForwardConfigIncoming +} + +func (cfg ForwardConfigIncomingUdp) ConfigDescriptor() string { + return fmt.Sprintf("inbound.udp.%d.%s", cfg.port, cfg.forwardLocalAddress) +} + +// use UDP timeout of 300 seconds according to +// https://support.goto.com/connect/help/what-are-the-recommended-nat-keep-alive-settings +var UDP_CONNECTION_TIMEOUT_SECONDS uint32 = 300 + +type udpConnInterface interface { + io.Closer + WriteTo(b []byte, addr net.Addr) (int, error) + Write(b []byte) (int, error) + ReadFrom(b []byte) (int, net.Addr, error) + LocalAddr() net.Addr +} + +func resetTimer(t *time.Timer, d time.Duration) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(d) +} + +func handleUdpDestinationPortResponseReading[destConn udpConnInterface, srcConn udpConnInterface]( + l *logrus.Logger, + loggingFields logrus.Fields, + closedConnections *chan string, + sourceAddr net.Addr, + destConnection destConn, + localListenConnection srcConn, +) error { + // net.Conn is thread-safe according to: https://pkg.go.dev/net#Conn + // no need for remoteConnection to protect by mutex + + defer func() { (*closedConnections) <- sourceAddr.String() }() + + l.WithFields(loggingFields).Debug("begin reading responses ...") + wg := &sync.WaitGroup{} + defer wg.Wait() + + timeout := time.Second * time.Duration(UDP_CONNECTION_TIMEOUT_SECONDS) + timer := time.NewTimer(timeout) + + rr := newUdpPortReader(wg, l, loggingFields, destConnection) + defer close(rr.receivedDataDone) + for { + select { + case <-timer.C: + destConnection.Close() + l.WithFields(loggingFields).Debug("response read - closed due to timeout") + return nil + case data, ok := <-rr.receivedData: + if !ok { + return nil + } + resetTimer(timer, timeout) + + l.WithFields(loggingFields). + WithField("payloadSize", data.n). + Debug("response forward") + n, err := localListenConnection.WriteTo(rr.buf[:data.n], sourceAddr) + rr.receivedDataDone <- 1 + if (n == 0) && (err != nil) { + l.WithFields(loggingFields).WithError(err).Debugf("response forward - write error") + return err + } + } + } +} + +type PortForwardingCommonUdp struct { + wg *sync.WaitGroup + l *logrus.Logger + tunService *service.Service + // net.Conn is thread-safe according to: https://pkg.go.dev/net#Conn + // no need for localListenConnection to protect by mutex + localListenConnection io.Closer +} + +func (fwd PortForwardingCommonUdp) Close() error { + fwd.localListenConnection.Close() + fwd.wg.Wait() + return nil +} + +type PortForwardingOutgoingUdp struct { + PortForwardingCommonUdp + cfg ForwardConfigOutgoingUdp +} + +func (cfg ForwardConfigOutgoingUdp) SetupPortForwarding( + tunService *service.Service, + l *logrus.Logger, +) (io.Closer, error) { + localUdpListenAddr, err := net.ResolveUDPAddr("udp", cfg.localListen) + if err != nil { + return nil, err + } + + localListenConnection, err := net.ListenUDP("udp", localUdpListenAddr) + if err != nil { + return nil, err + } + + l.Infof("UDP port forwarding to '%v': listening on local UDP addr: '%v'", + cfg.remoteConnect, localUdpListenAddr) + + wg := &sync.WaitGroup{} + + portForwarding := &PortForwardingOutgoingUdp{ + PortForwardingCommonUdp: PortForwardingCommonUdp{ + wg: wg, + l: l, + tunService: tunService, + localListenConnection: localListenConnection, + }, + cfg: cfg, + } + + logPrefix := logrus.Fields{ + "a": "UDP fwd out", + "listen": localListenConnection.LocalAddr(), + "dial": cfg.remoteConnect, + } + + wg.Add(1) + go func() { + defer wg.Done() + err := listenLocalPort_generic( + wg, + l, + logPrefix, + localListenConnection, + func(address string) (*gonet.UDPConn, error) { + return tunService.DialUDP(address) + }, + cfg.remoteConnect, + ) + if err != nil { + l.WithFields(logPrefix).WithError(err). + Error("listening stopped with error") + } + }() + + return portForwarding, nil +} + +type readData struct { + n int + addr net.Addr +} + +type readerRoutine struct { + buf []byte + receivedData chan readData + receivedDataDone chan int +} + +func newUdpPortReader( + wg *sync.WaitGroup, + l *logrus.Logger, + loggingFields logrus.Fields, + conn udpConnInterface, +) *readerRoutine { + r := &readerRoutine{ + buf: make([]byte, 512*1024), + receivedData: make(chan readData), + receivedDataDone: make(chan int, 1), + } + r.receivedDataDone <- 1 + + wg.Add(1) + go func() { + defer wg.Done() + defer close(r.receivedData) + l.WithFields(loggingFields). + WithField("addr", conn.LocalAddr()). + Debug("start listening") + for { + _, ok := <-r.receivedDataDone + if !ok { + return + } + l.WithFields(loggingFields). + WithField("addr", conn.LocalAddr()). + Trace("reading data ...") + n, addr, err := conn.ReadFrom(r.buf[0:]) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + l.WithFields(loggingFields). + WithField("addr", conn.LocalAddr()). + WithError(err).Error("listen for data failed. stop.") + return + } + r.receivedData <- readData{ + n: n, + addr: addr, + } + } + }() + + return r +} + +func listenLocalPort_generic[destConn udpConnInterface]( + wg *sync.WaitGroup, + l *logrus.Logger, + loggingFields logrus.Fields, + localListenConnection udpConnInterface, + dial func(address string) (destConn, error), + remoteConnect string, +) error { + dialConnResponseReaders := make(map[string]bool) + dialConnections := make(map[string]destConn) + closedConnections := make(chan string, 5) + mr := newUdpPortReader(wg, l, loggingFields, localListenConnection) + defer close(mr.receivedDataDone) + + defer func() { + // close and wait for remaining connections + for _, connection := range dialConnections { + connection.Close() + } + for range dialConnResponseReaders { + <-closedConnections + } + }() + + for { + select { + case closedOne := <-closedConnections: + l.Debugf("closing connection to %s", closedOne) + delete(dialConnections, closedOne) + delete(dialConnResponseReaders, closedOne) + case data, ok := <-mr.receivedData: + if !ok { + return nil + } + l.WithFields(loggingFields). + WithField("source", data.addr). + WithField("payloadSize", data.n). + Trace("read data") + dialConnection, ok := dialConnections[data.addr.String()] + if !ok { + newConnection, err := dial(remoteConnect) + if err != nil { + l.WithFields(loggingFields).WithError(err).Error("dialing dial address failed") + continue + } + dialConnections[data.addr.String()] = newConnection + dialConnection = newConnection + } + + l.WithFields(loggingFields). + WithField("source", data.addr). + WithField("dialSource", dialConnection.LocalAddr()). + WithField("payloadSize", data.n). + Debug("forward") + + dialConnection.Write(mr.buf[:data.n]) + mr.receivedDataDone <- 1 + + _, ok = dialConnResponseReaders[data.addr.String()] + if !ok { + loggingFieldsRsp := logrus.Fields{ + "source": data.addr, + "dialSource": dialConnection.LocalAddr(), + } + for k, v := range loggingFields { + loggingFieldsRsp[k] = v + } + dialConnResponseReaders[data.addr.String()] = true + go func() error { + return handleUdpDestinationPortResponseReading( + l, loggingFieldsRsp, &closedConnections, data.addr, + dialConnection, localListenConnection) + }() + } + } + } +} + +type PortForwardingIncomingUdp struct { + PortForwardingCommonUdp + cfg ForwardConfigIncomingUdp +} + +func (cfg ForwardConfigIncomingUdp) SetupPortForwarding( + tunService *service.Service, + l *logrus.Logger, +) (io.Closer, error) { + + conn, err := tunService.ListenUDP(fmt.Sprintf(":%d", cfg.port)) + if err != nil { + return nil, err + } + + l.Infof("UDP port forwarding to '%v': listening on outside UDP addr: ':%d'", + cfg.forwardLocalAddress, cfg.port) + + logPrefix := logrus.Fields{ + "a": "UDP fwd in", + "listenPort": cfg.port, + "dial": cfg.forwardLocalAddress, + } + + wg := &sync.WaitGroup{} + + forwarding := &PortForwardingIncomingUdp{ + PortForwardingCommonUdp: PortForwardingCommonUdp{ + wg: wg, + l: l, + tunService: tunService, + localListenConnection: conn, + }, + cfg: cfg, + } + + wg.Add(1) + go func() { + defer wg.Done() + err := listenLocalPort_generic( + wg, + l, + logPrefix, + conn, + func(address string) (*net.UDPConn, error) { + fwdAddr, err := net.ResolveUDPAddr("udp", cfg.forwardLocalAddress) + if err != nil { + l.WithFields(logPrefix).Error("resolve of dial address failed") + return nil, err + } + return net.DialUDP("udp", nil, fwdAddr) + }, + cfg.forwardLocalAddress, + ) + if err != nil { + l.WithFields(logPrefix).WithError(err). + Error("listening stopped with error") + } + }() + + return forwarding, nil +} diff --git a/port_forwarder/port_forwarder_tcp_test.go b/port_forwarder/port_forwarder_tcp_test.go new file mode 100644 index 000000000..ca090d0a1 --- /dev/null +++ b/port_forwarder/port_forwarder_tcp_test.go @@ -0,0 +1,356 @@ +package port_forwarder + +import ( + "fmt" + "net" + "testing" + + "github.com/slackhq/nebula/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func startReadToChannel(receiverConn net.Conn) <-chan []byte { + rcv_chan := make(chan []byte, 10) + r := make(chan bool, 1) + go func() { + defer close(rcv_chan) + r <- true + for { + buf := make([]byte, 100) + n, err := receiverConn.Read(buf) + if err != nil { + break + } + rcv_chan <- buf[0:n] + } + }() + <-r + return rcv_chan +} + +func doTestTcpCommunication( + t *testing.T, + msg string, + senderConn net.Conn, + receiverConn <-chan []byte, +) { + var n int = 0 + var err error = nil + data_sent := []byte(msg) + var buf []byte = nil + for { + fmt.Println("sending ...") + t.Log("sending ...") + n, err = senderConn.Write(data_sent) + require.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + fmt.Println("receiving ...") + t.Log("receiving ...") + var ok bool = false + buf, ok = <-receiverConn + if ok { + break + } + } + fmt.Println("DONE") + t.Log("DONE") + require.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + assert.Equal(t, data_sent, buf[:n]) +} + +func doTestTcpCommunicationFail( + t *testing.T, + msg string, + senderConn net.Conn, + receiverConn net.Conn, +) { + data_sent := []byte(msg) + n, err := senderConn.Write(data_sent) + if err != nil { + return + } + require.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + buf := make([]byte, 100) + _, err = receiverConn.Read(buf) + assert.NotNil(t, err) +} + +func tcpListenerNAccept(t *testing.T, listener *net.TCPListener, n int) <-chan net.Conn { + c := make(chan net.Conn, 1) + r := make(chan bool, 1) + go func() { + defer close(c) + r <- true + for range n { + conn, err := listener.Accept() + require.Nil(t, err) + c <- conn + } + }() + + <-r + + return c +} + +func TestTcpInOut2Clients(t *testing.T) { + server, sl, client, cl := service.CreateTwoConnectedServices(t, 4247) + + server_pf, err := createPortForwarderFromConfigString(t, sl, server, ` +port_forwarding: + inbound: + - listen_port: 4495 + dial_address: 127.0.0.1:5595 + protocols: [tcp] +`) + require.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3395 + dial_address: 10.0.0.1:4495 + protocols: [tcp] +`) + require.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3395") + require.Nil(t, err) + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5595") + require.Nil(t, err) + + server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) + require.Nil(t, err) + defer server_listen_conn.Close() + server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 2) + + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + require.Nil(t, err) + defer client1_conn.Close() + + client1_rcv_chan := startReadToChannel(client1_conn) + client1_server_side_conn := <-server_listen_conn_accepts + client1_server_side_rcv_chan := startReadToChannel(client1_server_side_conn) + + client2_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + require.Nil(t, err) + defer client2_conn.Close() + + client2_rcv_chan := startReadToChannel(client2_conn) + client2_server_side_conn := <-server_listen_conn_accepts + client2_server_side_rcv_chan := startReadToChannel(client2_server_side_conn) + + doTestTcpCommunication(t, "Hello from client 1 side!", + client1_conn, client1_server_side_rcv_chan) + doTestTcpCommunication(t, "Hello from client two side!", + client2_conn, client2_server_side_rcv_chan) + + doTestTcpCommunication(t, "Hello from server first side!", + client1_server_side_conn, client1_rcv_chan) + doTestTcpCommunication(t, "Hello from server second side!", + client2_server_side_conn, client2_rcv_chan) + doTestTcpCommunication(t, "Hello from server third side!", + client1_server_side_conn, client1_rcv_chan) + + doTestTcpCommunication(t, "Hello from client two side AGAIN!", + client2_conn, client2_server_side_rcv_chan) + +} + +func TestTcpInOut1ClientConfigReload(t *testing.T) { + server, sl, client, cl := service.CreateTwoConnectedServices(t, 4246) + + server_pf, err := createPortForwarderFromConfigString(t, sl, server, ` +port_forwarding: + inbound: + - listen_port: 4497 + dial_address: 127.0.0.1:5597 + protocols: [tcp] +`) + require.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3397 + dial_address: 10.0.0.1:4497 + protocols: [tcp] +`) + require.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3397") + require.Nil(t, err) + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5597") + require.Nil(t, err) + + server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) + require.Nil(t, err) + defer server_listen_conn.Close() + + server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) + + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + require.Nil(t, err) + defer client1_conn.Close() + client1_rcv_chan := startReadToChannel(client1_conn) + + client1_server_side_conn := <-server_listen_conn_accepts + defer client1_server_side_conn.Close() + client1_server_side_rcv_chan := startReadToChannel(client1_server_side_conn) + + doTestTcpCommunication(t, "Hello from client 1 side!", + client1_conn, client1_server_side_rcv_chan) + + doTestTcpCommunication(t, "Hello from server first side!", + client1_server_side_conn, client1_rcv_chan) + doTestTcpCommunication(t, "Hello from server third side!", + client1_server_side_conn, client1_rcv_chan) + + doTestTcpCommunication(t, "Hello from client one side AGAIN!", + client1_conn, client1_server_side_rcv_chan) + + new_server_fwd_list, err := loadPortFwdConfigFromString(sl, ` +port_forwarding: + inbound: + - listen_port: 4496 + dial_address: 127.0.0.1:5596 + protocols: [tcp] +`) + require.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + new_client_fwd_list, err := loadPortFwdConfigFromString(cl, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3396 + dial_address: 10.0.0.1:4496 + protocols: [tcp] +`) + require.Nil(t, err) + + err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list) + require.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + + err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list) + require.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) +} + +func TestTcpInOut1ClientConfigReload_inverseCloseOrder(t *testing.T) { + server, sl, client, cl := service.CreateTwoConnectedServices(t, 4245) + + server_pf, err := createPortForwarderFromConfigString(t, sl, server, ` +port_forwarding: + inbound: + - listen_port: 4499 + dial_address: 127.0.0.1:5599 + protocols: [tcp] +`) + require.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 10.0.0.1:4499 + protocols: [tcp] +`) + require.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3399") + require.Nil(t, err) + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5599") + require.Nil(t, err) + + server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) + require.Nil(t, err) + defer server_listen_conn.Close() + server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) + + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + require.Nil(t, err) + defer client1_conn.Close() + client1_rcv_chan := startReadToChannel(client1_conn) + + client1_server_side_conn := <-server_listen_conn_accepts + defer client1_server_side_conn.Close() + client1_server_side_rcv_chan := startReadToChannel(client1_server_side_conn) + + doTestTcpCommunication(t, "Hello from client 1 side!", + client1_conn, client1_server_side_rcv_chan) + + doTestTcpCommunication(t, "Hello from server first side!", + client1_server_side_conn, client1_rcv_chan) + doTestTcpCommunication(t, "Hello from server third side!", + client1_server_side_conn, client1_rcv_chan) + + doTestTcpCommunication(t, "Hello from client one side AGAIN!", + client1_conn, client1_server_side_rcv_chan) + + new_server_fwd_list, err := loadPortFwdConfigFromString(sl, ` +port_forwarding: + inbound: + - listen_port: 4498 + dial_address: 127.0.0.1:5598 + protocols: [tcp] +`) + require.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + new_client_fwd_list, err := loadPortFwdConfigFromString(cl, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3398 + dial_address: 10.0.0.1:4498 + protocols: [tcp] +`) + require.Nil(t, err) + + err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list) + require.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + + err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list) + require.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) +} diff --git a/port_forwarder/port_forwarder_udp_test.go b/port_forwarder/port_forwarder_udp_test.go new file mode 100644 index 000000000..4a629a3ea --- /dev/null +++ b/port_forwarder/port_forwarder_udp_test.go @@ -0,0 +1,163 @@ +package port_forwarder + +import ( + "net" + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func loadPortFwdConfigFromString(l *logrus.Logger, configStr string) (*PortForwardingList, error) { + c := config.NewC(l) + err := c.LoadString(configStr) + if err != nil { + return nil, err + } + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + if err != nil { + return nil, err + } + + return &fwd_list, nil +} + +func createPortForwarderFromConfigString(t *testing.T, l *logrus.Logger, srv *service.Service, configStr string) (*PortForwardingService, error) { + + fwd_list, err := loadPortFwdConfigFromString(l, configStr) + if err != nil { + return nil, err + } + + pf, err := ConstructFromInitialFwdList(srv, l, fwd_list) + if err != nil { + return nil, err + } + + err = pf.Activate() + if err != nil { + return nil, err + } + + t.Cleanup(func() { + pf.CloseAll() + }) + + return pf, nil +} + +func doTestUdpCommunication( + t *testing.T, + msg string, + senderConn *net.UDPConn, + toAddr net.Addr, + receiverConn <-chan Pair[[]byte, net.Addr], +) net.Addr { + data_sent := []byte(msg) + var n int + var err error + if toAddr != nil { + n, err = senderConn.WriteTo(data_sent, toAddr) + } else { + n, err = senderConn.Write(data_sent) + } + require.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + pair := <-receiverConn + require.Nil(t, err) + assert.Equal(t, data_sent, pair.a) + return pair.b +} + +type Pair[A any, B any] struct { + a A + b B +} + +func readUdpConnectionToChannel(conn *net.UDPConn) <-chan Pair[[]byte, net.Addr] { + rcv_chan := make(chan Pair[[]byte, net.Addr]) + + go func() { + defer close(rcv_chan) + for { + buf := make([]byte, 100) + n, addr, err := conn.ReadFrom(buf) + if err != nil { + return + } + rcv_chan <- Pair[[]byte, net.Addr]{buf[0:n], addr} + } + }() + + return rcv_chan +} + +func TestUdpInOut2Clients(t *testing.T) { + server, sl, client, cl := service.CreateTwoConnectedServices(t, 4244) + + server_pf, err := createPortForwarderFromConfigString(t, sl, server, ` +port_forwarding: + inbound: + - listen_port: 4499 + dial_address: 127.0.0.1:5599 + protocols: [udp] +`) + require.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 10.0.0.1:4499 + protocols: [udp] +`) + require.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3399") + require.Nil(t, err) + server_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5599") + require.Nil(t, err) + + server_listen_conn, err := net.ListenUDP("udp", server_conn_addr) + require.Nil(t, err) + defer server_listen_conn.Close() + server_listen_rcv_chan := readUdpConnectionToChannel(server_listen_conn) + + client1_conn, err := net.DialUDP("udp", nil, client_conn_addr) + require.Nil(t, err) + defer client1_conn.Close() + client1_rcv_chan := readUdpConnectionToChannel(client1_conn) + + client2_conn, err := net.DialUDP("udp", nil, client_conn_addr) + require.Nil(t, err) + defer client2_conn.Close() + client2_rcv_chan := readUdpConnectionToChannel(client2_conn) + + client1_addr := doTestUdpCommunication(t, "Hello from client 1 side!", + client1_conn, nil, server_listen_rcv_chan) + assert.NotNil(t, client1_addr) + client2_addr := doTestUdpCommunication(t, "Hello from client two side!", + client2_conn, nil, server_listen_rcv_chan) + assert.NotNil(t, client2_addr) + + doTestUdpCommunication(t, "Hello from server first side!", + server_listen_conn, client1_addr, client1_rcv_chan) + doTestUdpCommunication(t, "Hello from server second side!", + server_listen_conn, client2_addr, client2_rcv_chan) + doTestUdpCommunication(t, "Hello from server third side!", + server_listen_conn, client1_addr, client1_rcv_chan) + + doTestUdpCommunication(t, "Hello from client two side AGAIN!", + client2_conn, nil, server_listen_rcv_chan) + +} diff --git a/port_forwarder/port_forwarding_service.go b/port_forwarder/port_forwarding_service.go new file mode 100644 index 000000000..ed2910d72 --- /dev/null +++ b/port_forwarder/port_forwarding_service.go @@ -0,0 +1,65 @@ +package port_forwarder + +import ( + "io" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" +) + +type PortForwardingService struct { + l *logrus.Logger + tunService *service.Service + + configPortForwardings map[string]ForwardConfig + portForwardings map[string]io.Closer +} + +func (t *PortForwardingService) AddConfig(cfg ForwardConfig) { + t.configPortForwardings[cfg.ConfigDescriptor()] = cfg +} + +func (t *PortForwardingService) Activate() error { + return t.ActivateNew(t.configPortForwardings) +} + +func (t *PortForwardingService) ActivateNew(newForwards map[string]ForwardConfig) error { + + for descriptor, config := range newForwards { + fwd_instance, err := config.SetupPortForwarding(t.tunService, t.l) + if err == nil { + t.configPortForwardings[config.ConfigDescriptor()] = config + t.portForwardings[config.ConfigDescriptor()] = fwd_instance + } else { + t.l.Errorf("failed to setup port forwarding #%s: %s", descriptor, config.ConfigDescriptor()) + } + } + + return nil +} + +func (t *PortForwardingService) CloseSelective(descriptors []string) error { + + for _, descriptor := range descriptors { + delete(t.configPortForwardings, descriptor) + pf, ok := t.portForwardings[descriptor] + if ok { + t.l.Infof("closing port forwarding: %s", descriptor) + pf.Close() + delete(t.portForwardings, descriptor) + } + } + + return nil +} + +func (t *PortForwardingService) CloseAll() error { + + for descriptor, pf := range t.portForwardings { + t.l.Infof("closing port forwarding: %s", descriptor) + pf.Close() + delete(t.portForwardings, descriptor) + } + + return nil +} diff --git a/service/service.go b/service/service.go index 4ddd30182..469b55e31 100644 --- a/service/service.go +++ b/service/service.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "log" "math" "net" @@ -17,6 +18,7 @@ import ( "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/util" "golang.org/x/sync/errgroup" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" @@ -35,6 +37,7 @@ import ( const nicID = 1 type Service struct { + l *logrus.Logger eg *errgroup.Group control *nebula.Control ipstack *stack.Stack @@ -46,10 +49,7 @@ type Service struct { } } -func New(config *config.C) (*Service, error) { - logger := logrus.New() - logger.Out = os.Stdout - +func New(config *config.C, logger *logrus.Logger) (*Service, error) { control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { return nil, err @@ -59,6 +59,7 @@ func New(config *config.C) (*Service, error) { ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) s := Service{ + l: logger, eg: eg, control: control, } @@ -107,34 +108,28 @@ func New(config *config.C) (*Service, error) { tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler) s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) - reader, writer := device.Pipe() - - go func() { - <-ctx.Done() - reader.Close() - writer.Close() - }() + nebula_tun_reader, nebula_tun_writer := device.Pipe() // create Goroutines to forward packets between Nebula and Gvisor eg.Go(func() error { - buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize) + defer linkEP.Close() for { - // this will read exactly one packet - n, err := reader.Read(buf) - if err != nil { - return err - } - packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(bytes.Clone(buf[:n])), - }) - linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) - - if err := ctx.Err(); err != nil { - return err + select { + case <-ctx.Done(): + return nil + case view, ok := <-nebula_tun_reader: + if !ok { + return nil + } + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithView(view), + }) + linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) } } }) eg.Go(func() error { + defer close(nebula_tun_writer) for { packet := linkEP.ReadContext(ctx) if packet == nil { @@ -143,11 +138,7 @@ func New(config *config.C) (*Service, error) { } continue } - bufView := packet.ToView() - if _, err := bufView.WriteTo(writer); err != nil { - return err - } - bufView.Release() + nebula_tun_writer <- packet.ToView() } }) @@ -198,6 +189,21 @@ func (s *Service) Dial(network, address string) (net.Conn, error) { return s.DialContext(context.Background(), network, address) } +func (s *Service) DialUDP(address string) (*gonet.UDPConn, error) { + addr, err := net.ResolveUDPAddr("udp", address) + if err != nil { + return nil, err + } + + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + + return gonet.DialUDP(s.ipstack, nil, &fullAddr, ipv4.ProtocolNumber) +} + // Listen listens on the provided address. Currently only TCP with wildcard // addresses are supported. func (s *Service) Listen(network, address string) (net.Listener, error) { @@ -237,8 +243,25 @@ func (s *Service) Listen(network, address string) (net.Listener, error) { return l, nil } +func (s *Service) ListenUDP(address string) (*gonet.UDPConn, error) { + addr, err := net.ResolveUDPAddr("udp", address) + if err != nil { + return nil, err + } + return gonet.DialUDP(s.ipstack, &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + LinkAddr: "", + }, nil, ipv4.ProtocolNumber) +} + func (s *Service) Wait() error { - return s.eg.Wait() + err := s.eg.Wait() + + s.ipstack.Destroy() + + return err } func (s *Service) Close() error { @@ -246,6 +269,23 @@ func (s *Service) Close() error { return nil } +func (s *Service) CloseAndWait() error { + s.Close() + if err := s.Wait(); err != nil { + if errors.Is(err, os.ErrClosed) || + errors.Is(err, io.EOF) || + errors.Is(err, context.Canceled) { + s.l.Debugf("Stop of nebula service returned: %v", err) + return nil + } else { + util.LogWithContextIfNeeded("Unclean stop", err, s.l) + return err + } + } + + return nil +} + func (s *Service) tcpHandler(r *tcp.ForwarderRequest) { endpointID := r.ID() diff --git a/service/service_test.go b/service/service_test.go index 31762090d..e1c1d4cc4 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -4,101 +4,13 @@ import ( "bytes" "context" "errors" - "net/netip" "testing" - "time" - "dario.cat/mergo" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/e2e" "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v2" ) -type m map[string]interface{} - -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) - caB, err := caCrt.MarshalToPEM() - if err != nil { - panic(err) - } - - mc := m{ - "pki": m{ - "ca": string(caB), - "cert": string(myPEM), - "key": string(myPrivKey), - }, - //"tun": m{"disabled": true}, - "firewall": m{ - "outbound": []m{{ - "proto": "any", - "port": "any", - "host": "any", - }}, - "inbound": []m{{ - "proto": "any", - "port": "any", - "host": "any", - }}, - }, - "timers": m{ - "pending_deletion_interval": 2, - "connection_alive_interval": 2, - }, - "handshakes": m{ - "try_interval": "200ms", - }, - } - - if overrides != nil { - err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) - if err != nil { - panic(err) - } - mc = overrides - } - - cb, err := yaml.Marshal(mc) - if err != nil { - panic(err) - } - - var c config.C - if err := c.LoadString(string(cb)); err != nil { - panic(err) - } - - s, err := New(&c) - if err != nil { - panic(err) - } - return s -} - func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ - "static_host_map": m{}, - "lighthouse": m{ - "am_lighthouse": true, - }, - "listen": m{ - "host": "0.0.0.0", - "port": 4243, - }, - }) - b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ - "static_host_map": m{ - "10.0.0.1": []string{"localhost:4243"}, - }, - "lighthouse": m{ - "hosts": []string{"10.0.0.1"}, - "interval": 1, - }, - }) + a, _, b, _ := CreateTwoConnectedServices(t, 4243) ln, err := a.Listen("tcp", ":1234") if err != nil { @@ -150,12 +62,4 @@ func TestService(t *testing.T) { if !bytes.Equal(data, []byte("server msg")) { t.Fatal("got invalid message from client") } - - if err := c.Close(); err != nil { - t.Fatal(err) - } - - if err := eg.Wait(); err != nil { - t.Fatal(err) - } } diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go new file mode 100644 index 000000000..c77c1c574 --- /dev/null +++ b/service/service_testhelpers.go @@ -0,0 +1,134 @@ +package service + +import ( + "fmt" + "io" + "math/rand" + "net/netip" + "testing" + "time" + + "dario.cat/mergo" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/e2e" + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" +) + +type m map[string]interface{} + +type LogOutputWithPrefix struct { + prefix string + out io.Writer +} + +func (o LogOutputWithPrefix) Write(p []byte) (n int, err error) { + fmt.Fprintf(o.out, "[%s] ", o.prefix) + return o.out.Write(p) +} + +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) (*Service, *logrus.Logger) { + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, name, + time.Now().Add(-3*time.Minute), + time.Now().Add(30*time.Minute), + netip.PrefixFrom(udpIp, 24), nil, []string{}) + caB, err := caCrt.MarshalToPEM() + if err != nil { + panic(err) + } + + mc := m{ + "pki": m{ + "ca": string(caB), + "cert": string(myPEM), + "key": string(myPrivKey), + }, + //"tun": m{"disabled": true}, + "firewall": m{ + "outbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, + "handshakes": m{ + "try_interval": "200ms", + }, + } + + if overrides != nil { + err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = overrides + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + var c config.C + if err := c.LoadString(string(cb)); err != nil { + panic(err) + } + + l := logrus.New() + prefixWriter := LogOutputWithPrefix{ + prefix: name, + out: l.Out, + } + l.SetOutput(prefixWriter) + + s, err := New(&c, l) + if err != nil { + panic(err) + } + return s, l +} + +func CreateTwoConnectedServices(t *testing.T, port int) (*Service, *logrus.Logger, *Service, *logrus.Logger) { + port += 100 * (rand.Int() % 10) + ca, _, caKey, _ := e2e.NewTestCaCert( + time.Now().Add(-9*time.Minute), // ensure that there is no issue due to rounding + time.Now().Add(40*time.Minute), // ensure that the certificate is valid for at least the time ot the test execution + nil, nil, []string{}) + a, al := newSimpleService(ca, caKey, fmt.Sprintf("a_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.1"), m{ + "static_host_map": m{}, + "lighthouse": m{ + "am_lighthouse": true, + }, + "listen": m{ + "host": "0.0.0.0", + "port": port, + }, + }) + t.Cleanup(func() { + assert.NoError(t, a.CloseAndWait()) + }) + b, bl := newSimpleService(ca, caKey, fmt.Sprintf("b_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.2"), m{ + "static_host_map": m{ + "10.0.0.1": []string{fmt.Sprintf("localhost:%d", port)}, + }, + "lighthouse": m{ + "hosts": []string{"10.0.0.1"}, + "interval": 1, + }, + }) + t.Cleanup(func() { + assert.NoError(t, b.CloseAndWait()) + }) + return a, al, b, bl +} diff --git a/timeout.go b/timeout.go index c1b4c398b..26beb25b5 100644 --- a/timeout.go +++ b/timeout.go @@ -16,7 +16,8 @@ type TimerWheel[T any] struct { wheelLen int // Last time we ticked, since we are lazy ticking - lastTick *time.Time + lastTickValid bool + lastTick time.Time // Durations of a tick and the entire wheel tickDuration time.Duration @@ -168,13 +169,14 @@ func (tw *TimerWheel[T]) findWheel(timeout time.Duration) (i int) { // Advance will move the wheel forward by the appropriate number of ticks for the provided time and all items // passed over will be moved to the expired list. Calling Purge is necessary to remove them entirely. -func (tw *TimerWheel[T]) Advance(now time.Time) { - if tw.lastTick == nil { - tw.lastTick = &now +func (tw *TimerWheel[T]) Advance(now1 time.Time) { + if !tw.lastTickValid { + tw.lastTick = now1 + tw.lastTickValid = true } // We want to round down - ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration) + ticks := int(now1.Sub(tw.lastTick) / tw.tickDuration) adv := ticks if ticks > tw.wheelLen { ticks = tw.wheelLen @@ -203,7 +205,7 @@ func (tw *TimerWheel[T]) Advance(now time.Time) { // Advance the tick based on duration to avoid losing some accuracy newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv)) - tw.lastTick = &newTick + tw.lastTick = newTick } func (lw *LockingTimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] { diff --git a/timeout_test.go b/timeout_test.go index 4c6364ef5..616d83f74 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -14,7 +14,7 @@ func TestNewTimerWheel(t *testing.T) { tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) assert.Equal(t, 12, tw.wheelLen) assert.Equal(t, 0, tw.current) - assert.Nil(t, tw.lastTick) + assert.Equal(t, false, tw.lastTickValid) assert.Equal(t, time.Second*1, tw.tickDuration) assert.Equal(t, time.Second*10, tw.wheelDuration) assert.Len(t, tw.wheel, 12) @@ -110,9 +110,9 @@ func TestTimerWheel_Add(t *testing.T) { func TestTimerWheel_Purge(t *testing.T) { // First advance should set the lastTick and do nothing else tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) - assert.Nil(t, tw.lastTick) + assert.Equal(t, false, tw.lastTickValid) tw.Advance(time.Now()) - assert.NotNil(t, tw.lastTick) + assert.Equal(t, true, tw.lastTickValid) assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ @@ -128,7 +128,7 @@ func TestTimerWheel_Purge(t *testing.T) { tw.Add(fps[3], time.Second*2) ta := time.Now().Add(time.Second * 3) - lastTick := *tw.lastTick + lastTick := tw.lastTick tw.Advance(ta) assert.Equal(t, 3, tw.current) assert.True(t, tw.lastTick.After(lastTick)) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 2eee76ee2..ae44e582d 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -8,6 +8,8 @@ import ( "fmt" "net" "net/netip" + "sync" + "sync/atomic" "syscall" "unsafe" @@ -22,10 +24,12 @@ import ( //TODO: make it support reload as best you can! type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int + sysFd int + closed atomic.Bool + wg *sync.WaitGroup + isV4 bool + l *logrus.Logger + batch int } func maybeIPV4(ip net.IP) (net.IP, bool) { @@ -79,7 +83,14 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return &StdConn{ + sysFd: fd, + closed: atomic.Bool{}, + wg: &sync.WaitGroup{}, + isV4: ip.Is4(), + l: l, + batch: batch, + }, err } func (u *StdConn) Rebind() error { @@ -121,6 +132,15 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + + u.wg.Add(1) + defer func() { + u.wg.Done() + }() + if u.closed.Load() { + return + } + plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} @@ -142,6 +162,11 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } + if u.closed.Load() { + u.l.Debug("flag for closing connection is set, exiting read loop") + return + } + //metric.Update(int64(n)) for i := 0; i < n; i++ { if u.isV4 { @@ -314,7 +339,20 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { } func (u *StdConn) Close() error { - //TODO: this will not interrupt the read loop + if !u.closed.CompareAndSwap(false, true) { + // already closed by e.g. other thread + return nil + } + err := syscall.Shutdown(u.sysFd, syscall.SHUT_RDWR) + if err != nil { + errno, ok := err.(syscall.Errno) + // connection might have been terminated by remote before + wasDisconnected := ok && (errno == syscall.ENOTCONN) + if !wasDisconnected { + panic(fmt.Sprintf("error while shutdown of UDP socket: %v", err)) + } + } + u.wg.Wait() return syscall.Close(u.sysFd) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index ee7e1e002..eec96c91f 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -14,6 +14,7 @@ import ( "sync" "sync/atomic" "syscall" + "time" "unsafe" "github.com/sirupsen/logrus" @@ -178,10 +179,11 @@ func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) { retry: count = 0 for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if !u.isOpen.Load() { // might have changed since first check before the mutex lock + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + if tries > 0 { - if !u.isOpen.Load() { - return 0, windows.RawSockaddrInet6{}, net.ErrClosed - } procyield(1) } @@ -247,6 +249,10 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { u.tx.mu.Lock() defer u.tx.mu.Unlock() + if !u.isOpen.Load() { // might have changed since first check before the mutex lock + return net.ErrClosed + } + count := winrio.DequeueCompletion(u.tx.cq, u.results[:]) if count == 0 && u.tx.isFull { err := winrio.Notify(u.tx.cq) @@ -323,6 +329,14 @@ func (u *RIOConn) Close() error { windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil) + u.rx.mu.Lock() // for waiting till active reader is done + time.Sleep(time.Millisecond * 0) // avoid warning about empty critical section + u.rx.mu.Unlock() + + u.tx.mu.Lock() // for waiting till active writer is done + time.Sleep(time.Millisecond * 0) // avoid warning about empty critical section + u.tx.mu.Unlock() + u.rx.CloseAndZero() u.tx.CloseAndZero() if u.sock != 0 {