Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UDP listener on IPv4-only Linux #787

Merged
merged 6 commits into from
Jan 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 74 additions & 11 deletions udp/udp_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

type StdConn struct {
sysFd int
isV4 bool
l *logrus.Logger
batch int
}
Expand All @@ -45,9 +46,22 @@ const (

type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32

func maybeIPV4(ip net.IP) (net.IP, bool) {
ip4 := ip.To4()
if ip4 != nil {
return ip4, true
}
return ip, false
}

func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
ipV4, isV4 := maybeIPV4(ip)
af := unix.AF_INET6
if isV4 {
af = unix.AF_INET
}
syscall.ForkLock.RLock()
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
if err == nil {
unix.CloseOnExec(fd)
}
Expand All @@ -58,17 +72,24 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
return nil, fmt.Errorf("unable to open socket: %s", err)
}

var lip [16]byte
copy(lip[:], ip.To16())

if multi {
if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
}
}

//TODO: support multiple listening IPs (for limiting ipv6)
if err = unix.Bind(fd, &unix.SockaddrInet6{Addr: lip, Port: port}); err != nil {
var sa unix.Sockaddr
if isV4 {
sa4 := &unix.SockaddrInet4{Port: port}
copy(sa4.Addr[:], ipV4)
sa = sa4
} else {
sa6 := &unix.SockaddrInet6{Port: port}
copy(sa6.Addr[:], ip.To16())
sa = sa6
}
if err = unix.Bind(fd, sa); err != nil {
return nil, fmt.Errorf("unable to bind to socket: %s", err)
}

Expand All @@ -77,7 +98,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
//l.Println(v, err)

return &StdConn{sysFd: fd, l: l, batch: batch}, err
return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err
}

func (u *StdConn) Rebind() error {
Expand Down Expand Up @@ -143,7 +164,11 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew

//metric.Update(int64(n))
for i := 0; i < n; i++ {
udpAddr.IP = names[i][8:24]
if u.isV4 {
udpAddr.IP = names[i][4:8]
} else {
udpAddr.IP = names[i][8:24]
}
udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
}
Expand Down Expand Up @@ -192,13 +217,18 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
}

func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
if u.isV4 {
return u.writeTo4(b, addr)
}
return u.writeTo6(b, addr)
}

func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
p := (*[2]byte)(unsafe.Pointer(&rsa.Port))
p[0] = byte(addr.Port >> 8)
p[1] = byte(addr.Port)
copy(rsa.Addr[:], addr.IP)
// Little Endian -> Network Endian
rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
copy(rsa.Addr[:], addr.IP.To16())

for {
_, _, err := unix.Syscall6(
Expand All @@ -221,6 +251,39 @@ func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
}
}

func (u *StdConn) writeTo4(b []byte, addr *Addr) error {
addrV4, isAddrV4 := maybeIPV4(addr.IP)
if !isAddrV4 {
return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
}

var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET
// Little Endian -> Network Endian
rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
copy(rsa.Addr[:], addrV4)

for {
_, _, err := unix.Syscall6(
unix.SYS_SENDTO,
uintptr(u.sysFd),
uintptr(unsafe.Pointer(&b[0])),
uintptr(len(b)),
uintptr(0),
uintptr(unsafe.Pointer(&rsa)),
uintptr(unix.SizeofSockaddrInet4),
)

if err != 0 {
return &net.OpError{Op: "sendto", Err: err}
}

//TODO: handle incomplete writes

return nil
}
}

func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
Expand Down
Loading