From 08134116a5605e38e74ce9e7e226a111752d77c5 Mon Sep 17 00:00:00 2001 From: Moses Narrow <36607567+0pcom@users.noreply.github.com> Date: Tue, 12 Nov 2024 17:27:09 -0600 Subject: [PATCH] remove replace directives in go.mod --- go.mod | 6 +- go.sum | 8 +- .../wireguard/conn/bind_std.go | 544 ++++++++++ .../wireguard/conn/bind_windows.go | 601 +++++++++++ .../wireguard/conn/boundif_android.go | 34 + .../golang.zx2c4.com/wireguard/conn/conn.go | 133 +++ .../wireguard/conn/controlfns.go | 43 + .../wireguard/conn/controlfns_linux.go | 69 ++ .../wireguard/conn/controlfns_unix.go | 35 + .../wireguard/conn/controlfns_windows.go | 23 + .../wireguard/conn/default.go | 10 + .../wireguard/conn/errors_default.go | 12 + .../wireguard/conn/errors_linux.go | 26 + .../wireguard/conn/features_default.go | 15 + .../wireguard/conn/features_linux.go | 29 + .../wireguard/conn/gso_default.go | 21 + .../wireguard/conn/gso_linux.go | 65 ++ .../wireguard/conn/mark_default.go | 12 + .../wireguard/conn/mark_unix.go | 65 ++ .../wireguard/conn/sticky_default.go | 42 + .../wireguard/conn/sticky_linux.go | 112 ++ .../wireguard/conn/winrio/rio_windows.go | 254 +++++ .../wireguard/rwcancel/rwcancel.go | 2 +- .../wireguard/rwcancel/rwcancel_stub.go | 2 +- .../wireguard/tun/checksum.go | 118 +++ .../golang.zx2c4.com/wireguard/tun/errors.go | 12 + .../wireguard/tun/offload_linux.go | 993 ++++++++++++++++++ vendor/golang.zx2c4.com/wireguard/tun/tun.go | 40 +- .../wireguard/tun/tun_darwin.go | 67 +- .../wireguard/tun/tun_freebsd.go | 55 +- .../wireguard/tun/tun_linux.go | 285 +++-- .../wireguard/tun/tun_openbsd.go | 60 +- .../wireguard/tun/tun_windows.go | 58 +- vendor/modules.txt | 8 +- 34 files changed, 3666 insertions(+), 193 deletions(-) create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/bind_std.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/bind_windows.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/boundif_android.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/conn.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/controlfns.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/controlfns_linux.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/controlfns_unix.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/controlfns_windows.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/default.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/errors_default.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/errors_linux.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/features_default.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/features_linux.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/gso_default.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/gso_linux.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/mark_default.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/mark_unix.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/sticky_default.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/sticky_linux.go create mode 100644 vendor/golang.zx2c4.com/wireguard/conn/winrio/rio_windows.go create mode 100644 vendor/golang.zx2c4.com/wireguard/tun/checksum.go create mode 100644 vendor/golang.zx2c4.com/wireguard/tun/errors.go create mode 100644 vendor/golang.zx2c4.com/wireguard/tun/offload_linux.go diff --git a/go.mod b/go.mod index adbf47fc2a..0d79e246b2 100644 --- a/go.mod +++ b/go.mod @@ -176,9 +176,11 @@ require ( mvdan.cc/sh/v3 v3.9.0 // indirect ) -replace github.com/xxxserxxx/gotop/v4 => github.com/ersonp/gotop/v4 v4.2.1 +// issues with gotop on riscv64 +//replace github.com/xxxserxxx/gotop/v4 => github.com/ersonp/gotop/v4 v4.2.1 -replace golang.zx2c4.com/wireguard => golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 +// Wiregusrd version must match below version ; do not update +// replace golang.zx2c4.com/wireguard => golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 // Uncomment for tests with local sources //replace github.com/skycoin/dmsg => ../dmsg diff --git a/go.sum b/go.sum index c9c0f20895..b2b8c6d676 100644 --- a/go.sum +++ b/go.sum @@ -318,6 +318,8 @@ github.com/gomarkdown/markdown v0.0.0-20240930133441-72d49d9543d8 h1:4txT5G2kqVA github.com/gomarkdown/markdown v0.0.0-20240930133441-72d49d9543d8/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -1082,8 +1084,8 @@ golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU= -golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= @@ -1287,6 +1289,8 @@ gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8= gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/vendor/golang.zx2c4.com/wireguard/conn/bind_std.go b/vendor/golang.zx2c4.com/wireguard/conn/bind_std.go new file mode 100644 index 0000000000..46df7fd4ef --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/bind_std.go @@ -0,0 +1,544 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "runtime" + "strconv" + "sync" + "syscall" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + _ Bind = (*StdNetBind)(nil) +) + +// StdNetBind implements Bind for all platforms. While Windows has its own Bind +// (see bind_windows.go), it may fall back to StdNetBind. +// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable +// methods for sending and receiving multiple datagrams per-syscall. See the +// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. +type StdNetBind struct { + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool + + blackhole4 bool + blackhole6 bool +} + +func NewStdNetBind() Bind { + return &StdNetBind{ + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, + }, + + msgsPool: sync.Pool{ + New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. + msgs := make([]ipv6.Message, IdealBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) + } + return &msgs + }, + }, + } +} + +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if + // supported. Typically this is a PKTINFO structure from/for control + // messages, see unix.PKTINFO for an example. + src []byte +} + +var ( + _ Bind = (*StdNetBind)(nil) + _ Endpoint = &StdNetEndpoint{} +) + +func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { + e, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return &StdNetEndpoint{ + AddrPort: e, + }, nil +} + +func (e *StdNetEndpoint) ClearSrc() { + if e.src != nil { + // Truncate src, no need to reallocate. + e.src = e.src[:0] + } +} + +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} + +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. + +func (e *StdNetEndpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() + return b +} + +func (e *StdNetEndpoint) DstToString() string { + return e.AddrPort.String() +} + +func listenNet(network string, port int) (*net.UDPConn, int, error) { + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) + if err != nil { + return nil, 0, err + } + + // Retrieve port. + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + return nil, 0, err + } + return conn.(*net.UDPConn), uaddr.Port, nil +} + +func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var err error + var tries int + + if s.ipv4 != nil || s.ipv6 != nil { + return nil, 0, ErrBindAlreadyOpen + } + + // Attempt to open ipv4 and ipv6 listeners on the same port. + // If uport is 0, we can retry on failure. +again: + port := int(uport) + var v4conn, v6conn *net.UDPConn + var v4pc *ipv4.PacketConn + var v6pc *ipv6.PacketConn + + v4conn, port, err = listenNet("udp4", port) + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + return nil, 0, err + } + + // Listen on the same port as we're using for ipv4. + v6conn, port, err = listenNet("udp6", port) + if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + v4conn.Close() + tries++ + goto again + } + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + v4conn.Close() + return nil, 0, err + } + var fns []ReceiveFunc + if v4conn != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v4pc = ipv4.NewPacketConn(v4conn) + s.ipv4PC = v4pc + } + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) + s.ipv4 = v4conn + } + if v6conn != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v6pc = ipv6.NewPacketConn(v6conn) + s.ipv6PC = v6pc + } + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) + s.ipv6 = v6conn + } + if len(fns) == 0 { + return nil, 0, syscall.EAFNOSUPPORT + } + + return fns, uint16(port), nil +} + +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) + if err != nil { + return 0, err + } + } else { + numMsgs, err = br.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + } + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue + } + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil +} + +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (s *StdNetBind) BatchSize() int { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + return IdealBatchSize + } + return 1 +} + +func (s *StdNetBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + var err1, err2 error + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil + s.ipv4PC = nil + } + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil + s.ipv6PC = nil + } + s.blackhole4 = false + s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false + if err1 != nil { + return err1 + } + return err2 +} + +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + +func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + br = s.ipv6PC + is6 = true + offload = s.ipv6TxOffload + } + s.mu.Unlock() + + if blackhole { + return nil + } + if conn == nil { + return syscall.EAFNOSUPPORT + } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) + if is6 { + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + } else { + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err +} + +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { + var ( + n int + err error + start int + ) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + for { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { + break + } + start += n + } + } else { + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) + if err != nil { + break + } + } + } + return err +} + +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { + var ( + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs + ) + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue + } + } + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") + } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 + } + } + return n, nil +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/bind_windows.go b/vendor/golang.zx2c4.com/wireguard/conn/bind_windows.go new file mode 100644 index 0000000000..d5095e004b --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/bind_windows.go @@ -0,0 +1,601 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "encoding/binary" + "io" + "net" + "net/netip" + "strconv" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/windows" + + "golang.zx2c4.com/wireguard/conn/winrio" +) + +const ( + packetsPerRing = 1024 + bytesPerPacket = 2048 - 32 + receiveSpins = 15 +) + +type ringPacket struct { + addr WinRingEndpoint + data [bytesPerPacket]byte +} + +type ringBuffer struct { + packets uintptr + head, tail uint32 + id winrio.BufferId + iocp windows.Handle + isFull bool + cq winrio.Cq + mu sync.Mutex + overlapped windows.Overlapped +} + +func (rb *ringBuffer) Push() *ringPacket { + for rb.isFull { + panic("ring is full") + } + ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) + rb.tail += 1 + if rb.tail%packetsPerRing == rb.head%packetsPerRing { + rb.isFull = true + } + return ret +} + +func (rb *ringBuffer) Return(count uint32) { + if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { + return + } + rb.head += count + rb.isFull = false +} + +type afWinRingBind struct { + sock windows.Handle + rx, tx ringBuffer + rq winrio.Rq + mu sync.Mutex + blackhole bool +} + +// WinRingBind uses Windows registered I/O for fast ring buffered networking. +type WinRingBind struct { + v4, v6 afWinRingBind + mu sync.RWMutex + isOpen atomic.Uint32 // 0, 1, or 2 +} + +func NewDefaultBind() Bind { return NewWinRingBind() } + +func NewWinRingBind() Bind { + if !winrio.Initialize() { + return NewStdNetBind() + } + return new(WinRingBind) +} + +type WinRingEndpoint struct { + family uint16 + data [30]byte +} + +var ( + _ Bind = (*WinRingBind)(nil) + _ Endpoint = (*WinRingEndpoint)(nil) +) + +func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { + host, port, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + host16, err := windows.UTF16PtrFromString(host) + if err != nil { + return nil, err + } + port16, err := windows.UTF16PtrFromString(port) + if err != nil { + return nil, err + } + hints := windows.AddrinfoW{ + Flags: windows.AI_NUMERICHOST, + Family: windows.AF_UNSPEC, + Socktype: windows.SOCK_DGRAM, + Protocol: windows.IPPROTO_UDP, + } + var addrinfo *windows.AddrinfoW + err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) + if err != nil { + return nil, err + } + defer windows.FreeAddrInfoW(addrinfo) + if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { + return nil, windows.ERROR_INVALID_ADDRESS + } + var dst [unsafe.Sizeof(WinRingEndpoint{})]byte + copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) + return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil +} + +func (*WinRingEndpoint) ClearSrc() {} + +func (e *WinRingEndpoint) DstIP() netip.Addr { + switch e.family { + case windows.AF_INET: + return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) + case windows.AF_INET6: + return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) + } + return netip.Addr{} +} + +func (e *WinRingEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported +} + +func (e *WinRingEndpoint) DstToBytes() []byte { + switch e.family { + case windows.AF_INET: + b := make([]byte, 0, 6) + b = append(b, e.data[2:6]...) + b = append(b, e.data[1], e.data[0]) + return b + case windows.AF_INET6: + b := make([]byte, 0, 18) + b = append(b, e.data[6:22]...) + b = append(b, e.data[1], e.data[0]) + return b + } + return nil +} + +func (e *WinRingEndpoint) DstToString() string { + switch e.family { + case windows.AF_INET: + return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() + case windows.AF_INET6: + var zone string + if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { + zone = strconv.FormatUint(uint64(scope), 10) + } + return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() + } + return "" +} + +func (e *WinRingEndpoint) SrcToString() string { + return "" +} + +func (ring *ringBuffer) CloseAndZero() { + if ring.cq != 0 { + winrio.CloseCompletionQueue(ring.cq) + ring.cq = 0 + } + if ring.iocp != 0 { + windows.CloseHandle(ring.iocp) + ring.iocp = 0 + } + if ring.id != 0 { + winrio.DeregisterBuffer(ring.id) + ring.id = 0 + } + if ring.packets != 0 { + windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) + ring.packets = 0 + } + ring.head = 0 + ring.tail = 0 + ring.isFull = false +} + +func (bind *afWinRingBind) CloseAndZero() { + bind.rx.CloseAndZero() + bind.tx.CloseAndZero() + if bind.sock != 0 { + windows.CloseHandle(bind.sock) + bind.sock = 0 + } + bind.blackhole = false +} + +func (bind *WinRingBind) closeAndZero() { + bind.isOpen.Store(0) + bind.v4.CloseAndZero() + bind.v6.CloseAndZero() +} + +func (ring *ringBuffer) Open() error { + var err error + packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing + ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + return err + } + ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) + if err != nil { + return err + } + ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return err + } + ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) + if err != nil { + return err + } + return nil +} + +func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { + var err error + bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return nil, err + } + err = bind.rx.Open() + if err != nil { + return nil, err + } + err = bind.tx.Open() + if err != nil { + return nil, err + } + bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) + if err != nil { + return nil, err + } + err = windows.Bind(bind.sock, sa) + if err != nil { + return nil, err + } + sa, err = windows.Getsockname(bind.sock) + if err != nil { + return nil, err + } + return sa, nil +} + +func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { + bind.mu.Lock() + defer bind.mu.Unlock() + defer func() { + if err != nil { + bind.closeAndZero() + } + }() + if bind.isOpen.Load() != 0 { + return nil, 0, ErrBindAlreadyOpen + } + var sa windows.Sockaddr + sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) + if err != nil { + return nil, 0, err + } + sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) + if err != nil { + return nil, 0, err + } + selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) + for i := 0; i < packetsPerRing; i++ { + err = bind.v4.InsertReceiveRequest() + if err != nil { + return nil, 0, err + } + err = bind.v6.InsertReceiveRequest() + if err != nil { + return nil, 0, err + } + } + bind.isOpen.Store(1) + return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err +} + +func (bind *WinRingBind) Close() error { + bind.mu.RLock() + if bind.isOpen.Load() != 1 { + bind.mu.RUnlock() + return nil + } + bind.isOpen.Store(2) + windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) + bind.mu.RUnlock() + bind.mu.Lock() + defer bind.mu.Unlock() + bind.closeAndZero() + return nil +} + +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (bind *WinRingBind) BatchSize() int { + // TODO: implement batching in and out of the ring + return 1 +} + +func (bind *WinRingBind) SetMark(mark uint32) error { + return nil +} + +func (bind *afWinRingBind) InsertReceiveRequest() error { + packet := bind.rx.Push() + dataBuffer := &winrio.Buffer{ + Id: bind.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), + Length: uint32(len(packet.data)), + } + addressBuffer := &winrio.Buffer{ + Id: bind.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + bind.mu.Lock() + defer bind.mu.Unlock() + return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) +} + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + bind.rx.mu.Lock() + defer bind.rx.mu.Unlock() + + var err error + var count uint32 + var results [1]winrio.Result +retry: + count = 0 + for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if tries > 0 { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + procyield(1) + } + count = winrio.DequeueCompletion(bind.rx.cq, results[:]) + } + if count == 0 { + err = winrio.Notify(bind.rx.cq) + if err != nil { + return 0, nil, err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return 0, nil, err + } + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + count = winrio.DequeueCompletion(bind.rx.cq, results[:]) + if count == 0 { + return 0, nil, io.ErrNoProgress + } + } + bind.rx.Return(1) + err = bind.InsertReceiveRequest() + if err != nil { + return 0, nil, err + } + // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us + // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to + // attacker bandwidth, just like the rest of the receive path. + if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + goto retry + } + if results[0].Status != 0 { + return 0, nil, windows.Errno(results[0].Status) + } + packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) + ep := packet.addr + n := copy(buf, packet.data[:results[0].BytesTransferred]) + return n, &ep, nil +} + +func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { + bind.mu.RLock() + defer bind.mu.RUnlock() + n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err +} + +func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { + bind.mu.RLock() + defer bind.mu.RUnlock() + n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err +} + +func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { + if isOpen.Load() != 1 { + return net.ErrClosed + } + if len(buf) > bytesPerPacket { + return io.ErrShortBuffer + } + bind.tx.mu.Lock() + defer bind.tx.mu.Unlock() + var results [packetsPerRing]winrio.Result + count := winrio.DequeueCompletion(bind.tx.cq, results[:]) + if count == 0 && bind.tx.isFull { + err := winrio.Notify(bind.tx.cq) + if err != nil { + return err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return err + } + if isOpen.Load() != 1 { + return net.ErrClosed + } + count = winrio.DequeueCompletion(bind.tx.cq, results[:]) + if count == 0 { + return io.ErrNoProgress + } + } + if count > 0 { + bind.tx.Return(count) + } + packet := bind.tx.Push() + packet.addr = *nend + copy(packet.data[:], buf) + dataBuffer := &winrio.Buffer{ + Id: bind.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), + Length: uint32(len(buf)), + } + addressBuffer := &winrio.Buffer{ + Id: bind.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + bind.mu.Lock() + defer bind.mu.Unlock() + return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) +} + +func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { + nend, ok := endpoint.(*WinRingEndpoint) + if !ok { + return ErrWrongEndpointType + } + bind.mu.RLock() + defer bind.mu.RUnlock() + for _, buf := range bufs { + switch nend.family { + case windows.AF_INET: + if bind.v4.blackhole { + continue + } + if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + case windows.AF_INET6: + if bind.v6.blackhole { + continue + } + if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + } + } + return nil +} + +func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv4.SyscallConn() + if err != nil { + return err + } + err2 := sysconn.Control(func(fd uintptr) { + err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) + }) + if err2 != nil { + return err2 + } + if err != nil { + return err + } + s.blackhole4 = blackhole + return nil +} + +func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv6.SyscallConn() + if err != nil { + return err + } + err2 := sysconn.Control(func(fd uintptr) { + err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) + }) + if err2 != nil { + return err2 + } + if err != nil { + return err + } + s.blackhole6 = blackhole + return nil +} + +func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + bind.mu.RLock() + defer bind.mu.RUnlock() + if bind.isOpen.Load() != 1 { + return net.ErrClosed + } + err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) + if err != nil { + return err + } + bind.v4.blackhole = blackhole + return nil +} + +func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + bind.mu.RLock() + defer bind.mu.RUnlock() + if bind.isOpen.Load() != 1 { + return net.ErrClosed + } + err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) + if err != nil { + return err + } + bind.v6.blackhole = blackhole + return nil +} + +func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { + const IP_UNICAST_IF = 31 + /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ + var bytes [4]byte + binary.BigEndian.PutUint32(bytes[:], interfaceIndex) + interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) + err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) + if err != nil { + return err + } + return nil +} + +func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { + const IPV6_UNICAST_IF = 31 + return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/boundif_android.go b/vendor/golang.zx2c4.com/wireguard/conn/boundif_android.go new file mode 100644 index 0000000000..dd3ca5b076 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/boundif_android.go @@ -0,0 +1,34 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { + sysconn, err := s.ipv4.SyscallConn() + if err != nil { + return -1, err + } + err = sysconn.Control(func(f uintptr) { + fd = int(f) + }) + if err != nil { + return -1, err + } + return +} + +func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { + sysconn, err := s.ipv6.SyscallConn() + if err != nil { + return -1, err + } + err = sysconn.Control(func(f uintptr) { + fd = int(f) + }) + if err != nil { + return -1, err + } + return +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/conn.go b/vendor/golang.zx2c4.com/wireguard/conn/conn.go new file mode 100644 index 0000000000..a1f57d2b1d --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/conn.go @@ -0,0 +1,133 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +// Package conn implements WireGuard's network connections. +package conn + +import ( + "errors" + "fmt" + "net/netip" + "reflect" + "runtime" + "strings" +) + +const ( + IdealBatchSize = 128 // maximum number of packets handled per read and write +) + +// A ReceiveFunc receives at least one packet from the network and writes them +// into packets. On a successful read it returns the number of elements of +// sizes, packets, and endpoints that should be evaluated. Some elements of +// sizes may be zero, and callers should ignore them. Callers must pass a sizes +// and eps slice with a length greater than or equal to the length of packets. +// These lengths must not exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) + +// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. +// +// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, +// depending on the platform-specific implementation. +type Bind interface { + // Open puts the Bind into a listening state on a given port and reports the actual + // port that it bound to. Passing zero results in a random selection. + // fns is the set of functions that will be called to receive packets. + Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) + + // Close closes the Bind listener. + // All fns returned by Open must return net.ErrClosed after a call to Close. + Close() error + + // SetMark sets the mark for each packet sent through this Bind. + // This mark is passed to the kernel as the socket option SO_MARK. + SetMark(mark uint32) error + + // Send writes one or more packets in bufs to address ep. The length of + // bufs must not exceed BatchSize(). + Send(bufs [][]byte, ep Endpoint) error + + // ParseEndpoint creates a new endpoint from a string. + ParseEndpoint(s string) (Endpoint, error) + + // BatchSize is the number of buffers expected to be passed to + // the ReceiveFuncs, and the maximum expected to be passed to SendBatch. + BatchSize() int +} + +// BindSocketToInterface is implemented by Bind objects that support being +// tied to a single network interface. Used by wireguard-windows. +type BindSocketToInterface interface { + BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error + BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error +} + +// PeekLookAtSocketFd is implemented by Bind objects that support having their +// file descriptor peeked at. Used by wireguard-android. +type PeekLookAtSocketFd interface { + PeekLookAtSocketFd4() (fd int, err error) + PeekLookAtSocketFd6() (fd int, err error) +} + +// An Endpoint maintains the source/destination caching for a peer. +// +// dst: the remote address of a peer ("endpoint" in uapi terminology) +// src: the local address from which datagrams originate going to the peer +type Endpoint interface { + ClearSrc() // clears the source address + SrcToString() string // returns the local source address (ip:port) + DstToString() string // returns the destination address (ip:port) + DstToBytes() []byte // used for mac2 cookie calculations + DstIP() netip.Addr + SrcIP() netip.Addr +} + +var ( + ErrBindAlreadyOpen = errors.New("bind is already open") + ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") +) + +func (fn ReceiveFunc) PrettyName() string { + name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + // 0. cheese/taco.beansIPv6.func12.func21218-fm + name = strings.TrimSuffix(name, "-fm") + // 1. cheese/taco.beansIPv6.func12.func21218 + if idx := strings.LastIndexByte(name, '/'); idx != -1 { + name = name[idx+1:] + // 2. taco.beansIPv6.func12.func21218 + } + for { + var idx int + for idx = len(name) - 1; idx >= 0; idx-- { + if name[idx] < '0' || name[idx] > '9' { + break + } + } + if idx == len(name)-1 { + break + } + const dotFunc = ".func" + if !strings.HasSuffix(name[:idx+1], dotFunc) { + break + } + name = name[:idx+1-len(dotFunc)] + // 3. taco.beansIPv6.func12 + // 4. taco.beansIPv6 + } + if idx := strings.LastIndexByte(name, '.'); idx != -1 { + name = name[idx+1:] + // 5. beansIPv6 + } + if name == "" { + return fmt.Sprintf("%p", fn) + } + if strings.HasSuffix(name, "IPv4") { + return "v4" + } + if strings.HasSuffix(name, "IPv6") { + return "v6" + } + return name +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/controlfns.go b/vendor/golang.zx2c4.com/wireguard/conn/controlfns.go new file mode 100644 index 0000000000..4f7d90fa10 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/controlfns.go @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + "syscall" +) + +// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is +// the max supported by a default configuration of macOS. Some platforms will +// silently clamp the value to other maximums, such as linux clamping to +// net.core.{r,w}mem_max (see _linux.go for additional implementation that works +// around this limitation) +const socketBufferSize = 7 << 20 + +// controlFn is the callback function signature from net.ListenConfig.Control. +// It is used to apply platform specific configuration to the socket prior to +// bind. +type controlFn func(network, address string, c syscall.RawConn) error + +// controlFns is a list of functions that are called from the listen config +// that can apply socket options. +var controlFns = []controlFn{} + +// listenConfig returns a net.ListenConfig that applies the controlFns to the +// socket prior to bind. This is used to apply socket buffer sizing and packet +// information OOB configuration for sticky sockets. +func listenConfig() *net.ListenConfig { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + for _, fn := range controlFns { + if err := fn(network, address, c); err != nil { + return err + } + } + return nil + }, + } +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/controlfns_linux.go b/vendor/golang.zx2c4.com/wireguard/conn/controlfns_linux.go new file mode 100644 index 0000000000..f6ab1d2ec4 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/controlfns_linux.go @@ -0,0 +1,69 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "runtime" + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + + // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by + // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to + // fail silently - the result of failure is lower performance on very fast + // links or high latency links. + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + // Set up to *mem_max + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) + // Set beyond *mem_max if CAP_NET_ADMIN + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) + }) + }, + + // Enable receiving of the packet information (IP_PKTINFO for IPv4, + // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. + func(network, address string, c syscall.RawConn) error { + var err error + switch network { + case "udp4": + if runtime.GOOS != "android" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) + }) + } + case "udp6": + c.Control(func(fd uintptr) { + if runtime.GOOS != "android" { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + if err != nil { + return + } + } + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + default: + err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) + } + return err + }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }) + return nil + }, + ) +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/controlfns_unix.go b/vendor/golang.zx2c4.com/wireguard/conn/controlfns_unix.go new file mode 100644 index 0000000000..91692c0a65 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/controlfns_unix.go @@ -0,0 +1,35 @@ +//go:build !windows && !linux && !wasm + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) + }) + }, + + func(network, address string, c syscall.RawConn) error { + var err error + if network == "udp6" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + } + return err + }, + ) +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/controlfns_windows.go b/vendor/golang.zx2c4.com/wireguard/conn/controlfns_windows.go new file mode 100644 index 0000000000..c3bdf7d3a9 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/controlfns_windows.go @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "syscall" + + "golang.org/x/sys/windows" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize) + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize) + }) + }, + ) +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/default.go b/vendor/golang.zx2c4.com/wireguard/conn/default.go new file mode 100644 index 0000000000..b6f761b9ed --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/default.go @@ -0,0 +1,10 @@ +//go:build !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func NewDefaultBind() Bind { return NewStdNetBind() } diff --git a/vendor/golang.zx2c4.com/wireguard/conn/errors_default.go b/vendor/golang.zx2c4.com/wireguard/conn/errors_default.go new file mode 100644 index 0000000000..f1e5b90e5a --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(err error) bool { + return false +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/errors_linux.go b/vendor/golang.zx2c4.com/wireguard/conn/errors_linux.go new file mode 100644 index 0000000000..8e61000f8a --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/features_default.go b/vendor/golang.zx2c4.com/wireguard/conn/features_default.go new file mode 100644 index 0000000000..d53ff5f7b6 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/features_linux.go b/vendor/golang.zx2c4.com/wireguard/conn/features_linux.go new file mode 100644 index 0000000000..8959d93582 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/features_linux.go @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + txOffload = errSyscall == nil + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) + rxOffload = errSyscall == nil && opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/gso_default.go b/vendor/golang.zx2c4.com/wireguard/conn/gso_default.go new file mode 100644 index 0000000000..57780dbb50 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/gso_default.go @@ -0,0 +1,21 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// gsoControlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const gsoControlSize = 0 diff --git a/vendor/golang.zx2c4.com/wireguard/conn/gso_linux.go b/vendor/golang.zx2c4.com/wireguard/conn/gso_linux.go new file mode 100644 index 0000000000..8596b292ec --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/gso_linux.go @@ -0,0 +1,65 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// gsoControlSize returns the recommended buffer size for pooling UDP +// offloading control data. +var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) diff --git a/vendor/golang.zx2c4.com/wireguard/conn/mark_default.go b/vendor/golang.zx2c4.com/wireguard/conn/mark_default.go new file mode 100644 index 0000000000..31023844a2 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/mark_default.go @@ -0,0 +1,12 @@ +//go:build !linux && !openbsd && !freebsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func (s *StdNetBind) SetMark(mark uint32) error { + return nil +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/mark_unix.go b/vendor/golang.zx2c4.com/wireguard/conn/mark_unix.go new file mode 100644 index 0000000000..d9e46eea7f --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/mark_unix.go @@ -0,0 +1,65 @@ +//go:build linux || openbsd || freebsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "runtime" + + "golang.org/x/sys/unix" +) + +var fwmarkIoctl int + +func init() { + switch runtime.GOOS { + case "linux", "android": + fwmarkIoctl = 36 /* unix.SO_MARK */ + case "freebsd": + fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ + case "openbsd": + fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ + } +} + +func (s *StdNetBind) SetMark(mark uint32) error { + var operr error + if fwmarkIoctl == 0 { + return nil + } + if s.ipv4 != nil { + fd, err := s.ipv4.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err == nil { + err = operr + } + if err != nil { + return err + } + } + if s.ipv6 != nil { + fd, err := s.ipv6.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err == nil { + err = operr + } + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/golang.zx2c4.com/wireguard/conn/sticky_default.go b/vendor/golang.zx2c4.com/wireguard/conn/sticky_default.go new file mode 100644 index 0000000000..0b213867d7 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/sticky_default.go @@ -0,0 +1,42 @@ +//go:build !linux || android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +const stickyControlSize = 0 + +const StdNetSupportsStickySockets = false diff --git a/vendor/golang.zx2c4.com/wireguard/conn/sticky_linux.go b/vendor/golang.zx2c4.com/wireguard/conn/sticky_linux.go new file mode 100644 index 0000000000..8e206e90b2 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/sticky_linux.go @@ -0,0 +1,112 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { + ep.ClearSrc() + + var ( + hdr unix.Cmsghdr + data []byte + rem []byte = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return + } + + if hdr.Level == unix.IPPROTO_IP && + hdr.Type == unix.IP_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + + if hdr.Level == unix.IPPROTO_IPV6 && + hdr.Type == unix.IPV6_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + } +} + +// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address +// and source ifindex found in ep. control's len will be set to 0 in the event +// that ep is a default value. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { + if cap(*control) < len(ep.src) { + return + } + *control = (*control)[:0] + *control = append(*control, ep.src...) +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + +const StdNetSupportsStickySockets = true diff --git a/vendor/golang.zx2c4.com/wireguard/conn/winrio/rio_windows.go b/vendor/golang.zx2c4.com/wireguard/conn/winrio/rio_windows.go new file mode 100644 index 0000000000..d1037bba90 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/conn/winrio/rio_windows.go @@ -0,0 +1,254 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package winrio + +import ( + "log" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + MsgDontNotify = 1 + MsgDefer = 2 + MsgWaitAll = 4 + MsgCommitOnly = 8 + + MaxCqSize = 0x8000000 + + invalidBufferId = 0xFFFFFFFF + invalidCq = 0 + invalidRq = 0 + corruptCq = 0xFFFFFFFF +) + +var extensionFunctionTable struct { + cbSize uint32 + rioReceive uintptr + rioReceiveEx uintptr + rioSend uintptr + rioSendEx uintptr + rioCloseCompletionQueue uintptr + rioCreateCompletionQueue uintptr + rioCreateRequestQueue uintptr + rioDequeueCompletion uintptr + rioDeregisterBuffer uintptr + rioNotify uintptr + rioRegisterBuffer uintptr + rioResizeCompletionQueue uintptr + rioResizeRequestQueue uintptr +} + +type Cq uintptr + +type Rq uintptr + +type BufferId uintptr + +type Buffer struct { + Id BufferId + Offset uint32 + Length uint32 +} + +type Result struct { + Status int32 + BytesTransferred uint32 + SocketContext uint64 + RequestContext uint64 +} + +type notificationCompletionType uint32 + +const ( + eventCompletion notificationCompletionType = 1 + iocpCompletion notificationCompletionType = 2 +) + +type eventNotificationCompletion struct { + completionType notificationCompletionType + event windows.Handle + notifyReset uint32 +} + +type iocpNotificationCompletion struct { + completionType notificationCompletionType + iocp windows.Handle + key uintptr + overlapped *windows.Overlapped +} + +var ( + initialized sync.Once + available bool +) + +func Initialize() bool { + initialized.Do(func() { + var ( + err error + socket windows.Handle + cq Cq + ) + defer func() { + if err == nil { + return + } + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 { + return + } + log.Printf("Registered I/O is unavailable: %v", err) + }() + socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return + } + defer windows.CloseHandle(socket) + WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} + const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 + ob := uint32(0) + err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, + (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)), + (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)), + &ob, nil, 0) + if err != nil { + return + } + + // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes + // failures in RIOCreateRequestQueue, so keep going to be certain this is supported. + var iocp windows.Handle + iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return + } + defer windows.CloseHandle(iocp) + var overlapped windows.Overlapped + cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped) + if err != nil { + return + } + defer CloseCompletionQueue(cq) + _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0) + if err != nil { + return + } + available = true + }) + return available +} + +func Socket(af, typ, proto int32) (windows.Handle, error) { + return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO) +} + +func CloseCompletionQueue(cq Cq) { + _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0) +} + +func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) { + notificationCompletion := &eventNotificationCompletion{ + completionType: eventCompletion, + event: event, + } + if notifyReset { + notificationCompletion.notifyReset = 1 + } + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) { + notificationCompletion := &iocpNotificationCompletion{ + completionType: iocpCompletion, + iocp: iocp, + key: key, + overlapped: overlapped, + } + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) { + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0) + if ret == invalidRq { + return 0, err + } + return Rq(ret), nil +} + +func DequeueCompletion(cq Cq, results []Result) uint32 { + var array uintptr + if len(results) > 0 { + array = uintptr(unsafe.Pointer(&results[0])) + } + ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results))) + if ret == corruptCq { + panic("cq is corrupt") + } + return uint32(ret) +} + +func DeregisterBuffer(id BufferId) { + _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0) +} + +func RegisterBuffer(buffer []byte) (BufferId, error) { + var buf unsafe.Pointer + if len(buffer) > 0 { + buf = unsafe.Pointer(&buffer[0]) + } + return RegisterPointer(buf, uint32(len(buffer))) +} + +func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) { + ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0) + if ret == invalidBufferId { + return 0, err + } + return BufferId(ret), nil +} + +func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) + if ret == 0 { + return err + } + return nil +} + +func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) + if ret == 0 { + return err + } + return nil +} + +func Notify(cq Cq) error { + ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0) + if ret != 0 { + return windows.Errno(ret) + } + return nil +} diff --git a/vendor/golang.zx2c4.com/wireguard/rwcancel/rwcancel.go b/vendor/golang.zx2c4.com/wireguard/rwcancel/rwcancel.go index 63e1510b10..e397c0e8ae 100644 --- a/vendor/golang.zx2c4.com/wireguard/rwcancel/rwcancel.go +++ b/vendor/golang.zx2c4.com/wireguard/rwcancel/rwcancel.go @@ -1,4 +1,4 @@ -//go:build !windows && !js +//go:build !windows && !wasm /* SPDX-License-Identifier: MIT * diff --git a/vendor/golang.zx2c4.com/wireguard/rwcancel/rwcancel_stub.go b/vendor/golang.zx2c4.com/wireguard/rwcancel/rwcancel_stub.go index 182940b32e..2a98b2b4ad 100644 --- a/vendor/golang.zx2c4.com/wireguard/rwcancel/rwcancel_stub.go +++ b/vendor/golang.zx2c4.com/wireguard/rwcancel/rwcancel_stub.go @@ -1,4 +1,4 @@ -//go:build windows || js +//go:build windows || wasm // SPDX-License-Identifier: MIT diff --git a/vendor/golang.zx2c4.com/wireguard/tun/checksum.go b/vendor/golang.zx2c4.com/wireguard/tun/checksum.go new file mode 100644 index 0000000000..29a8fc8fc0 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/tun/checksum.go @@ -0,0 +1,118 @@ +package tun + +import "encoding/binary" + +// TODO: Explore SIMD and/or other assembly optimizations. +// TODO: Test native endian loads. See RFC 1071 section 2 part B. +func checksumNoFold(b []byte, initial uint64) uint64 { + ac := initial + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] + } + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] + } + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] + } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] + } + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] + } + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] + } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + + return ac +} + +func checksum(b []byte, initial uint64) uint16 { + ac := checksumNoFold(b, initial) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + return uint16(ac) +} + +func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { + sum := checksumNoFold(srcAddr, 0) + sum = checksumNoFold(dstAddr, sum) + sum = checksumNoFold([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumNoFold(tmp, sum) +} diff --git a/vendor/golang.zx2c4.com/wireguard/tun/errors.go b/vendor/golang.zx2c4.com/wireguard/tun/errors.go new file mode 100644 index 0000000000..75ae3a434a --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/tun/errors.go @@ -0,0 +1,12 @@ +package tun + +import ( + "errors" +) + +var ( + // ErrTooManySegments is returned by Device.Read() when segmentation + // overflows the length of supplied buffers. This error should not cause + // reads to cease. + ErrTooManySegments = errors.New("too many segments") +) diff --git a/vendor/golang.zx2c4.com/wireguard/tun/offload_linux.go b/vendor/golang.zx2c4.com/wireguard/tun/offload_linux.go new file mode 100644 index 0000000000..9ff7fea8f9 --- /dev/null +++ b/vendor/golang.zx2c4.com/wireguard/tun/offload_linux.go @@ -0,0 +1,993 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool +} + +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. +type tcpGROTable struct { + itemsByFlow map[tcpFlowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize), + itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize) + } + return t +} + +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key tcpFlowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize), + itemsPool: make([][]udpGROItem, conn.IdealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^checksum(pkt[iphLen:], cSum) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 + maxUint16 = 1<<16 - 1 +) + +type groResult int + +const ( + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return groResultNoop + } + if len(pkt) < iphLen+tcphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return groResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return groResultTableInsert + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return groResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return groResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return groResultTableInsert +} + +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if item.key.isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP && canUDPGRO { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is +// supported. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result groResult + switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) + } + switch result { + case groResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case groResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) +} + +// gsoSplit splits packets from in into outBuffs, writing the size of each +// element into sizes. It returns the number of buffers populated, and/or an +// error. +func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { + iphLen := int(hdr.csumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if !isV6 { + in[10], in[11] = 0, 0 // clear ipv4 header checksum + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + transportCsumAt := int(hdr.csumStart + hdr.csumOffset) + in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + var firstTCPSeqNum uint32 + var protocol uint8 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { + protocol = unix.IPPROTO_TCP + firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + } else { + protocol = unix.IPPROTO_UDP + } + nextSegmentDataAt := int(hdr.hdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBuffs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(hdr.hdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBuffs[i][outOffset:] + + copy(out, in[:iphLen]) + if !isV6 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // copy transport header + copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) + + if protocol == unix.IPPROTO_TCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart)) + } + + // payload + copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // transport checksum + transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) + + nextSegmentDataAt += int(hdr.gsoSize) + } + return i, nil +} + +func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { + cSumAt := cSumStart + cSumOffset + // The initial value at the checksum offset should be summed with the + // checksum we compute. This is typically the pseudo-header checksum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + return nil +} diff --git a/vendor/golang.zx2c4.com/wireguard/tun/tun.go b/vendor/golang.zx2c4.com/wireguard/tun/tun.go index 01051b938e..0ae53d0733 100644 --- a/vendor/golang.zx2c4.com/wireguard/tun/tun.go +++ b/vendor/golang.zx2c4.com/wireguard/tun/tun.go @@ -18,12 +18,36 @@ const ( ) type Device interface { - File() *os.File // returns the file descriptor of the device - Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) - Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) - Flush() error // flush all previous writes to the device - MTU() (int, error) // returns the MTU of the device - Name() (string, error) // fetches and returns the current name - Events() <-chan Event // returns a constant channel of events related to the device - Close() error // stops the device and closes the event channel + // File returns the file descriptor of the device. + File() *os.File + + // Read one or more packets from the Device (without any additional headers). + // On a successful read it returns the number of packets read, and sets + // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). + // A nonzero offset can be used to instruct the Device on where to begin + // reading into each element of the bufs slice. + Read(bufs [][]byte, sizes []int, offset int) (n int, err error) + + // Write one or more packets to the device (without any additional headers). + // On a successful write it returns the number of packets written. A nonzero + // offset can be used to instruct the Device on where to begin writing from + // each packet contained within the bufs slice. + Write(bufs [][]byte, offset int) (int, error) + + // MTU returns the MTU of the Device. + MTU() (int, error) + + // Name returns the current name of the Device. + Name() (string, error) + + // Events returns a channel of type Event, which is fed Device events. + Events() <-chan Event + + // Close stops the Device and closes the Event channel. + Close() error + + // BatchSize returns the preferred/max number of packets that can be read or + // written in a single read/write call. BatchSize must not change over the + // lifetime of a Device. + BatchSize() int } diff --git a/vendor/golang.zx2c4.com/wireguard/tun/tun_darwin.go b/vendor/golang.zx2c4.com/wireguard/tun/tun_darwin.go index 7411a69463..c9a6c0bc45 100644 --- a/vendor/golang.zx2c4.com/wireguard/tun/tun_darwin.go +++ b/vendor/golang.zx2c4.com/wireguard/tun/tun_darwin.go @@ -8,6 +8,7 @@ package tun import ( "errors" "fmt" + "io" "net" "os" "sync" @@ -15,7 +16,6 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) @@ -33,7 +33,7 @@ type NativeTun struct { func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { for i := 0; i < 20; i++ { iface, err = net.InterfaceByIndex(index) - if err != nil && errors.Is(err, syscall.ENOMEM) { + if err != nil && errors.Is(err, unix.ENOMEM) { time.Sleep(time.Duration(i) * time.Second / 3) continue } @@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { retry: n, err := unix.Read(tun.routeSocket, data) if err != nil { - if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { + if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR { goto retry } tun.errors <- err @@ -217,45 +217,46 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + // TODO: the BSDs look very similar in Read() and Write(). They should be + // collapsed, with platform-specific files containing the varying parts of + // their implementations. select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range bufs { + buf = buf[offset-4:] + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(bufs), nil } func (tun *NativeTun) Close() error { @@ -318,6 +319,10 @@ func (tun *NativeTun) MTU() (int, error) { return int(ifr.MTU), nil } +func (tun *NativeTun) BatchSize() int { + return 1 +} + func socketCloexec(family, sotype, proto int) (fd int, err error) { // See go/src/net/sys_cloexec.go for background. syscall.ForkLock.RLock() diff --git a/vendor/golang.zx2c4.com/wireguard/tun/tun_freebsd.go b/vendor/golang.zx2c4.com/wireguard/tun/tun_freebsd.go index 42431aa3ee..7c65fd9992 100644 --- a/vendor/golang.zx2c4.com/wireguard/tun/tun_freebsd.go +++ b/vendor/golang.zx2c4.com/wireguard/tun/tun_freebsd.go @@ -333,45 +333,46 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { if offset < 4 { return 0, io.ErrShortBuffer } - buf = buf[offset-4:] - if len(buf) < 5 { - return 0, io.ErrShortBuffer - } - buf[0] = 0x00 - buf[1] = 0x00 - buf[2] = 0x00 - switch buf[4] >> 4 { - case 4: - buf[3] = unix.AF_INET - case 6: - buf[3] = unix.AF_INET6 - default: - return 0, unix.EAFNOSUPPORT + for i, buf := range bufs { + buf = buf[offset-4:] + if len(buf) < 5 { + return i, io.ErrShortBuffer + } + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } } - return tun.tunFile.Write(buf) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + return len(bufs), nil } func (tun *NativeTun) Close() error { @@ -428,3 +429,7 @@ func (tun *NativeTun) MTU() (int, error) { } return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/vendor/golang.zx2c4.com/wireguard/tun/tun_linux.go b/vendor/golang.zx2c4.com/wireguard/tun/tun_linux.go index 25dbc0749b..bd69cb552c 100644 --- a/vendor/golang.zx2c4.com/wireguard/tun/tun_linux.go +++ b/vendor/golang.zx2c4.com/wireguard/tun/tun_linux.go @@ -17,9 +17,8 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" - + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) @@ -33,17 +32,27 @@ type NativeTun struct { index int32 // if index errors chan error // async error handling events chan Event // device related events - nopi bool // the device was passed IFF_NO_PI netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + batchSize int + vnetHdr bool + udpGSO bool closeOnce sync.Once nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error + + readOpMu sync.Mutex // readOpMu guards readBuff + readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr + + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable } func (tun *NativeTun) File() *os.File { @@ -323,57 +332,147 @@ func (tun *NativeTun) nameSlow() (string, error) { return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { - if tun.nopi { - buf = buf[offset:] +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + tun.writeOpMu.Lock() + defer func() { + tun.tcpGROTable.reset() + tun.udpGROTable.reset() + tun.writeOpMu.Unlock() + }() + var ( + errs error + total int + ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen } else { - // reserve space for header - buf = buf[offset-4:] - - // add packet information header - buf[0] = 0x00 - buf[1] = 0x00 - if buf[4]>>4 == ipv6.Version { - buf[2] = 0x86 - buf[3] = 0xdd + for i := range bufs { + tun.toWrite = append(tun.toWrite, i) + } + } + for _, bufsI := range tun.toWrite { + n, err := tun.tunFile.Write(bufs[bufsI][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = errors.Join(errs, err) } else { - buf[2] = 0x08 - buf[3] = 0x00 + total += n } } + return total, errs +} - n, err := tun.tunFile.Write(buf) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed +// handleVirtioRead splits in into bufs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of bufs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(bufs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) + } + n := copy(bufs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) } - return n, err -} -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the transport header length and add it onto + // csumStart, which is synonymous for IP header length. + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.hdrLen = hdr.csumStart + 8 + } else { + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen + } + + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) + } + + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) + } + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) } -func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + tun.readOpMu.Lock() + defer tun.readOpMu.Unlock() select { - case err = <-tun.errors: + case err := <-tun.errors: + return 0, err default: - if tun.nopi { - n, err = tun.tunFile.Read(buf[offset:]) + readInto := bufs[0][offset:] + if tun.vnetHdr { + readInto = tun.readBuff[:] + } + n, err := tun.tunFile.Read(readInto) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if err != nil { + return 0, err + } + if tun.vnetHdr { + return handleVirtioRead(readInto[:n], bufs, sizes, offset) } else { - buff := buf[offset-4:] - n, err = tun.tunFile.Read(buff[:]) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } - if n < 4 { - n = 0 - } else { - n -= 4 - } + sizes[0] = n + return 1, nil } } - return } func (tun *NativeTun) Events() <-chan Event { @@ -399,6 +498,56 @@ func (tun *NativeTun) Close() error { return err2 } +func (tun *NativeTun) BatchSize() int { + return tun.batchSize +} + +const ( + // TODO: support TSO with ECN bits + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 +) + +func (tun *NativeTun) initFromFlags(name string) error { + sc, err := tun.tunFile.SyscallConn() + if err != nil { + return err + } + if e := sc.Control(func(fd uintptr) { + var ( + ifr *unix.Ifreq + ) + ifr, err = unix.NewIfreq(name) + if err != nil { + return + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return + } + got := ifr.Uint16() + if got&unix.IFF_VNET_HDR != 0 { + // tunTCPOffloads were added in Linux v2.6. We require their support + // if IFF_VNET_HDR is set. + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) + if err != nil { + return + } + tun.vnetHdr = true + tun.batchSize = conn.IdealBatchSize + // tunUDPOffloads were added in Linux v6.2. We do not return an + // error if they are unsupported at runtime. + tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil + } else { + tun.batchSize = 1 + } + }); e != nil { + return e + } + return err +} + +// CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { @@ -408,25 +557,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, err } - var ifr [ifReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - unix.Close(nfd) - return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG) + ifr, err := unix.NewIfreq(name) + if err != nil { + return nil, err } - copy(ifr[:], nameBytes) - *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(nfd), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - unix.Close(nfd) - return nil, errno + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() + // where a null write will return EINVAL indicating the TUN is up. + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) + if err != nil { + return nil, err } err = unix.SetNonblock(nfd, true) @@ -441,13 +581,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return CreateTUNFromFile(fd, mtu) } +// CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - nopi: false, + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() @@ -455,8 +598,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - // start event listener + err = tun.initFromFlags(name) + if err != nil { + return nil, err + } + // start event listener tun.index, err = getIFIndex(name) if err != nil { return nil, err @@ -485,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return tun, nil } +// CreateUnmonitoredTUNFromFD creates a Device from the provided file +// descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { @@ -492,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - nopi: true, + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } - return tun, name, nil + err = tun.initFromFlags(name) + if err != nil { + return nil, "", err + } + return tun, name, err } diff --git a/vendor/golang.zx2c4.com/wireguard/tun/tun_openbsd.go b/vendor/golang.zx2c4.com/wireguard/tun/tun_openbsd.go index e7fd79c5b0..ae571b90c3 100644 --- a/vendor/golang.zx2c4.com/wireguard/tun/tun_openbsd.go +++ b/vendor/golang.zx2c4.com/wireguard/tun/tun_openbsd.go @@ -8,13 +8,13 @@ package tun import ( "errors" "fmt" + "io" "net" "os" "sync" "syscall" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) @@ -204,45 +204,43 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range bufs { + buf = buf[offset-4:] + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(bufs), nil } func (tun *NativeTun) Close() error { @@ -329,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) { return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/vendor/golang.zx2c4.com/wireguard/tun/tun_windows.go b/vendor/golang.zx2c4.com/wireguard/tun/tun_windows.go index d5abb14898..2af8e3e922 100644 --- a/vendor/golang.zx2c4.com/wireguard/tun/tun_windows.go +++ b/vendor/golang.zx2c4.com/wireguard/tun/tun_windows.go @@ -15,7 +15,6 @@ import ( _ "unsafe" "golang.org/x/sys/windows" - "golang.zx2c4.com/wintun" ) @@ -44,6 +43,7 @@ type NativeTun struct { closeOnce sync.Once close atomic.Bool forcedMTU int + outSizes []int } var ( @@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update { @@ -134,9 +137,14 @@ func (tun *NativeTun) ForceMTU(mtu int) { } } +func (tun *NativeTun) BatchSize() int { + // TODO: implement batching with wintun + return 1 +} + // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() retry: @@ -152,11 +160,11 @@ retry: packet, err := tun.session.ReceivePacket() switch err { case nil: - packetSize := len(packet) - copy(buff[offset:], packet) + n := copy(bufs[0][offset:], packet) + sizes[0] = n tun.session.ReleaseReceivePacket(packet) - tun.rate.update(uint64(packetSize)) - return packetSize, nil + tun.rate.update(uint64(n)) + return 1, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(tun.readWait, windows.INFINITE) @@ -173,33 +181,33 @@ retry: } } -func (tun *NativeTun) Flush() error { - return nil -} - -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() if tun.close.Load() { return 0, os.ErrClosed } - packetSize := len(buff) - offset - tun.rate.update(uint64(packetSize)) + for i, buf := range bufs { + packetSize := len(buf) - offset + tun.rate.update(uint64(packetSize)) - packet, err := tun.session.AllocateSendPacket(packetSize) - if err == nil { - copy(packet, buff[offset:]) - tun.session.SendPacket(packet) - return packetSize, nil - } - switch err { - case windows.ERROR_HANDLE_EOF: - return 0, os.ErrClosed - case windows.ERROR_BUFFER_OVERFLOW: - return 0, nil // Dropping when ring is full. + packet, err := tun.session.AllocateSendPacket(packetSize) + switch err { + case nil: + // TODO: Explore options to eliminate this copy. + copy(packet, buf[offset:]) + tun.session.SendPacket(packet) + continue + case windows.ERROR_HANDLE_EOF: + return i, os.ErrClosed + case windows.ERROR_BUFFER_OVERFLOW: + continue // Dropping when ring is full. + default: + return i, fmt.Errorf("Write failed: %w", err) + } } - return 0, fmt.Errorf("Write failed: %w", err) + return len(bufs), nil } // LUID returns Windows interface instance ID. diff --git a/vendor/modules.txt b/vendor/modules.txt index 32be3171b2..2c3ff26507 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -922,8 +922,10 @@ golang.org/x/tools/internal/versions # golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 ## explicit; go 1.17 golang.zx2c4.com/wintun -# golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 => golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 -## explicit; go 1.19 +# golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 +## explicit; go 1.20 +golang.zx2c4.com/wireguard/conn +golang.zx2c4.com/wireguard/conn/winrio golang.zx2c4.com/wireguard/rwcancel golang.zx2c4.com/wireguard/tun # google.golang.org/protobuf v1.35.1 @@ -975,5 +977,3 @@ mvdan.cc/sh/v3/fileutil mvdan.cc/sh/v3/pattern mvdan.cc/sh/v3/shell mvdan.cc/sh/v3/syntax -# github.com/xxxserxxx/gotop/v4 => github.com/ersonp/gotop/v4 v4.2.1 -# golang.zx2c4.com/wireguard => golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675