Skip to content

Commit

Permalink
server: add proxy protocol support
Browse files Browse the repository at this point in the history
This commit adds an `--accept-proxy-protocol-headers` flag to
CockroachDB startup commands. When set, it allows CockroachDB to parse
proxy protocol headers and use the client IP information therein.

Resolves #130706

Release note: None
  • Loading branch information
DuskEagle committed Sep 19, 2024
1 parent d8233a9 commit eda4f3e
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 9 deletions.
5 changes: 5 additions & 0 deletions pkg/base/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,11 @@ type Config struct {
// LocalityAddresses contains private IP addresses that can only be accessed
// in the corresponding locality.
LocalityAddresses []roachpb.LocalityAddress

// AcceptProxyProtocolHeaders allows CockroachDB to parse proxy protocol
// headers, and use the client IP information contained within instead of
// using the IP information in the source IP field of the incoming packets.
AcceptProxyProtocolHeaders bool
}

// AdvertiseAddr is the type of the AdvertiseAddr field in Config.
Expand Down
15 changes: 15 additions & 0 deletions pkg/cli/cliflags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,21 @@ apply. This flag is experimental.
`,
}

AcceptProxyProtocolHeaders = FlagInfo{
Name: "accept-proxy-protocol-headers",
Description: `
Allows CockroachDB to parse proxy protocol headers. Proxy protocol is used by
some proxies to retain the original client IP information after the proxy has
rewritten the source IP address of forwarded packets.
<PRE>
</PRE>
When using this flag, ensure all traffic to CockroachDB flows through a proxy
which adds proxy protocol headers, to prevent spoofing of client IP address
information.
`,
}

LocalityAdvertiseAddr = FlagInfo{
Name: "locality-advertise-addr",
Description: `
Expand Down
3 changes: 3 additions & 0 deletions pkg/cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ func init() {
cliflagcfg.VarFlag(f, addr.NewAddrSetter(&serverHTTPAddr, &serverHTTPPort), cliflags.ListenHTTPAddr)
cliflagcfg.VarFlag(f, addr.NewAddrSetter(&serverHTTPAdvertiseAddr, &serverHTTPAdvertisePort), cliflags.HTTPAdvertiseAddr)

cliflagcfg.BoolFlag(f, &serverCfg.AcceptProxyProtocolHeaders, cliflags.AcceptProxyProtocolHeaders)

// Certificates directory. Use a server-specific flag and value to ignore environment
// variables, but share the same default.
cliflagcfg.StringFlag(f, &startCtx.serverSSLCertsDir, cliflags.ServerCertsDir)
Expand All @@ -463,6 +465,7 @@ func init() {
_ = f.MarkHidden(cliflags.AdvertiseAddr.Name)
_ = f.MarkHidden(cliflags.SQLAdvertiseAddr.Name)
_ = f.MarkHidden(cliflags.HTTPAdvertiseAddr.Name)
_ = f.MarkHidden(cliflags.AcceptProxyProtocolHeaders.Name)
}

if cmd == startCmd || cmd == startSingleNodeCmd {
Expand Down
3 changes: 3 additions & 0 deletions pkg/server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ go_library(
"@com_github_marusama_semaphore//:semaphore",
"@com_github_nightlyone_lockfile//:lockfile",
"@com_github_nytimes_gziphandler//:gziphandler",
"@com_github_pires_go_proxyproto//:go-proxyproto",
"@com_github_prometheus_common//expfmt",
"@in_gopkg_yaml_v2//:yaml_v2",
"@org_golang_google_grpc//:go_default_library",
Expand Down Expand Up @@ -436,6 +437,7 @@ go_test(
"helpers_test.go",
"index_usage_stats_test.go",
"job_profiler_test.go",
"listen_and_update_addrs_test.go",
"load_endpoint_test.go",
"main_test.go",
"migration_test.go",
Expand Down Expand Up @@ -576,6 +578,7 @@ go_test(
"@com_github_grpc_ecosystem_grpc_gateway//runtime:go_default_library",
"@com_github_jackc_pgx_v4//:pgx",
"@com_github_kr_pretty//:pretty",
"@com_github_pires_go_proxyproto//:go-proxyproto",
"@com_github_prometheus_client_model//go",
"@com_github_prometheus_common//expfmt",
"@com_github_stretchr_testify//assert",
Expand Down
21 changes: 19 additions & 2 deletions pkg/server/listen_and_update_addrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/netutil/addr"
"github.com/cockroachdb/cockroach/pkg/util/sysutil"
"github.com/cockroachdb/errors"
"github.com/pires/go-proxyproto"
)

// ListenError is returned from Start when we fail to start listening on either
Expand Down Expand Up @@ -61,7 +62,10 @@ type rangeListenerFactory struct {
}

func (rlf *rangeListenerFactory) ListenAndUpdateAddrs(
ctx context.Context, listenAddr, advertiseAddr *string, connName string,
ctx context.Context,
listenAddr, advertiseAddr *string,
connName string,
acceptProxyProtocolHeaders bool,
) (net.Listener, error) {
h, _, err := addr.SplitHostPort(*listenAddr, "0")
if err != nil {
Expand All @@ -80,6 +84,11 @@ func (rlf *rangeListenerFactory) ListenAndUpdateAddrs(
nextAddr := net.JoinHostPort(h, strconv.Itoa(nextPort))
ln, err = net.Listen("tcp", nextAddr)
if err == nil {
if acceptProxyProtocolHeaders {
ln = &proxyproto.Listener{
Listener: ln,
}
}
if err := UpdateAddrs(ctx, listenAddr, advertiseAddr, ln.Addr()); err != nil {
return nil, errors.Wrapf(err, "internal error: cannot parse %s listen address", connName)
}
Expand All @@ -99,7 +108,10 @@ func (rlf *rangeListenerFactory) ListenAndUpdateAddrs(
// actual interface address resolved by the OS during the Listen()
// call.
func ListenAndUpdateAddrs(
ctx context.Context, addr, advertiseAddr *string, connName string,
ctx context.Context,
addr, advertiseAddr *string,
connName string,
acceptProxyProtocolHeaders bool,
) (net.Listener, error) {
ln, err := net.Listen("tcp", *addr)
if err != nil {
Expand All @@ -108,6 +120,11 @@ func ListenAndUpdateAddrs(
Addr: *addr,
}
}
if acceptProxyProtocolHeaders {
ln = &proxyproto.Listener{
Listener: ln,
}
}
if err := UpdateAddrs(ctx, addr, advertiseAddr, ln.Addr()); err != nil {
return nil, errors.Wrapf(err, "internal error: cannot parse %s listen address", connName)
}
Expand Down
90 changes: 90 additions & 0 deletions pkg/server/listen_and_update_addrs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2024 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package server

import (
"context"
"net"
"sync"
"testing"

"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/pires/go-proxyproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestListenAndUpdateAddrs(t *testing.T) {
defer leaktest.AfterTest(t)()

ctx := context.Background()

t.Run("don't accept proxy protocol headers", func(t *testing.T) {
addr := "127.0.0.1:0"
advertiseAddr := "127.0.0.1:0"
ln, err := ListenAndUpdateAddrs(ctx, &addr, &advertiseAddr, "sql", false)
require.NoError(t, err)
require.NotNil(t, ln)
_, addrPort, err := net.SplitHostPort(addr)
require.NoError(t, err)
require.NotZero(t, addrPort)
_, advertiseAddrPort, err := net.SplitHostPort(addr)
require.NoError(t, err)
require.NotZero(t, advertiseAddrPort)
require.NoError(t, ln.Close())
})

t.Run("accept proxy protocol headers", func(t *testing.T) {
addr := "127.0.0.1:0"
advertiseAddr := "127.0.0.1:0"
proxyLn, err := ListenAndUpdateAddrs(ctx, &addr, &advertiseAddr, "sql", true)
require.NoError(t, err)
require.NotNil(t, proxyLn)

proxyLn, ok := proxyLn.(*proxyproto.Listener)
require.True(t, ok)

sourceAddr := &net.TCPAddr{
IP: net.ParseIP("10.20.30.40").To4(),
Port: 4242,
}

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
conn, err := proxyLn.Accept()
require.NoError(t, err)

header := conn.(*proxyproto.Conn).ProxyHeader()
assert.NotNil(t, header)
assert.Equal(t, sourceAddr, header.SourceAddr)
}()

conn, err := net.Dial("tcp", addr)
require.NoError(t, err)
defer conn.Close()

header := &proxyproto.Header{
Version: 2,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: sourceAddr,
DestinationAddr: conn.RemoteAddr(),
}
_, err = header.WriteTo(conn)
require.NoError(t, err)
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)

wg.Wait()
})
}
2 changes: 1 addition & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,7 @@ func (s *topLevelServer) PreStart(ctx context.Context) error {
// below when the server has initialized.
pgL, loopbackPgL, rpcLoopbackDialFn, startRPCServer, err := startListenRPCAndSQL(
ctx, workersCtx, s.cfg.BaseConfig,
s.stopper, s.grpc, ListenAndUpdateAddrs, true /* enableSQLListener */)
s.stopper, s.grpc, ListenAndUpdateAddrs, true /* enableSQLListener */, s.cfg.AcceptProxyProtocolHeaders)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/server_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func startHTTPService(
stopper *stop.Stopper,
handler http.HandlerFunc,
) error {
httpLn, err := ListenAndUpdateAddrs(ctx, &cfg.HTTPAddr, &cfg.HTTPAdvertiseAddr, "http")
httpLn, err := ListenAndUpdateAddrs(ctx, &cfg.HTTPAddr, &cfg.HTTPAdvertiseAddr, "http", cfg.AcceptProxyProtocolHeaders)
if err != nil {
return err
}
Expand Down
12 changes: 9 additions & 3 deletions pkg/server/start_listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ import (
"github.com/cockroachdb/errors"
)

type RPCListenerFactory func(ctx context.Context, addr, advertiseAddr *string, connName string) (net.Listener, error)
type RPCListenerFactory func(
ctx context.Context,
addr, advertiseAddr *string,
connName string,
acceptProxyProtocolHeaders bool,
) (net.Listener, error)

// startListenRPCAndSQL starts the RPC and SQL listeners. It returns:
// - The listener for pgwire connections coming over the network. This will be used
Expand All @@ -43,6 +48,7 @@ func startListenRPCAndSQL(
grpc *grpcServer,
rpcListenerFactory RPCListenerFactory,
enableSQLListener bool,
acceptProxyProtocolHeaders bool,
) (
sqlListener net.Listener,
pgLoopbackListener *netutil.LoopbackListener,
Expand All @@ -61,7 +67,7 @@ func startListenRPCAndSQL(
}
if ln == nil {
var err error
ln, err = rpcListenerFactory(ctx, &cfg.Addr, &cfg.AdvertiseAddr, rpcChanName)
ln, err = rpcListenerFactory(ctx, &cfg.Addr, &cfg.AdvertiseAddr, rpcChanName, acceptProxyProtocolHeaders)
if err != nil {
return nil, nil, nil, nil, err
}
Expand All @@ -71,7 +77,7 @@ func startListenRPCAndSQL(
var pgL net.Listener
if cfg.SplitListenSQL && enableSQLListener {
if cfg.SQLAddrListener == nil {
pgL, err = ListenAndUpdateAddrs(ctx, &cfg.SQLAddr, &cfg.SQLAdvertiseAddr, "sql")
pgL, err = ListenAndUpdateAddrs(ctx, &cfg.SQLAddr, &cfg.SQLAdvertiseAddr, "sql", acceptProxyProtocolHeaders)
} else {
pgL = cfg.SQLAddrListener
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/tenant.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ func newTenantServer(
// (until the server is ready) won't cause client connections to be rejected.
if baseCfg.SplitListenSQL && !baseCfg.DisableSQLListener {
sqlAddrListener, err := ListenAndUpdateAddrs(
ctx, &baseCfg.SQLAddr, &baseCfg.SQLAdvertiseAddr, "sql")
ctx, &baseCfg.SQLAddr, &baseCfg.SQLAdvertiseAddr, "sql", baseCfg.AcceptProxyProtocolHeaders)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -590,7 +590,7 @@ func (s *SQLServerWrapper) PreStart(ctx context.Context) error {
lf = s.sqlServer.cfg.RPCListenerFactory
}

pgL, loopbackPgL, rpcLoopbackDialFn, startRPCServer, err := startListenRPCAndSQL(ctx, workersCtx, *s.sqlServer.cfg, s.stopper, s.grpc, lf, enableSQLListener)
pgL, loopbackPgL, rpcLoopbackDialFn, startRPCServer, err := startListenRPCAndSQL(ctx, workersCtx, *s.sqlServer.cfg, s.stopper, s.grpc, lf, enableSQLListener, s.cfg.AcceptProxyProtocolHeaders)
if err != nil {
return err
}
Expand Down

0 comments on commit eda4f3e

Please sign in to comment.