From 954f000b379a4880cdc24e6da0bbf48f012683fd Mon Sep 17 00:00:00 2001 From: Andrew Boyle Date: Thu, 9 Jan 2025 13:54:05 -0600 Subject: [PATCH] add headerbp middleware to httpbp --- httpbp/client_middlewares.go | 29 ++++ httpbp/config.go | 1 + httpbp/middlewares.go | 27 ++++ httpbp/server_test.go | 274 +++++++++++++++++++++++++++++++++++ 4 files changed, 331 insertions(+) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index b02948e78..16185ce67 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -15,6 +15,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/reddit/baseplate.go/breakerbp" + "github.com/reddit/baseplate.go/internal/headerbp" //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/retrybp" @@ -88,6 +89,11 @@ func NewClient(config ClientConfig, middleware ...ClientMiddleware) (*http.Clien if config.CircuitBreaker != nil { defaults = append([]ClientMiddleware{CircuitBreaker(*config.CircuitBreaker)}, defaults...) } + + // only add the middleware to forward baseplate headers if the client is for internal calls + if config.InternalOnly { + defaults = append(defaults, ForwardBaseplateHeaders(config.Slug)) + } middleware = append(middleware, defaults...) return &http.Client{ @@ -349,3 +355,26 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware { }) } } + +// ForwardBaseplateHeaders is a middleware that forwards baseplate headers from the context to the outgoing request. +// +// If it detects any new baseplate headers set on the request, it will reject the request and return an error. +func ForwardBaseplateHeaders(client string) ClientMiddleware { + return func(next http.RoundTripper) http.RoundTripper { + return roundTripperFunc(func(req *http.Request) (*http.Response, error) { + for k := range req.Header { + if err := headerbp.CheckClientHeader(k, + headerbp.WithHTTPClient("", client, ""), + ); err != nil { + return nil, err + } + } + headerbp.SetOutgoingHeaders( + req.Context(), + headerbp.WithHTTPClient("", client, ""), + headerbp.WithHeaderSetter(req.Header.Set), + ) + return next.RoundTrip(req) + }) + } +} diff --git a/httpbp/config.go b/httpbp/config.go index dd7fc0603..f43e16966 100644 --- a/httpbp/config.go +++ b/httpbp/config.go @@ -16,6 +16,7 @@ type ClientConfig struct { MaxConnections int `yaml:"maxConnections"` CircuitBreaker *breakerbp.Config `yaml:"circuitBreaker"` RetryOptions []retry.Option + InternalOnly bool } // Validate checks ClientConfig for any missing or erroneous values. diff --git a/httpbp/middlewares.go b/httpbp/middlewares.go index 2d4b7ea88..fe5af2baf 100644 --- a/httpbp/middlewares.go +++ b/httpbp/middlewares.go @@ -15,6 +15,7 @@ import ( "github.com/reddit/baseplate.go/ecinterface" "github.com/reddit/baseplate.go/errorsbp" + "github.com/reddit/baseplate.go/internal/headerbp" //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/log" @@ -517,3 +518,29 @@ func (rr *responseRecorder) WriteHeader(code int) { rr.ResponseWriter.WriteHeader(code) rr.responseCode = code } + +// ExtractBaseplateHeaders is a middleware that extracts baseplate headers from the incoming request and adds them to the context. +func ExtractBaseplateHeaders(service string) Middleware { + return func(name string, next HandlerFunc) HandlerFunc { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) { + if r.Header.Get(headerbp.IsUntrustedRequestHeaderCanonicalHTTP) != "" { + for k := range r.Header { + if headerbp.IsBaseplateHeader(k) { + r.Header.Del(k) + } + } + return next(ctx, w, r) + } + headers := headerbp.NewIncomingHeaders( + headerbp.WithHTTPService(service, name), + ) + for k, v := range r.Header { + if len(v) > 0 { + headers.RecordHeader(k, v[0]) + } + } + ctx = headers.SetOnContext(ctx) + return next(ctx, w, r.WithContext(ctx)) + } + } +} diff --git a/httpbp/server_test.go b/httpbp/server_test.go index a5a278465..0f61b5d0a 100644 --- a/httpbp/server_test.go +++ b/httpbp/server_test.go @@ -7,13 +7,19 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/reddit/baseplate.go" "github.com/reddit/baseplate.go/ecinterface" "github.com/reddit/baseplate.go/httpbp" + "github.com/reddit/baseplate.go/internal/headerbp" "github.com/reddit/baseplate.go/log" + "github.com/reddit/baseplate.go/secrets" ) func TestEndpoint(t *testing.T) { @@ -427,3 +433,271 @@ func TestPanicRecovery(t *testing.T) { t.Fatalf("unexpected service code") } } + +func TestBaseplateHeaderPropagation(t *testing.T) { + expectedHeaders := map[string][]string{ + "x-bp-from-edge": {"true"}, + "x-bp-test": {"foo"}, + } + store, _, err := secrets.NewTestSecrets(context.TODO(), nil) + if err != nil { + t.Fatalf("failed to create test secrets: %v", err) + } + t.Cleanup(func() { + store.Close() + }) + bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{ + Config: baseplate.Config{ + Addr: ":8081", + }, + Store: store, + EdgeContextImpl: ecinterface.Mock(), + }) + downstreamServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{ + Baseplate: bp, + Endpoints: map[httpbp.Pattern]httpbp.Endpoint{ + "/say-hello": { + Name: "say-hello", + Methods: []string{http.MethodGet}, + Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { + for wantKey, wantValue := range expectedHeaders { + if v := request.Header.Values(wantKey); len(v) == 0 { + t.Fatalf("missing header %q", wantKey) + } else if diff := cmp.Diff(v, wantValue, cmpopts.SortSlices(func(a, b string) bool { + return a < b + })); diff != "" { + t.Fatalf("header %q values mismatch (-want +got):\n%s", wantKey, diff) + } + } + return nil + }, + }, + }, + Middlewares: []httpbp.Middleware{ + httpbp.ExtractBaseplateHeaders("originHTTPBPV0"), + }, + }) + if err != nil { + t.Fatalf("failed to create test downstreamServer: %v", err) + } + t.Cleanup(func() { + downstreamServer.Close() + }) + go downstreamServer.Serve() + + downstreamBaseURL, err := url.Parse("http://" + downstreamServer.Baseplate().GetConfig().Addr + "/") + if err != nil { + t.Fatalf("failed to parse test originServer base URL: %v", err) + } + + downstreamClient, err := httpbp.NewClient( + httpbp.ClientConfig{ + Slug: "downstreamHTTPBPV0", + InternalOnly: true, + }, + withBaseURL(downstreamBaseURL), + ) + if err != nil { + t.Fatalf("failed to create test client: %v", err) + } + + originServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{ + Baseplate: bp, + Endpoints: map[httpbp.Pattern]httpbp.Endpoint{ + "/say-hello": { + Name: "say-hello", + Methods: []string{http.MethodGet}, + Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { + for wantKey, wantValue := range expectedHeaders { + if v := request.Header.Values(wantKey); len(v) == 0 { + t.Fatalf("missing header %q", wantKey) + } else if diff := cmp.Diff(v, wantValue, cmpopts.SortSlices(func(a, b string) bool { + return a < b + })); diff != "" { + t.Fatalf("header %q values mismatch (-want +got):\n%s", wantKey, diff) + } + } + + req, err := http.NewRequest( + http.MethodGet, + downstreamBaseURL.JoinPath("say-hello").String(), + nil, + ) + if err != nil { + t.Fatalf("creating request: %v", err) + } + + resp, err := downstreamClient.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + + invalidReq, err := http.NewRequest( + http.MethodGet, + downstreamBaseURL.JoinPath("say-hello").String(), + nil, + ) + if err != nil { + t.Fatalf("creating request: %v", err) + } + invalidReq.Header.Set("x-bp-test", "bar") + + if _, err := downstreamClient.Do(req); !errors.Is(err, headerbp.ErrNewInternalHeaderNotAllowed) { + t.Fatalf("error mismatch, want %v, got %v", headerbp.ErrNewInternalHeaderNotAllowed, err) + } + return nil + }, + }, + }, + Middlewares: []httpbp.Middleware{ + httpbp.ExtractBaseplateHeaders("originHTTPBPV0"), + }, + }) + if err != nil { + t.Fatalf("failed to create test originServer: %v", err) + } + t.Cleanup(func() { + originServer.Close() + }) + go originServer.Serve() + + baseURL, err := url.Parse("http://" + originServer.Baseplate().GetConfig().Addr + "/") + if err != nil { + t.Fatalf("failed to parse test originServer base URL: %v", err) + } + + client, err := httpbp.NewClient( + httpbp.ClientConfig{ + Slug: "downstreamHTTPBPV0", + }, + withBaseURL(baseURL), + ) + if err != nil { + t.Fatalf("failed to create test client: %v", err) + } + + req, err := http.NewRequest( + http.MethodGet, + baseURL.JoinPath("say-hello").String(), + nil, + ) + if err != nil { + t.Fatalf("creating request: %v", err) + } + for name, values := range expectedHeaders { + req.Header.Set(name, values[0]) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } +} + +func TestBaseplateHeaderPropagation_untrusted(t *testing.T) { + expectedHeaders := map[string][]string{ + "x-bp-from-edge": {"true"}, + "x-bp-test": {"foo"}, + } + store, _, err := secrets.NewTestSecrets(context.TODO(), nil) + if err != nil { + t.Fatalf("failed to create test secrets: %v", err) + } + t.Cleanup(func() { + store.Close() + }) + bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{ + Config: baseplate.Config{ + Addr: ":8081", + }, + Store: store, + EdgeContextImpl: ecinterface.Mock(), + }) + + originServer, err := httpbp.NewBaseplateServer(httpbp.ServerArgs{ + Baseplate: bp, + Endpoints: map[httpbp.Pattern]httpbp.Endpoint{ + "/say-hello": { + Name: "say-hello", + Methods: []string{http.MethodGet}, + Handle: func(ctx context.Context, writer http.ResponseWriter, request *http.Request) error { + for wantKey := range expectedHeaders { + if v := request.Header.Values(wantKey); len(v) != 0 { + t.Fatalf("expected no values for header %q, got %+v", wantKey, v) + } + } + + return nil + }, + }, + }, + Middlewares: []httpbp.Middleware{ + httpbp.ExtractBaseplateHeaders("originHTTPBPV0"), + }, + }) + if err != nil { + t.Fatalf("failed to create test originServer: %v", err) + } + t.Cleanup(func() { + originServer.Close() + }) + go originServer.Serve() + + baseURL, err := url.Parse("http://" + originServer.Baseplate().GetConfig().Addr + "/") + if err != nil { + t.Fatalf("failed to parse test originServer base URL: %v", err) + } + + client, err := httpbp.NewClient( + httpbp.ClientConfig{ + Slug: "downstreamHTTPBPV0", + }, + withBaseURL(baseURL), + ) + if err != nil { + t.Fatalf("failed to create test client: %v", err) + } + + req, err := http.NewRequest( + http.MethodGet, + baseURL.JoinPath("say-hello").String(), + nil, + ) + if err != nil { + t.Fatalf("creating request: %v", err) + } + req.Header.Set("X-Rddt-Untrusted", "1") + for name, values := range expectedHeaders { + req.Header.Set(name, values[0]) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } +} + +func withBaseURL(baseURL *url.URL) httpbp.ClientMiddleware { + return func(next http.RoundTripper) http.RoundTripper { + return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { + resolved := req.Clone(req.Context()) + resolved.URL = baseURL.ResolveReference(req.URL) + return next.RoundTrip(resolved) + }) + } +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +}