Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support sslnegotiation flag #1180

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1113,20 +1113,24 @@ func (cn *conn) ssl(o values) error {
return nil
}

w := cn.writeBuf(0)
w.int32(80877103)
if err = cn.sendStartupPacket(w); err != nil {
return err
}
// only negotiate the ssl handshake if requested (which is the default).
// sllnegotiation=direct is supported by pg17 and above.
if sslnegotiation(o) {
w := cn.writeBuf(0)
w.int32(80877103)
if err = cn.sendStartupPacket(w); err != nil {
return err
}

b := cn.scratch[:1]
_, err = io.ReadFull(cn.c, b)
if err != nil {
return err
}
b := cn.scratch[:1]
_, err = io.ReadFull(cn.c, b)
if err != nil {
return err
}

if b[0] != 'S' {
return ErrSSLNotSupported
if b[0] != 'S' {
return ErrSSLNotSupported
}
}

cn.c, err = upgrade(cn.c)
Expand Down
11 changes: 11 additions & 0 deletions ssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,14 @@ func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error
_, err = certs[0].Verify(opts)
return err
}

// sslnegotiation returns true if we should negotiate SSL.
// returns false if there should be no negotiation and we should upgrade immediately.
func sslnegotiation(o values) bool {
if negotiation, ok := o["sslnegotiation"]; ok {
if negotiation == "direct" {
return false
}
}
return true
}
57 changes: 39 additions & 18 deletions ssl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,30 +308,49 @@ func TestSNISupport(t *testing.T) {
conn_param string
hostname string
expected_sni string
direct bool
}{
{
name: "SNI is set by default",
conn_param: "",
hostname: "localhost",
expected_sni: "localhost",
direct: false,
},
{
name: "SNI is passed when asked for",
conn_param: "sslsni=1",
hostname: "localhost",
expected_sni: "localhost",
direct: false,
},
{
name: "SNI is not passed when disabled",
conn_param: "sslsni=0",
hostname: "localhost",
expected_sni: "",
direct: false,
},
{
name: "SNI is not set for IPv4",
conn_param: "",
hostname: "127.0.0.1",
expected_sni: "",
direct: false,
},
{
name: "SNI is set for negotiated ssl",
conn_param: "sslnegotiation=postgres",
hostname: "localhost",
expected_sni: "localhost",
direct: false,
},
{
name: "SNI is set for direct ssl",
conn_param: "sslnegotiation=direct",
hostname: "localhost",
expected_sni: "localhost",
direct: true,
},
}
for _, tt := range tests {
Expand All @@ -346,7 +365,7 @@ func TestSNISupport(t *testing.T) {
}
serverErrChan := make(chan error, 1)
serverSNINameChan := make(chan string, 1)
go mockPostgresSSL(listener, serverErrChan, serverSNINameChan)
go mockPostgresSSL(listener, tt.direct, serverErrChan, serverSNINameChan)

defer listener.Close()
defer close(serverErrChan)
Expand Down Expand Up @@ -381,7 +400,7 @@ func TestSNISupport(t *testing.T) {
//
// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
// While reading clientHello catch passed SNI data and report it to nameChan.
func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) {
func mockPostgresSSL(listener net.Listener, direct bool, errChan chan error, nameChan chan string) {
var sniHost string

conn, err := listener.Accept()
Expand All @@ -397,23 +416,25 @@ func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan st
return
}

// Receive StartupMessage with SSL Request
startupMessage := make([]byte, 8)
if _, err := io.ReadFull(conn, startupMessage); err != nil {
errChan <- err
return
}
// StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
return
}
if !direct {
// Receive StartupMessage with SSL Request
startupMessage := make([]byte, 8)
if _, err := io.ReadFull(conn, startupMessage); err != nil {
errChan <- err
return
}
// StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
return
}

// Respond with SSLOk
_, err = conn.Write([]byte("S"))
if err != nil {
errChan <- err
return
// Respond with SSLOk
_, err = conn.Write([]byte("S"))
if err != nil {
errChan <- err
return
}
}

// Set up TLS context to catch clientHello. It will always error out during handshake
Expand Down