Skip to content

Commit ac77bdc

Browse files
committed
Add lazy conn support for gVisor
1 parent 3185844 commit ac77bdc

File tree

5 files changed

+248
-160
lines changed

5 files changed

+248
-160
lines changed

stack_gvisor.go

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -76,43 +76,17 @@ func (t *GVisor) Start() error {
7676
return err
7777
}
7878
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
79-
var wq waiter.Queue
80-
handshakeCtx, cancel := context.WithCancel(context.Background())
81-
go func() {
82-
select {
83-
case <-t.ctx.Done():
84-
wq.Notify(wq.Events())
85-
case <-handshakeCtx.Done():
86-
}
87-
}()
88-
endpoint, err := r.CreateEndpoint(&wq)
89-
cancel()
90-
if err != nil {
91-
r.Complete(true)
92-
return
93-
}
94-
r.Complete(false)
95-
endpoint.SocketOptions().SetKeepAlive(true)
96-
keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
97-
endpoint.SetSockOpt(&keepAliveIdle)
98-
keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
99-
endpoint.SetSockOpt(&keepAliveInterval)
100-
tcpConn := gonet.NewTCPConn(&wq, endpoint)
101-
lAddr := tcpConn.RemoteAddr()
102-
rAddr := tcpConn.LocalAddr()
103-
if lAddr == nil || rAddr == nil {
104-
tcpConn.Close()
105-
return
79+
var metadata M.Metadata
80+
metadata.Source = M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
81+
metadata.Destination = M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
82+
conn := &gLazyConn{
83+
parentCtx: t.ctx,
84+
stack: t.stack,
85+
request: r,
86+
localAddr: metadata.Source.TCPAddr(),
87+
remoteAddr: metadata.Destination.TCPAddr(),
10688
}
107-
go func() {
108-
var metadata M.Metadata
109-
metadata.Source = M.SocksaddrFromNet(lAddr)
110-
metadata.Destination = M.SocksaddrFromNet(rAddr)
111-
hErr := t.handler.NewConnection(t.ctx, &gTCPConn{tcpConn}, metadata)
112-
if hErr != nil {
113-
endpoint.Abort()
114-
}
115-
}()
89+
_ = t.handler.NewConnection(t.ctx, conn, metadata)
11690
})
11791
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
11892
if !t.endpointIndependentNat {
@@ -129,12 +103,11 @@ func (t *GVisor) Start() error {
129103
endpoint.Abort()
130104
return
131105
}
132-
gConn := &gUDPConn{UDPConn: udpConn}
133106
go func() {
134107
var metadata M.Metadata
135108
metadata.Source = M.SocksaddrFromNet(lAddr)
136109
metadata.Destination = M.SocksaddrFromNet(rAddr)
137-
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second)
110+
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second)
138111
hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
139112
if hErr != nil {
140113
endpoint.Abort()
@@ -191,7 +164,7 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
191164
})
192165
tErr := ipStack.CreateNIC(defaultNIC, ep)
193166
if tErr != nil {
194-
return nil, E.New("create nic: ", wrapStackError(tErr))
167+
return nil, E.New("create nic: ", gonet.TranslateNetstackError(tErr))
195168
}
196169
ipStack.SetRouteTable([]tcpip.Route{
197170
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},

stack_gvisor_err.go

Lines changed: 0 additions & 50 deletions
This file was deleted.

stack_gvisor_lazy.go

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
//go:build with_gvisor
2+
3+
package tun
4+
5+
import (
6+
"context"
7+
"errors"
8+
"net"
9+
"os"
10+
"sync"
11+
"syscall"
12+
"time"
13+
14+
"github.com/sagernet/gvisor/pkg/tcpip"
15+
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
16+
"github.com/sagernet/gvisor/pkg/tcpip/header"
17+
"github.com/sagernet/gvisor/pkg/tcpip/stack"
18+
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
19+
"github.com/sagernet/gvisor/pkg/waiter"
20+
)
21+
22+
type gLazyConn struct {
23+
tcpConn *gonet.TCPConn
24+
parentCtx context.Context
25+
stack *stack.Stack
26+
request *tcp.ForwarderRequest
27+
localAddr net.Addr
28+
remoteAddr net.Addr
29+
handshakeAccess sync.Mutex
30+
handshakeDone bool
31+
handshakeErr error
32+
}
33+
34+
func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
35+
if c.handshakeDone {
36+
return nil
37+
}
38+
c.handshakeAccess.Lock()
39+
defer c.handshakeAccess.Unlock()
40+
if c.handshakeDone {
41+
return nil
42+
}
43+
defer func() {
44+
c.handshakeDone = true
45+
}()
46+
var (
47+
wq waiter.Queue
48+
endpoint tcpip.Endpoint
49+
)
50+
handshakeCtx, cancel := context.WithCancel(ctx)
51+
go func() {
52+
select {
53+
case <-c.parentCtx.Done():
54+
wq.Notify(wq.Events())
55+
case <-handshakeCtx.Done():
56+
}
57+
}()
58+
endpoint, err := c.request.CreateEndpoint(&wq)
59+
cancel()
60+
if err != nil {
61+
gErr := gonet.TranslateNetstackError(err)
62+
c.handshakeErr = gErr
63+
c.request.Complete(true)
64+
return gErr
65+
}
66+
c.request.Complete(false)
67+
endpoint.SocketOptions().SetKeepAlive(true)
68+
keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
69+
endpoint.SetSockOpt(&keepAliveIdle)
70+
keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
71+
endpoint.SetSockOpt(&keepAliveInterval)
72+
tcpConn := gonet.NewTCPConn(&wq, endpoint)
73+
c.tcpConn = tcpConn
74+
return nil
75+
}
76+
77+
func (c *gLazyConn) HandshakeFailure(err error) error {
78+
if c.handshakeDone {
79+
return nil
80+
}
81+
c.handshakeAccess.Lock()
82+
defer c.handshakeAccess.Unlock()
83+
if c.handshakeDone {
84+
return nil
85+
}
86+
wErr := gWriteUnreachable(c.stack, c.request.Packet(), err)
87+
c.request.Complete(wErr == os.ErrInvalid)
88+
c.handshakeDone = true
89+
c.handshakeErr = err
90+
return nil
91+
}
92+
93+
func (c *gLazyConn) HandshakeSuccess() error {
94+
return c.HandshakeContext(context.Background())
95+
}
96+
97+
func (c *gLazyConn) Read(b []byte) (n int, err error) {
98+
if !c.handshakeDone {
99+
err = c.HandshakeContext(context.Background())
100+
if err != nil {
101+
return
102+
}
103+
} else if c.handshakeErr != nil {
104+
return 0, c.handshakeErr
105+
}
106+
return c.tcpConn.Read(b)
107+
}
108+
109+
func (c *gLazyConn) Write(b []byte) (n int, err error) {
110+
if !c.handshakeDone {
111+
err = c.HandshakeContext(context.Background())
112+
if err != nil {
113+
return
114+
}
115+
} else if c.handshakeErr != nil {
116+
return 0, c.handshakeErr
117+
}
118+
return c.tcpConn.Write(b)
119+
}
120+
121+
func (c *gLazyConn) LocalAddr() net.Addr {
122+
return c.localAddr
123+
}
124+
125+
func (c *gLazyConn) RemoteAddr() net.Addr {
126+
return c.remoteAddr
127+
}
128+
129+
func (c *gLazyConn) SetDeadline(t time.Time) error {
130+
if !c.handshakeDone {
131+
err := c.HandshakeContext(context.Background())
132+
if err != nil {
133+
return err
134+
}
135+
} else if c.handshakeErr != nil {
136+
return c.handshakeErr
137+
}
138+
return c.tcpConn.SetDeadline(t)
139+
}
140+
141+
func (c *gLazyConn) SetReadDeadline(t time.Time) error {
142+
if !c.handshakeDone {
143+
err := c.HandshakeContext(context.Background())
144+
if err != nil {
145+
return err
146+
}
147+
} else if c.handshakeErr != nil {
148+
return c.handshakeErr
149+
}
150+
return c.tcpConn.SetReadDeadline(t)
151+
}
152+
153+
func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
154+
if !c.handshakeDone {
155+
err := c.HandshakeContext(context.Background())
156+
if err != nil {
157+
return err
158+
}
159+
} else if c.handshakeErr != nil {
160+
return c.handshakeErr
161+
}
162+
return c.tcpConn.SetWriteDeadline(t)
163+
}
164+
165+
func (c *gLazyConn) Close() error {
166+
c.handshakeAccess.Lock()
167+
defer c.handshakeAccess.Unlock()
168+
if !c.handshakeDone {
169+
c.request.Complete(true)
170+
c.handshakeErr = net.ErrClosed
171+
return nil
172+
} else if c.handshakeErr != nil {
173+
return nil
174+
}
175+
return c.tcpConn.Close()
176+
}
177+
178+
func (c *gLazyConn) CloseRead() error {
179+
c.handshakeAccess.Lock()
180+
defer c.handshakeAccess.Unlock()
181+
if !c.handshakeDone {
182+
c.request.Complete(true)
183+
c.handshakeErr = net.ErrClosed
184+
return nil
185+
} else if c.handshakeErr != nil {
186+
return nil
187+
}
188+
return c.tcpConn.CloseRead()
189+
}
190+
191+
func (c *gLazyConn) CloseWrite() error {
192+
c.handshakeAccess.Lock()
193+
defer c.handshakeAccess.Unlock()
194+
if !c.handshakeDone {
195+
c.request.Complete(true)
196+
c.handshakeErr = net.ErrClosed
197+
return nil
198+
} else if c.handshakeErr != nil {
199+
return nil
200+
}
201+
return c.tcpConn.CloseRead()
202+
}
203+
204+
func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) error {
205+
if errors.Is(err, syscall.ENETUNREACH) {
206+
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
207+
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
208+
} else {
209+
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
210+
}
211+
} else if errors.Is(err, syscall.EHOSTUNREACH) {
212+
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
213+
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostProhibited)
214+
} else {
215+
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
216+
}
217+
} else if errors.Is(err, syscall.ECONNREFUSED) {
218+
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
219+
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
220+
} else {
221+
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
222+
}
223+
}
224+
return os.ErrInvalid
225+
}
226+
227+
func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error {
228+
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true))
229+
}
230+
231+
func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error {
232+
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true))
233+
}

0 commit comments

Comments
 (0)