forked from Control-D-Inc/ctrld
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config_quic.go
160 lines (148 loc) · 4.07 KB
/
config_quic.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
package ctrld
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"runtime"
"sync"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
)
func (uc *UpstreamConfig) setupDOH3Transport() {
switch uc.IPStack {
case IpStackBoth, "":
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
case IpStackV4:
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
case IpStackV6:
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
case IpStackSplit:
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
if hasIPv6() {
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
} else {
uc.http3RoundTripper6 = uc.http3RoundTripper4
}
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
rt := &http3.RoundTripper{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
_, port, _ := net.SplitHostPort(addr)
// if we have a bootstrap ip set, use it to avoid DNS lookup
if uc.BootstrapIP != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", addr)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
}
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
pd := &quicParallelDialer{}
conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg)
if err != nil {
return nil, err
}
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) {
rt.CloseIdleConnections()
})
return rt
}
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
uc.transportOnce.Do(func() {
uc.SetupTransport()
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport()
}
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.http3RoundTripper
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.http3RoundTripper4
default:
return uc.http3RoundTripper6
}
}
return uc.http3RoundTripper
}
// Putting the code for quic parallel dialer here:
//
// - quic dialer is different with net.Dialer
// - simplification for quic free version
type parallelDialerResult struct {
conn quic.EarlyConnection
err error
}
type quicParallelDialer struct{}
// Dial performs parallel dialing to the given address list.
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
if len(addrs) == 0 {
return nil, errors.New("empty addresses")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
done := make(chan struct{})
defer close(done)
ch := make(chan *parallelDialerResult, len(addrs))
var wg sync.WaitGroup
wg.Add(len(addrs))
go func() {
wg.Wait()
close(ch)
}()
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
ch <- ¶llelDialerResult{conn: nil, err: err}
return
}
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
ch <- ¶llelDialerResult{conn: nil, err: err}
return
}
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
select {
case ch <- ¶llelDialerResult{conn: conn, err: err}:
case <-done:
if conn != nil {
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
}
if udpConn != nil {
udpConn.Close()
}
}
}(addr)
}
errs := make([]error, 0, len(addrs))
for res := range ch {
if res.err == nil {
cancel()
return res.conn, res.err
}
errs = append(errs, res.err)
}
return nil, errors.Join(errs...)
}