@@ -5,14 +5,16 @@ import (
55 "net"
66 "net/netip"
77 "os"
8+ "sync"
89 "time"
910
1011 "github.com/sagernet/sing-tun/internal/gtcpip/checksum"
1112 "github.com/sagernet/sing-tun/internal/gtcpip/header"
12- "github.com/sagernet/sing/common/atomic "
13+ "github.com/sagernet/sing/common"
1314 "github.com/sagernet/sing/common/buf"
1415 "github.com/sagernet/sing/common/control"
1516 M "github.com/sagernet/sing/common/metadata"
17+ "github.com/sagernet/sing/common/pipe"
1618)
1719
1820type UnprivilegedConn struct {
@@ -21,7 +23,9 @@ type UnprivilegedConn struct {
2123 controlFunc control.Func
2224 destination netip.Addr
2325 receiveChan chan * unprivilegedResponse
24- readDeadline atomic.TypedValue [time.Time ]
26+ readDeadline pipe.Deadline
27+ natMap map [uint16 ]net.Conn
28+ natMapMutex sync.Mutex
2529}
2630
2731type unprivilegedResponse struct {
@@ -38,11 +42,13 @@ func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destinat
3842 conn .Close ()
3943 ctx , cancel := context .WithCancel (ctx )
4044 return & UnprivilegedConn {
41- ctx : ctx ,
42- cancel : cancel ,
43- controlFunc : controlFunc ,
44- destination : destination ,
45- receiveChan : make (chan * unprivilegedResponse ),
45+ ctx : ctx ,
46+ cancel : cancel ,
47+ controlFunc : controlFunc ,
48+ destination : destination ,
49+ receiveChan : make (chan * unprivilegedResponse ),
50+ readDeadline : pipe .MakeDeadline (),
51+ natMap : make (map [uint16 ]net.Conn ),
4652 }, nil
4753}
4854
@@ -55,6 +61,8 @@ func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
5561 return
5662 case <- c .ctx .Done ():
5763 return 0 , os .ErrClosed
64+ case <- c .readDeadline .Wait ():
65+ return 0 , os .ErrDeadlineExceeded
5866 }
5967}
6068
@@ -69,14 +77,12 @@ func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr neti
6977 return
7078 case <- c .ctx .Done ():
7179 return 0 , 0 , netip.Addr {}, os .ErrClosed
80+ case <- c .readDeadline .Wait ():
81+ return 0 , 0 , netip.Addr {}, os .ErrDeadlineExceeded
7282 }
7383}
7484
7585func (c * UnprivilegedConn ) Write (b []byte ) (n int , err error ) {
76- conn , err := connect (false , c .controlFunc , c .destination )
77- if err != nil {
78- return
79- }
8086 var identifier uint16
8187 if ! c .destination .Is6 () {
8288 icmpHdr := header .ICMPv4 (b )
@@ -85,62 +91,84 @@ func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
8591 icmpHdr := header .ICMPv6 (b )
8692 identifier = icmpHdr .Ident ()
8793 }
88- if readDeadline := c .readDeadline .Load (); ! readDeadline .IsZero () {
89- conn .SetReadDeadline (readDeadline )
94+
95+ c .natMapMutex .Lock ()
96+ if err = c .ctx .Err (); err != nil {
97+ return 0 , err
98+ }
99+ conn , ok := c .natMap [identifier ]
100+ if ! ok {
101+ conn , err = connect (false , c .controlFunc , c .destination )
102+ if err != nil {
103+ c .natMapMutex .Unlock ()
104+ return 0 , err
105+ }
106+ go c .fetchResponse (conn .(* net.UDPConn ), identifier )
90107 }
108+ c .natMapMutex .Unlock ()
109+
91110 n , err = conn .Write (b )
92111 if err != nil {
93- conn .Close ( )
112+ c . removeConn ( conn .( * net. UDPConn ), identifier )
94113 return
95114 }
96- go c .fetchResponse (conn , identifier )
97115 return
98116}
99117
100- func (c * UnprivilegedConn ) fetchResponse (conn net.Conn , identifier uint16 ) {
101- done := make (chan struct {})
102- defer close (done )
103- go func () {
118+ func (c * UnprivilegedConn ) fetchResponse (conn * net.UDPConn , identifier uint16 ) {
119+ defer c .removeConn (conn , identifier )
120+ for {
121+ buffer := buf .NewPacket ()
122+ cmsgBuffer := buf .NewSize (1024 )
123+ n , oobN , _ , addr , err := conn .ReadMsgUDPAddrPort (buffer .FreeBytes (), cmsgBuffer .FreeBytes ())
124+ if err != nil {
125+ buffer .Release ()
126+ cmsgBuffer .Release ()
127+ return
128+ }
129+ buffer .Truncate (n )
130+ cmsgBuffer .Truncate (oobN )
131+ if ! c .destination .Is6 () {
132+ icmpHdr := header .ICMPv4 (buffer .Bytes ())
133+ icmpHdr .SetIdent (identifier )
134+ icmpHdr .SetChecksum (0 )
135+ icmpHdr .SetChecksum (header .ICMPv4Checksum (icmpHdr [:header .ICMPv4MinimumSize ], checksum .Checksum (icmpHdr .Payload (), 0 )))
136+ } else {
137+ icmpHdr := header .ICMPv6 (buffer .Bytes ())
138+ icmpHdr .SetIdent (identifier )
139+ // offload checksum here since we don't have source address here
140+ }
104141 select {
142+ case c .receiveChan <- & unprivilegedResponse {
143+ Buffer : buffer ,
144+ Cmsg : cmsgBuffer ,
145+ Addr : addr .Addr (),
146+ }:
105147 case <- c .ctx .Done ():
106- case <- done :
148+ buffer .Release ()
149+ cmsgBuffer .Release ()
150+ return
107151 }
108- conn .Close ()
109- }()
110- buffer := buf .NewPacket ()
111- cmsgBuffer := buf .NewSize (1024 )
112- n , oobN , _ , addr , err := conn .(* net.UDPConn ).ReadMsgUDPAddrPort (buffer .FreeBytes (), cmsgBuffer .FreeBytes ())
113- if err != nil {
114- buffer .Release ()
115- cmsgBuffer .Release ()
116- return
117152 }
118- buffer .Truncate (n )
119- cmsgBuffer .Truncate (oobN )
120- if ! c .destination .Is6 () {
121- icmpHdr := header .ICMPv4 (buffer .Bytes ())
122- icmpHdr .SetIdent (identifier )
123- icmpHdr .SetChecksum (0 )
124- icmpHdr .SetChecksum (header .ICMPv4Checksum (icmpHdr [:header .ICMPv4MinimumSize ], checksum .Checksum (icmpHdr .Payload (), 0 )))
125- } else {
126- icmpHdr := header .ICMPv6 (buffer .Bytes ())
127- icmpHdr .SetIdent (identifier )
128- // offload checksum here since we don't have source address here
129- }
130- select {
131- case c .receiveChan <- & unprivilegedResponse {
132- Buffer : buffer ,
133- Cmsg : cmsgBuffer ,
134- Addr : addr .Addr (),
135- }:
136- case <- c .ctx .Done ():
137- buffer .Release ()
138- cmsgBuffer .Release ()
153+ }
154+
155+ func (c * UnprivilegedConn ) removeConn (conn * net.UDPConn , identifier uint16 ) {
156+ c .natMapMutex .Lock ()
157+ _ = conn .Close ()
158+ if c .natMap [identifier ] == conn {
159+ delete (c .natMap , identifier )
139160 }
161+ c .natMapMutex .Unlock ()
140162}
141163
142164func (c * UnprivilegedConn ) Close () error {
165+ c .natMapMutex .Lock ()
143166 c .cancel ()
167+ for _ , conn := range c .natMap {
168+ _ = conn .Close ()
169+ }
170+ common .ClearMap (c .natMap )
171+ c .natMapMutex .Unlock ()
144172 return nil
145173}
146174
@@ -153,14 +181,14 @@ func (c *UnprivilegedConn) RemoteAddr() net.Addr {
153181}
154182
155183func (c * UnprivilegedConn ) SetDeadline (t time.Time ) error {
156- return os . ErrInvalid
184+ return c . SetReadDeadline ( t )
157185}
158186
159187func (c * UnprivilegedConn ) SetReadDeadline (t time.Time ) error {
160- c .readDeadline .Store (t )
188+ c .readDeadline .Set (t )
161189 return nil
162190}
163191
164192func (c * UnprivilegedConn ) SetWriteDeadline (t time.Time ) error {
165- return os . ErrInvalid
193+ return nil
166194}
0 commit comments