Skip to content

Commit 325d0b6

Browse files
committed
ping: Rewrite UnprivilegedConn
1 parent 737ebf0 commit 325d0b6

File tree

1 file changed

+81
-53
lines changed

1 file changed

+81
-53
lines changed

ping/socket_linux_unprivileged.go

Lines changed: 81 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ import (
55
"net"
66
"net/netip"
77
"os"
8+
"sync"
89
"time"
910

1011
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
1112
"github.com/sagernet/sing-tun/internal/gtcpip/header"
12-
"github.com/sagernet/sing/common/atomic"
13+
"github.com/sagernet/sing/common"
1314
"github.com/sagernet/sing/common/buf"
1415
"github.com/sagernet/sing/common/control"
1516
M "github.com/sagernet/sing/common/metadata"
17+
"github.com/sagernet/sing/common/pipe"
1618
)
1719

1820
type UnprivilegedConn struct {
@@ -21,7 +23,9 @@ type UnprivilegedConn struct {
2123
controlFunc control.Func
2224
destination netip.Addr
2325
receiveChan chan *unprivilegedResponse
24-
readDeadline atomic.TypedValue[time.Time]
26+
readDeadline pipe.Deadline
27+
natMap map[uint16]net.Conn
28+
natMapMutex sync.Mutex
2529
}
2630

2731
type unprivilegedResponse struct {
@@ -38,11 +42,13 @@ func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destinat
3842
conn.Close()
3943
ctx, cancel := context.WithCancel(ctx)
4044
return &UnprivilegedConn{
41-
ctx: ctx,
42-
cancel: cancel,
43-
controlFunc: controlFunc,
44-
destination: destination,
45-
receiveChan: make(chan *unprivilegedResponse),
45+
ctx: ctx,
46+
cancel: cancel,
47+
controlFunc: controlFunc,
48+
destination: destination,
49+
receiveChan: make(chan *unprivilegedResponse),
50+
readDeadline: pipe.MakeDeadline(),
51+
natMap: make(map[uint16]net.Conn),
4652
}, nil
4753
}
4854

@@ -55,6 +61,8 @@ func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
5561
return
5662
case <-c.ctx.Done():
5763
return 0, os.ErrClosed
64+
case <-c.readDeadline.Wait():
65+
return 0, os.ErrDeadlineExceeded
5866
}
5967
}
6068

@@ -69,14 +77,12 @@ func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr neti
6977
return
7078
case <-c.ctx.Done():
7179
return 0, 0, netip.Addr{}, os.ErrClosed
80+
case <-c.readDeadline.Wait():
81+
return 0, 0, netip.Addr{}, os.ErrDeadlineExceeded
7282
}
7383
}
7484

7585
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
76-
conn, err := connect(false, c.controlFunc, c.destination)
77-
if err != nil {
78-
return
79-
}
8086
var identifier uint16
8187
if !c.destination.Is6() {
8288
icmpHdr := header.ICMPv4(b)
@@ -85,62 +91,84 @@ func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
8591
icmpHdr := header.ICMPv6(b)
8692
identifier = icmpHdr.Ident()
8793
}
88-
if readDeadline := c.readDeadline.Load(); !readDeadline.IsZero() {
89-
conn.SetReadDeadline(readDeadline)
94+
95+
c.natMapMutex.Lock()
96+
if err = c.ctx.Err(); err != nil {
97+
return 0, err
98+
}
99+
conn, ok := c.natMap[identifier]
100+
if !ok {
101+
conn, err = connect(false, c.controlFunc, c.destination)
102+
if err != nil {
103+
c.natMapMutex.Unlock()
104+
return 0, err
105+
}
106+
go c.fetchResponse(conn.(*net.UDPConn), identifier)
90107
}
108+
c.natMapMutex.Unlock()
109+
91110
n, err = conn.Write(b)
92111
if err != nil {
93-
conn.Close()
112+
c.removeConn(conn.(*net.UDPConn), identifier)
94113
return
95114
}
96-
go c.fetchResponse(conn, identifier)
97115
return
98116
}
99117

100-
func (c *UnprivilegedConn) fetchResponse(conn net.Conn, identifier uint16) {
101-
done := make(chan struct{})
102-
defer close(done)
103-
go func() {
118+
func (c *UnprivilegedConn) fetchResponse(conn *net.UDPConn, identifier uint16) {
119+
defer c.removeConn(conn, identifier)
120+
for {
121+
buffer := buf.NewPacket()
122+
cmsgBuffer := buf.NewSize(1024)
123+
n, oobN, _, addr, err := conn.ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
124+
if err != nil {
125+
buffer.Release()
126+
cmsgBuffer.Release()
127+
return
128+
}
129+
buffer.Truncate(n)
130+
cmsgBuffer.Truncate(oobN)
131+
if !c.destination.Is6() {
132+
icmpHdr := header.ICMPv4(buffer.Bytes())
133+
icmpHdr.SetIdent(identifier)
134+
icmpHdr.SetChecksum(0)
135+
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
136+
} else {
137+
icmpHdr := header.ICMPv6(buffer.Bytes())
138+
icmpHdr.SetIdent(identifier)
139+
// offload checksum here since we don't have source address here
140+
}
104141
select {
142+
case c.receiveChan <- &unprivilegedResponse{
143+
Buffer: buffer,
144+
Cmsg: cmsgBuffer,
145+
Addr: addr.Addr(),
146+
}:
105147
case <-c.ctx.Done():
106-
case <-done:
148+
buffer.Release()
149+
cmsgBuffer.Release()
150+
return
107151
}
108-
conn.Close()
109-
}()
110-
buffer := buf.NewPacket()
111-
cmsgBuffer := buf.NewSize(1024)
112-
n, oobN, _, addr, err := conn.(*net.UDPConn).ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
113-
if err != nil {
114-
buffer.Release()
115-
cmsgBuffer.Release()
116-
return
117152
}
118-
buffer.Truncate(n)
119-
cmsgBuffer.Truncate(oobN)
120-
if !c.destination.Is6() {
121-
icmpHdr := header.ICMPv4(buffer.Bytes())
122-
icmpHdr.SetIdent(identifier)
123-
icmpHdr.SetChecksum(0)
124-
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
125-
} else {
126-
icmpHdr := header.ICMPv6(buffer.Bytes())
127-
icmpHdr.SetIdent(identifier)
128-
// offload checksum here since we don't have source address here
129-
}
130-
select {
131-
case c.receiveChan <- &unprivilegedResponse{
132-
Buffer: buffer,
133-
Cmsg: cmsgBuffer,
134-
Addr: addr.Addr(),
135-
}:
136-
case <-c.ctx.Done():
137-
buffer.Release()
138-
cmsgBuffer.Release()
153+
}
154+
155+
func (c *UnprivilegedConn) removeConn(conn *net.UDPConn, identifier uint16) {
156+
c.natMapMutex.Lock()
157+
_ = conn.Close()
158+
if c.natMap[identifier] == conn {
159+
delete(c.natMap, identifier)
139160
}
161+
c.natMapMutex.Unlock()
140162
}
141163

142164
func (c *UnprivilegedConn) Close() error {
165+
c.natMapMutex.Lock()
143166
c.cancel()
167+
for _, conn := range c.natMap {
168+
_ = conn.Close()
169+
}
170+
common.ClearMap(c.natMap)
171+
c.natMapMutex.Unlock()
144172
return nil
145173
}
146174

@@ -153,14 +181,14 @@ func (c *UnprivilegedConn) RemoteAddr() net.Addr {
153181
}
154182

155183
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
156-
return os.ErrInvalid
184+
return c.SetReadDeadline(t)
157185
}
158186

159187
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
160-
c.readDeadline.Store(t)
188+
c.readDeadline.Set(t)
161189
return nil
162190
}
163191

164192
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
165-
return os.ErrInvalid
193+
return nil
166194
}

0 commit comments

Comments
 (0)