diff --git a/pkg/base/config.go b/pkg/base/config.go index 0f45c0f36b8e..b6a3d1b4679c 100644 --- a/pkg/base/config.go +++ b/pkg/base/config.go @@ -438,6 +438,11 @@ type Config struct { // LocalityAddresses contains private IP addresses that can only be accessed // in the corresponding locality. LocalityAddresses []roachpb.LocalityAddress + + // AcceptProxyProtocolHeaders allows CockroachDB to parse proxy protocol + // headers, and use the client IP information contained within instead of + // using the IP information in the source IP field of the incoming packets. + AcceptProxyProtocolHeaders bool } // AdvertiseAddr is the type of the AdvertiseAddr field in Config. diff --git a/pkg/ccl/backupccl/backup_job.go b/pkg/ccl/backupccl/backup_job.go index f6ba5d1aa8ed..09fc9eb577a6 100644 --- a/pkg/ccl/backupccl/backup_job.go +++ b/pkg/ccl/backupccl/backup_job.go @@ -617,6 +617,21 @@ func (b *backupResumer) Resume(ctx context.Context, execCtx interface{}) error { defaultURI := details.URI var backupDest backupdest.ResolvedDestination if details.URI == "" { + // Choose which scheduled backup pts we will update at the the end of the + // backup _before_ we resolve the destination of the backup. This avoids a + // race with inc backups where backup destination resolution leads this backup + // to extend a chain that is about to be superseded by a new full backup + // chain, which could cause this inc to accidentally push the pts for the + // _new_ chain instead of the old chain it is apart of. By choosing the pts to + // move before we resolve the destination, we guarantee that we push the old + // chain. + insqlDB := p.ExecCfg().InternalDB + if err := insqlDB.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { + return planSchedulePTSChaining(ctx, p.ExecCfg().JobsKnobs(), txn, &details, b.job.CreatedBy()) + }); err != nil { + return err + } + var err error backupDest, err = backupdest.ResolveDest(ctx, p.User(), details.Destination, details.EndTime, details.IncrementalFrom, p.ExecCfg()) @@ -727,12 +742,6 @@ func (b *backupResumer) Resume(ctx context.Context, execCtx interface{}) error { return err } - if err := insqlDB.Txn(ctx, func(ctx context.Context, txn isql.Txn) error { - return planSchedulePTSChaining(ctx, p.ExecCfg().JobsKnobs(), txn, &details, b.job.CreatedBy()) - }); err != nil { - return err - } - // The description picked during original planning might still say "LATEST", // if resolving that to the actual directory only just happened above here. // Ideally we'd re-render the description now that we know the subdir, but diff --git a/pkg/cli/cliflags/flags.go b/pkg/cli/cliflags/flags.go index 79699fe98032..a254db0eac25 100644 --- a/pkg/cli/cliflags/flags.go +++ b/pkg/cli/cliflags/flags.go @@ -694,6 +694,21 @@ apply. This flag is experimental. `, } + AcceptProxyProtocolHeaders = FlagInfo{ + Name: "accept-proxy-protocol-headers", + Description: ` +Allows CockroachDB to parse proxy protocol headers. Proxy protocol is used by +some proxies to retain the original client IP information after the proxy has +rewritten the source IP address of forwarded packets. +
+
+
+When using this flag, ensure all traffic to CockroachDB flows through a proxy +which adds proxy protocol headers, to prevent spoofing of client IP address +information. +`, + } + LocalityAdvertiseAddr = FlagInfo{ Name: "locality-advertise-addr", Description: ` diff --git a/pkg/cli/flags.go b/pkg/cli/flags.go index 092161be69cb..9d202cf57382 100644 --- a/pkg/cli/flags.go +++ b/pkg/cli/flags.go @@ -450,6 +450,8 @@ func init() { cliflagcfg.VarFlag(f, addr.NewAddrSetter(&serverHTTPAddr, &serverHTTPPort), cliflags.ListenHTTPAddr) cliflagcfg.VarFlag(f, addr.NewAddrSetter(&serverHTTPAdvertiseAddr, &serverHTTPAdvertisePort), cliflags.HTTPAdvertiseAddr) + cliflagcfg.BoolFlag(f, &serverCfg.AcceptProxyProtocolHeaders, cliflags.AcceptProxyProtocolHeaders) + // Certificates directory. Use a server-specific flag and value to ignore environment // variables, but share the same default. cliflagcfg.StringFlag(f, &startCtx.serverSSLCertsDir, cliflags.ServerCertsDir) @@ -463,6 +465,7 @@ func init() { _ = f.MarkHidden(cliflags.AdvertiseAddr.Name) _ = f.MarkHidden(cliflags.SQLAdvertiseAddr.Name) _ = f.MarkHidden(cliflags.HTTPAdvertiseAddr.Name) + _ = f.MarkHidden(cliflags.AcceptProxyProtocolHeaders.Name) } if cmd == startCmd || cmd == startSingleNodeCmd { diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index a19f70bcdd33..1d37830494ec 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -368,6 +368,7 @@ go_library( "@com_github_marusama_semaphore//:semaphore", "@com_github_nightlyone_lockfile//:lockfile", "@com_github_nytimes_gziphandler//:gziphandler", + "@com_github_pires_go_proxyproto//:go-proxyproto", "@com_github_prometheus_common//expfmt", "@in_gopkg_yaml_v2//:yaml_v2", "@org_golang_google_grpc//:go_default_library", @@ -440,6 +441,7 @@ go_test( "helpers_test.go", "index_usage_stats_test.go", "job_profiler_test.go", + "listen_and_update_addrs_test.go", "load_endpoint_test.go", "main_test.go", "migration_test.go", @@ -583,6 +585,7 @@ go_test( "@com_github_grpc_ecosystem_grpc_gateway//runtime:go_default_library", "@com_github_jackc_pgx_v4//:pgx", "@com_github_kr_pretty//:pretty", + "@com_github_pires_go_proxyproto//:go-proxyproto", "@com_github_prometheus_client_model//go", "@com_github_prometheus_common//expfmt", "@com_github_stretchr_testify//assert", diff --git a/pkg/server/listen_and_update_addrs.go b/pkg/server/listen_and_update_addrs.go index d19de34349df..c0e298c7ce6e 100644 --- a/pkg/server/listen_and_update_addrs.go +++ b/pkg/server/listen_and_update_addrs.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" "github.com/cockroachdb/cockroach/pkg/util/sysutil" "github.com/cockroachdb/errors" + "github.com/pires/go-proxyproto" ) // ListenError is returned from Start when we fail to start listening on either @@ -61,7 +62,10 @@ type rangeListenerFactory struct { } func (rlf *rangeListenerFactory) ListenAndUpdateAddrs( - ctx context.Context, listenAddr, advertiseAddr *string, connName string, + ctx context.Context, + listenAddr, advertiseAddr *string, + connName string, + acceptProxyProtocolHeaders bool, ) (net.Listener, error) { h, _, err := addr.SplitHostPort(*listenAddr, "0") if err != nil { @@ -80,6 +84,11 @@ func (rlf *rangeListenerFactory) ListenAndUpdateAddrs( nextAddr := net.JoinHostPort(h, strconv.Itoa(nextPort)) ln, err = net.Listen("tcp", nextAddr) if err == nil { + if acceptProxyProtocolHeaders { + ln = &proxyproto.Listener{ + Listener: ln, + } + } if err := UpdateAddrs(ctx, listenAddr, advertiseAddr, ln.Addr()); err != nil { return nil, errors.Wrapf(err, "internal error: cannot parse %s listen address", connName) } @@ -99,7 +108,10 @@ func (rlf *rangeListenerFactory) ListenAndUpdateAddrs( // actual interface address resolved by the OS during the Listen() // call. func ListenAndUpdateAddrs( - ctx context.Context, addr, advertiseAddr *string, connName string, + ctx context.Context, + addr, advertiseAddr *string, + connName string, + acceptProxyProtocolHeaders bool, ) (net.Listener, error) { ln, err := net.Listen("tcp", *addr) if err != nil { @@ -108,6 +120,11 @@ func ListenAndUpdateAddrs( Addr: *addr, } } + if acceptProxyProtocolHeaders { + ln = &proxyproto.Listener{ + Listener: ln, + } + } if err := UpdateAddrs(ctx, addr, advertiseAddr, ln.Addr()); err != nil { return nil, errors.Wrapf(err, "internal error: cannot parse %s listen address", connName) } diff --git a/pkg/server/listen_and_update_addrs_test.go b/pkg/server/listen_and_update_addrs_test.go new file mode 100644 index 000000000000..eb76941d024e --- /dev/null +++ b/pkg/server/listen_and_update_addrs_test.go @@ -0,0 +1,90 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package server + +import ( + "context" + "net" + "sync" + "testing" + + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/pires/go-proxyproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestListenAndUpdateAddrs(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + + t.Run("don't accept proxy protocol headers", func(t *testing.T) { + addr := "127.0.0.1:0" + advertiseAddr := "127.0.0.1:0" + ln, err := ListenAndUpdateAddrs(ctx, &addr, &advertiseAddr, "sql", false) + require.NoError(t, err) + require.NotNil(t, ln) + _, addrPort, err := net.SplitHostPort(addr) + require.NoError(t, err) + require.NotZero(t, addrPort) + _, advertiseAddrPort, err := net.SplitHostPort(addr) + require.NoError(t, err) + require.NotZero(t, advertiseAddrPort) + require.NoError(t, ln.Close()) + }) + + t.Run("accept proxy protocol headers", func(t *testing.T) { + addr := "127.0.0.1:0" + advertiseAddr := "127.0.0.1:0" + proxyLn, err := ListenAndUpdateAddrs(ctx, &addr, &advertiseAddr, "sql", true) + require.NoError(t, err) + require.NotNil(t, proxyLn) + + proxyLn, ok := proxyLn.(*proxyproto.Listener) + require.True(t, ok) + + sourceAddr := &net.TCPAddr{ + IP: net.ParseIP("10.20.30.40").To4(), + Port: 4242, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + conn, err := proxyLn.Accept() + require.NoError(t, err) + + header := conn.(*proxyproto.Conn).ProxyHeader() + assert.NotNil(t, header) + assert.Equal(t, sourceAddr, header.SourceAddr) + }() + + conn, err := net.Dial("tcp", addr) + require.NoError(t, err) + defer conn.Close() + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: sourceAddr, + DestinationAddr: conn.RemoteAddr(), + } + _, err = header.WriteTo(conn) + require.NoError(t, err) + _, err = conn.Write([]byte("ping")) + require.NoError(t, err) + + wg.Wait() + }) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 14a6dd8ef334..fe40a665a0dc 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1589,7 +1589,7 @@ func (s *topLevelServer) PreStart(ctx context.Context) error { // below when the server has initialized. pgL, loopbackPgL, rpcLoopbackDialFn, startRPCServer, err := startListenRPCAndSQL( ctx, workersCtx, s.cfg.BaseConfig, - s.stopper, s.grpc, ListenAndUpdateAddrs, true /* enableSQLListener */) + s.stopper, s.grpc, ListenAndUpdateAddrs, true /* enableSQLListener */, s.cfg.AcceptProxyProtocolHeaders) if err != nil { return err } diff --git a/pkg/server/server_http.go b/pkg/server/server_http.go index 280c12d2b0c8..f3c0f70662a8 100644 --- a/pkg/server/server_http.go +++ b/pkg/server/server_http.go @@ -279,7 +279,7 @@ func startHTTPService( stopper *stop.Stopper, handler http.HandlerFunc, ) error { - httpLn, err := ListenAndUpdateAddrs(ctx, &cfg.HTTPAddr, &cfg.HTTPAdvertiseAddr, "http") + httpLn, err := ListenAndUpdateAddrs(ctx, &cfg.HTTPAddr, &cfg.HTTPAdvertiseAddr, "http", cfg.AcceptProxyProtocolHeaders) if err != nil { return err } diff --git a/pkg/server/start_listen.go b/pkg/server/start_listen.go index bb7dbd8da4f1..eb5e252612b7 100644 --- a/pkg/server/start_listen.go +++ b/pkg/server/start_listen.go @@ -25,7 +25,12 @@ import ( "github.com/cockroachdb/errors" ) -type RPCListenerFactory func(ctx context.Context, addr, advertiseAddr *string, connName string) (net.Listener, error) +type RPCListenerFactory func( + ctx context.Context, + addr, advertiseAddr *string, + connName string, + acceptProxyProtocolHeaders bool, +) (net.Listener, error) // startListenRPCAndSQL starts the RPC and SQL listeners. It returns: // - The listener for pgwire connections coming over the network. This will be used @@ -43,6 +48,7 @@ func startListenRPCAndSQL( grpc *grpcServer, rpcListenerFactory RPCListenerFactory, enableSQLListener bool, + acceptProxyProtocolHeaders bool, ) ( sqlListener net.Listener, pgLoopbackListener *netutil.LoopbackListener, @@ -61,7 +67,7 @@ func startListenRPCAndSQL( } if ln == nil { var err error - ln, err = rpcListenerFactory(ctx, &cfg.Addr, &cfg.AdvertiseAddr, rpcChanName) + ln, err = rpcListenerFactory(ctx, &cfg.Addr, &cfg.AdvertiseAddr, rpcChanName, acceptProxyProtocolHeaders) if err != nil { return nil, nil, nil, nil, err } @@ -71,7 +77,7 @@ func startListenRPCAndSQL( var pgL net.Listener if cfg.SplitListenSQL && enableSQLListener { if cfg.SQLAddrListener == nil { - pgL, err = ListenAndUpdateAddrs(ctx, &cfg.SQLAddr, &cfg.SQLAdvertiseAddr, "sql") + pgL, err = ListenAndUpdateAddrs(ctx, &cfg.SQLAddr, &cfg.SQLAdvertiseAddr, "sql", acceptProxyProtocolHeaders) } else { pgL = cfg.SQLAddrListener } diff --git a/pkg/server/tenant.go b/pkg/server/tenant.go index 402ed52934b5..0f516405b606 100644 --- a/pkg/server/tenant.go +++ b/pkg/server/tenant.go @@ -286,7 +286,7 @@ func newTenantServer( // (until the server is ready) won't cause client connections to be rejected. if baseCfg.SplitListenSQL && !baseCfg.DisableSQLListener { sqlAddrListener, err := ListenAndUpdateAddrs( - ctx, &baseCfg.SQLAddr, &baseCfg.SQLAdvertiseAddr, "sql") + ctx, &baseCfg.SQLAddr, &baseCfg.SQLAdvertiseAddr, "sql", baseCfg.AcceptProxyProtocolHeaders) if err != nil { return nil, err } @@ -599,7 +599,7 @@ func (s *SQLServerWrapper) PreStart(ctx context.Context) error { lf = s.sqlServer.cfg.RPCListenerFactory } - pgL, loopbackPgL, rpcLoopbackDialFn, startRPCServer, err := startListenRPCAndSQL(ctx, workersCtx, *s.sqlServer.cfg, s.stopper, s.grpc, lf, enableSQLListener) + pgL, loopbackPgL, rpcLoopbackDialFn, startRPCServer, err := startListenRPCAndSQL(ctx, workersCtx, *s.sqlServer.cfg, s.stopper, s.grpc, lf, enableSQLListener, s.cfg.AcceptProxyProtocolHeaders) if err != nil { return err }