-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathftl_auth_proxy.go
164 lines (155 loc) · 5.54 KB
/
ftl_auth_proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package mysqlauthproxy
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"sync"
)
type Proxy struct {
host string
port int
cfgFn func(context.Context) (*Config, error)
logger Logger
portBinding chan int
}
func NewProxy(host string, port int, cfgFn func(context.Context) (*Config, error), logger Logger, portBinding chan int) *Proxy {
if logger == nil {
logger = defaultLogger
}
return &Proxy{
host: host,
port: port,
cfgFn: cfgFn,
logger: logger,
portBinding: portBinding,
}
}
// ListenAndServe listens on the TCP network address for incoming connections
// It will not return until the context is done.
func (p *Proxy) ListenAndServe(ctx context.Context) error {
socket, err := net.Listen("tcp", fmt.Sprintf("%s:%d", p.host, p.port))
if err != nil {
return err
}
cfg, err := p.cfgFn(ctx)
if err != nil {
return err
}
if p.portBinding != nil {
p.portBinding <- socket.Addr().(*net.TCPAddr).Port
}
for {
con, err := socket.Accept()
if err != nil {
p.logger.Print("failed to accept connection", err)
continue
}
go p.handleConnection(ctx, con, cfg)
select {
case <-ctx.Done():
return nil
default:
}
}
}
func (p *Proxy) handleConnection(ctx context.Context, con net.Conn, cfg *Config) {
defer con.Close()
backend, err := connectToBackend(ctx, cfg)
if err != nil {
p.logger.Print("failed to connect to backend", err.Error())
return
}
defer backend.Close()
err = writeServerHandshakePacket(con, backend)
if err != nil {
p.logger.Print("failed to write server handshake packet", err.Error())
return
}
// Now we need to read a client handshake packet
// We re-use the mysqlConn struct to parse the packet
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: cfg,
buf: newBuffer(con),
netConn: con,
rawConn: con,
sequence: 1,
}
mc.parseTime = mc.cfg.ParseTime
_, err = mc.readPacket()
if err != nil {
p.logger.Print("failed to read server handshake packet", err.Error())
return
}
// Now we are authenticated, send an OK packet
err = mc.writePacket([]byte{0, 0, 0, 0, 0, 0, 0})
if err != nil {
p.logger.Print(err.Error())
return
}
wg := sync.WaitGroup{}
wg.Add(2)
// All good, lets start proxying bytes
go func() {
defer wg.Done()
_, err := io.Copy(backend.netConn, con)
if err != nil {
p.logger.Print("failed to copy from client to backend: %s", err.Error())
}
}()
go func() {
defer wg.Done()
_, err = io.Copy(con, backend.netConn)
if err != nil {
p.logger.Print("failed to copy from backend to the client %s", err.Error())
}
}()
// We only return after both copies are done, to allow for half closed connections
// I am not sure if they are possible in the mysql protocol, but better safe than sorry
wg.Wait()
}
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formatted
func connectToBackend(ctx context.Context, cfg *Config) (*mysqlConn, error) {
c := newConnector(cfg)
conn, err := c.Connect(ctx)
if err != nil {
return nil, err
}
if mc, ok := conn.(*mysqlConn); ok {
// We have an authenticated connection
// Now we need to start proxying
return mc, nil
}
_ = conn.Close()
return nil, fmt.Errorf("failed to connect to backend")
}
func writeServerHandshakePacket(writer io.Writer, backendConnection *mysqlConn) error {
backendConnection.flags |= clientSecureConn
backendConnection.flags |= clientPluginAuth
backendConnection.flags &= ^clientSSL
var toWrite []byte
toWrite = append(toWrite, append([]byte("8.1.0"), 0)...) // Null terminated server version string, hard coded for now, this is very crap
toWrite = append(toWrite, 1, 0, 0, 0) // Connection id Int<4>
toWrite = append(toWrite, 1, 2, 3, 4, 5, 6, 7, 8) // String[8] auth-plugin-data-part-1 first 8 bytes of the plugin provided data (scramble)
toWrite = append(toWrite, 0) // Filler
toWrite = binary.LittleEndian.AppendUint16(toWrite, uint16(backendConnection.flags&0xFFFF)) // Capability flags (lower 2 bytes)
toWrite = append(toWrite, 0) // Character set
toWrite = append(toWrite, 0, 0) // Status flags
toWrite = binary.LittleEndian.AppendUint16(toWrite, uint16(((backendConnection.flags)>>16)&0xFFFF)) // Capability flags (upper 2 bytes)
toWrite = append(toWrite, 20) // Auth plugin data length
toWrite = append(toWrite, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) // reserved
toWrite = append(toWrite, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0) // 13 bytes of auth data? Last byte is the null terminator
toWrite = append(toWrite, []byte("caching_sha2_password")...) // auth method, any password is accepted
toWrite = append([]byte{byte(len(toWrite))}, toWrite...) // Length of the packet
pktlen := len(toWrite)
sizeData := []byte{byte(pktlen), byte(pktlen >> 8), byte(pktlen >> 16), 0}
toWrite = append(sizeData, toWrite...)
_, err := writer.Write(toWrite)
return err
}