Skip to content

Commit

Permalink
Support syscallReadWaiter and syscallPacketReadWaiter on windows
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx authored Dec 21, 2023
1 parent 249dc05 commit 48acfc4
Showing 1 changed file with 102 additions and 7 deletions.
109 changes: 102 additions & 7 deletions common/bufio/copy_direct_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,40 @@ package bufio

import (
"io"
"net/netip"
"os"
"syscall"
"unsafe"

"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"

"golang.org/x/sys/windows"
)

//go:linkname modws2_32 golang.org/x/sys/windows.modws2_32
var modws2_32 *windows.LazyDLL

var procrecv = modws2_32.NewProc("recv")

//go:linkname errnoErr golang.org/x/sys/windows.errnoErr
func errnoErr(e syscall.Errno) error

func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
n = int32(r0)
if n == -1 {
err = errnoErr(e1)
}
return
}

var _ N.ReadWaiter = (*syscallReadWaiter)(nil)

type syscallReadWaiter struct {
Expand All @@ -38,15 +62,13 @@ func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (nee
w.readFunc = func(fd uintptr) (done bool) {
if !w.hasData {
w.hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recv` syscall will not block the system thread.
return false
}
buffer := w.options.NewBuffer()
iovecList := []windows.WSABuf{windows.WSABuf{}}
iovecList[0].Buf = &buffer.FreeBytes()[0]
iovecList[0].Len = uint32(len(buffer.FreeBytes()))
var readN uint32
var flags uint32
w.readErr = windows.WSARecv(windows.Handle(fd), &iovecList[0], uint32(len(iovecList)), &readN, &flags, nil, nil)
var readN int32
readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0)
if readN > 0 {
buffer.Truncate(int(readN))
w.options.PostReturn(buffer)
Expand Down Expand Up @@ -85,6 +107,79 @@ func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
return
}

func createSyscallPacketReadWaiter(reader any) (N.PacketReadWaiter, bool) {
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)

type syscallPacketReadWaiter struct {
rawConn syscall.RawConn
readErr error
readFrom M.Socksaddr
readFunc func(fd uintptr) (done bool)
hasData bool
buffer *buf.Buffer
options N.ReadWaitOptions
}

func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
rawConn, err := syscallConn.SyscallConn()
if err == nil {
return &syscallPacketReadWaiter{rawConn: rawConn}, true
}
}
return nil, false
}

func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
w.options = options
w.readFunc = func(fd uintptr) (done bool) {
if !w.hasData {
w.hasData = true
// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
return false
}
buffer := w.options.NewPacketBuffer()
var readN int
var from windows.Sockaddr
readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
if readN > 0 {
buffer.Truncate(readN)
w.options.PostReturn(buffer)
w.buffer = buffer
} else {
buffer.Release()
}
if w.readErr == windows.WSAEWOULDBLOCK {
return false
}
if from != nil {
switch fromAddr := from.(type) {
case *windows.SockaddrInet4:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
case *windows.SockaddrInet6:
w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
}
}
w.hasData = false
return true
}
return false
}

func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
if w.readFunc == nil {
return nil, M.Socksaddr{}, os.ErrInvalid
}
err = w.rawConn.Read(w.readFunc)
if err != nil {
return
}
if w.readErr != nil {
err = E.Cause(w.readErr, "raw read")
return
}
buffer = w.buffer
w.buffer = nil
destination = w.readFrom
return
}

0 comments on commit 48acfc4

Please sign in to comment.