diff --git a/webhook/webhook.go b/webhook/webhook.go new file mode 100644 index 00000000..dbf90169 --- /dev/null +++ b/webhook/webhook.go @@ -0,0 +1,312 @@ +package webhook + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" +) + +var ( + // CurrentSignatureScheme is the current latest signature scheme. + CurrentSignatureScheme = SignatureSchemeV1 + + // AllSignatureSchemes is a list of all supported signature schemes. + AllSignatureSchemes = []SignatureScheme{ /* populated by init() */ } + + // HTTPHeaderSignature is the name of the header that contains signature packages. + HTTPHeaderSignature = "Do-Signature" + // HTTPHeaderEventName is the name of the header that contains the event name. + HTTPHeaderEventName = "Do-Event-Name" + + // DefaultTolerance is the default time tolerance for signature verification (3 minutes). + DefaultTolerance time.Duration = 3 * 60 * time.Second + + // ErrExpiredSignature indicates that the signature timestamp is outside of the allowed tolerance. + ErrExpiredSignature = fmt.Errorf("signature has expired") + // ErrNoVerifiedSignature indicates that no verified signature was found. + ErrNoVerifiedSignature = fmt.Errorf("no verified signature") + // ErrNotSigned indicates that the payload is not signed. + ErrNotSigned = fmt.Errorf("payload not signed") +) + +var ( + signatureSchemesByVersion = map[int]SignatureScheme{} +) + +func registerSignatureScheme(s SignatureScheme) { + signatureSchemesByVersion[s.Version()] = s + AllSignatureSchemes = append(AllSignatureSchemes, s) +} + +func init() { + // Schemes should be ordered by version number descending i.e. from newest to oldest. + registerSignatureScheme(SignatureSchemeV1) +} + +// SignaturePackage contains multiple signatures. +type SignaturePackage struct { + Timestamp time.Time + Signatures []Signature +} + +// NewSignaturePackage creates a signature package. +func NewSignaturePackage(t time.Time, payload []byte, secrets []string) SignaturePackage { + p := SignaturePackage{ + Timestamp: t, + } + + for _, scheme := range AllSignatureSchemes { + for _, secret := range secrets { + p.Signatures = append(p.Signatures, NewSignature(scheme, t, payload, secret)) + } + } + + return p +} + +// String returns the string representation of the signature package. +func (p *SignaturePackage) String() string { + value := make([]string, 0, len(p.Signatures)+1) + + value = append(value, fmt.Sprintf("t=%d", p.Timestamp.Unix())) + for _, s := range p.Signatures { + value = append(value, s.String()) + } + + return strings.Join(value, ",") +} + +// ParseSignaturePackage parses a signature package from its string representation. +func ParseSignaturePackage(value string) (SignaturePackage, error) { + sigPack := SignaturePackage{} + + pairs := strings.Split(value, ",") + for _, p := range pairs { + parts := strings.SplitN(p, "=", 2) + if len(parts) != 2 { + return SignaturePackage{}, fmt.Errorf("invalid signature package") + } + + k, v := parts[0], parts[1] + if k == "t" { + if !sigPack.Timestamp.IsZero() { + return SignaturePackage{}, fmt.Errorf("timestamp cannot be specified multiple times") + } + ts, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return SignaturePackage{}, fmt.Errorf("timestamp must be an integer") + } + sigPack.Timestamp = time.Unix(ts, 0).UTC() + } else { + sig, err := ParseSignature(p) + if err != nil { + return SignaturePackage{}, err + } + sigPack.Signatures = append(sigPack.Signatures, sig) + } + } + + if sigPack.Timestamp.IsZero() { + return SignaturePackage{}, fmt.Errorf("missing timestamp") + } + + return sigPack, nil +} + +// VerificationOpts sets options for verifying signature packages. +type VerificationOpts struct { + // Tolerance configures the maximum allowed signature age. Signatures older than this time window will fail verification. + // If unset, defaults to DefaultTolerance. + Tolerance time.Duration + // IgnoreTolerance skips checking if the signature package timestamp is within the allowed tolerance. + IgnoreTolerance bool + // Now is an optional override of time.Now. + Now func() time.Time + // UntrustedSchemes is a list of signature schemes that are untrusted. + UntrustedSchemes []SignatureScheme +} + +// Verify verifies the given signature package. Verification passes if at least of the signatures in the package is verified. +func (p SignaturePackage) Verify(payload []byte, secret string, opts VerificationOpts) error { + now := time.Now() + if opts.Now != nil { + now = opts.Now() + } + + if !opts.IgnoreTolerance { + tolerance := DefaultTolerance + if opts.Tolerance > 0 { + tolerance = opts.Tolerance + } + if now.Sub(p.Timestamp) > tolerance { + return ErrExpiredSignature + } + } + + if len(p.Signatures) == 0 { + return ErrNotSigned + } + + // try to find at least one verified signature +verifySignatures: + for _, s := range p.Signatures { + for _, scheme := range opts.UntrustedSchemes { + if scheme.Version() == s.Scheme.Version() { + continue verifySignatures + } + } + verified := s.Verify(payload, secret, p.Timestamp) + if verified == nil { + return nil + } + } + + return ErrNoVerifiedSignature +} + +// SignHTTPRequest signs the given HTTP request and sets the signature header. +func SignHTTPRequest(r *http.Request, t time.Time, secrets []string) error { + body, err := r.GetBody() + if err != nil { + return err + } + defer body.Close() + + payload, err := io.ReadAll(body) + if err != nil { + return err + } + + sigPack := NewSignaturePackage(t, payload, secrets) + r.Header.Set(HTTPHeaderSignature, sigPack.String()) + return nil +} + +// VerifyHTTPRequest verifies an HTTP request. +func VerifyHTTPRequest(r *http.Request, secret string, opts VerificationOpts) error { + header := r.Header.Get(HTTPHeaderSignature) + if header == "" { + return ErrNotSigned + } + + sigPack, err := ParseSignaturePackage(header) + if err != nil { + return fmt.Errorf("parsing signature header: %w", err) + } + + body, err := io.ReadAll(r.Body) + if err != nil { + return fmt.Errorf("reading request body: %w", err) + } + // Replace the body with a new reader after reading from the original + r.Body = io.NopCloser(bytes.NewBuffer(body)) + + return sigPack.Verify(body, secret, opts) +} + +// NewSignature creates a new signature. +func NewSignature(scheme SignatureScheme, t time.Time, payload []byte, secret string) Signature { + return Signature{ + Scheme: scheme, + Value: scheme.Sign(t, payload, secret), + } +} + +// Signature describes a signature. +type Signature struct { + Scheme SignatureScheme + Value string +} + +// String returns the string representation of a signature. +func (s Signature) String() string { + return fmt.Sprintf("v%d=%s", s.Scheme.Version(), s.Value) +} + +// Equal compares two signatures for equality without leaking timing information. +func (s Signature) Equal(o Signature) bool { + return subtle.ConstantTimeCompare([]byte(s.Value), []byte(o.Value)) == 1 +} + +// Verify verifies the given signature. The timestamp that was used to generate this signature must be provided. +func (s Signature) Verify(payload []byte, secret string, t time.Time) error { + if s.Scheme == nil { + return fmt.Errorf("invalid signature scheme") + } + + freshSig := NewSignature(s.Scheme, t, payload, secret) + if !s.Equal(freshSig) { + return ErrNoVerifiedSignature + } + + // the signatures are identical + return nil +} + +// ParseSignature attempts to parse a signature from its string representation. +func ParseSignature(value string) (Signature, error) { + parts := strings.SplitN(value, "=", 2) + if len(parts) != 2 { + return Signature{}, fmt.Errorf("invalid signature format") + } + + versionStr, value := parts[0], parts[1] + if !strings.HasPrefix(versionStr, "v") { + return Signature{}, fmt.Errorf("invalid signature format") + } + version, err := strconv.ParseInt(versionStr[1:], 10, 0) + if err != nil { + return Signature{}, fmt.Errorf("signature scheme version must be an integer") + } + scheme := signatureSchemesByVersion[int(version)] + if scheme == nil { + return Signature{}, fmt.Errorf("invalid signature scheme version %d", version) + } + + return Signature{ + Scheme: scheme, + Value: value, + }, nil +} + +// SignatureScheme describes a signature scheme. +type SignatureScheme interface { + Sign(t time.Time, payload []byte, secret string) string + Version() int +} + +// SignatureSchemeV1 computes an HMAC-SHA256 signature of the timestamp and payload in the following format: +// +// {unix timestamp}.{payload} +// +// The resulting signature is then hex-encoded. +var SignatureSchemeV1 SignatureScheme = &signatureSchemeV1{} + +type signatureSchemeV1 struct{} + +// Version returns the scheme version. +func (s *signatureSchemeV1) Version() int { + return 1 +} + +// Sign signs a payload. +func (s *signatureSchemeV1) Sign(t time.Time, payload []byte, secret string) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte(fmt.Sprintf("%d", t.Unix()))) + mac.Write([]byte(".")) + mac.Write(payload) + return hex.EncodeToString(mac.Sum(nil)) +} + +// EventName returns a namespaced event name. +func EventName(ns string, name string) string { + return fmt.Sprintf("%s.%s", ns, name) +} diff --git a/webhook/webhook_test.go b/webhook/webhook_test.go new file mode 100644 index 00000000..0f7e86e7 --- /dev/null +++ b/webhook/webhook_test.go @@ -0,0 +1,316 @@ +package webhook + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func init() { + // add a fake signature scheme to test functionality related to multiple schemes + registerSignatureScheme(FakeSignatureScheme) +} + +var ( + // find the easter egg šŸ„š + testPayload = []byte("it is wednesday my dudes šŸ•·ļø") + testSecret = "du-TY1GUFGk" +) + +func TestSignatureSchemeV1(t *testing.T) { + ts := time.Date(2000, 1, 1, 10, 0, 0, 0, time.UTC) + sig := NewSignature(SignatureSchemeV1, ts, testPayload, testSecret) + require.Equal(t, "v1=b70100cf2943bec15996e3ae9392d0dcaf21f285fa81969108185d47b292dfa2", sig.String()) + require.Equal(t, SignatureSchemeV1.Sign(ts, testPayload, testSecret), sig.Value) + + sig = NewSignature(SignatureSchemeV1, ts, testPayload, "other-secret") + require.Equal(t, "v1=817555f45dd54e36c87ad1a349083e5d2e706cae2eae7f4077379f5444f5b985", sig.String()) + + sig = NewSignature(SignatureSchemeV1, ts.Add(time.Minute), testPayload, "other-secret") + require.Equal(t, "v1=eb5eb314fd727bcbae0713640b66540c64dc7e7cfa18f715deee29ed5db59347", sig.String()) +} + +func TestSignature(t *testing.T) { + ts := time.Date(2000, 1, 1, 10, 0, 0, 0, time.UTC) + var sig Signature + t.Run("new", func(t *testing.T) { + sig = NewSignature(SignatureSchemeV1, ts, testPayload, testSecret) + require.Equal(t, Signature{ + Scheme: SignatureSchemeV1, + Value: SignatureSchemeV1.Sign(ts, testPayload, testSecret), + }, sig) + require.Equal(t, "v1=b70100cf2943bec15996e3ae9392d0dcaf21f285fa81969108185d47b292dfa2", sig.String()) + }) + + t.Run("verify", func(t *testing.T) { + err := sig.Verify(testPayload, testSecret, ts) + require.NoError(t, err) + + err = sig.Verify(testPayload, "other-secret", ts) + require.Equal(t, ErrNoVerifiedSignature, err) + + err = sig.Verify(testPayload, testSecret, ts.Add(time.Hour)) + require.Equal(t, ErrNoVerifiedSignature, err) + + err = sig.Verify([]byte("other-payload"), testSecret, ts) + require.Equal(t, ErrNoVerifiedSignature, err) + + err = Signature{}.Verify(nil, "", ts) + require.EqualError(t, err, "invalid signature scheme") + }) + + t.Run("parse", func(t *testing.T) { + parsed, err := ParseSignature(sig.String()) + require.NoError(t, err) + require.Equal(t, Signature{ + Scheme: sig.Scheme, + Value: sig.Value, + }, parsed) + + require.True(t, sig.Equal(parsed)) + differentSig := NewSignature(sig.Scheme, ts, testPayload, "other-secret") + require.False(t, sig.Equal(differentSig)) + + for val, err := range map[string]string{ + "šŸŒ": "invalid signature format", + "šŸŒ=": "invalid signature format", + "=a": "invalid signature format", + "všŸŒ=": "signature scheme version must be an integer", + "v=a": "signature scheme version must be an integer", + "v0=": "invalid signature scheme version 0", + } { + _, got := ParseSignature(val) + require.EqualError(t, got, err, val) + } + }) +} + +func TestSignaturePackage(t *testing.T) { + ts := time.Date(2000, 1, 1, 10, 0, 0, 0, time.UTC) + secrets := []string{testSecret, "some-secret"} + var sp SignaturePackage + t.Run("new", func(t *testing.T) { + sp = NewSignaturePackage(ts, testPayload, secrets) + require.Equal( + t, + `t=946720800,`+ + `v1=b70100cf2943bec15996e3ae9392d0dcaf21f285fa81969108185d47b292dfa2,v1=b3218d58417e81cf347b439091b9ede800b2e1555f90fee81ac94f67c249da26,`+ + `v1337=946720800:du-TY1GUFGk:(32),v1337=946720800:some-secret:(32)`, + sp.String(), + ) + }) + + t.Run("verify", func(t *testing.T) { + nowWithinTolerance := func() time.Time { + return ts.Add(DefaultTolerance) + } + + // verified: happy path + err := sp.Verify(testPayload, testSecret, VerificationOpts{Now: nowWithinTolerance}) + require.NoError(t, err) + + // false: expired signature + err = sp.Verify(testPayload, testSecret, VerificationOpts{ + Now: func() time.Time { + return ts.Add(DefaultTolerance + time.Second) + }, + }) + require.Equal(t, ErrExpiredSignature, err) + + // false: expired signature w/ custom tolerance + err = sp.Verify(testPayload, testSecret, VerificationOpts{ + Tolerance: 3 * time.Second, + Now: func() time.Time { + return ts.Add(5 * time.Second) + }, + }) + require.Equal(t, ErrExpiredSignature, err) + + // verified: expired signature w/ ignore tolerance + err = sp.Verify(testPayload, testSecret, VerificationOpts{ + IgnoreTolerance: true, + Now: func() time.Time { + return ts.Add(DefaultTolerance + time.Second) + }, + }) + require.NoError(t, err) + + // false: signature signed by unknown secret + err = sp.Verify(testPayload, "other-secret", VerificationOpts{Now: nowWithinTolerance}) + require.Equal(t, ErrNoVerifiedSignature, err) + + // false: signature does not match payload + err = sp.Verify([]byte("other-payload"), testSecret, VerificationOpts{Now: nowWithinTolerance}) + require.Equal(t, ErrNoVerifiedSignature, err) + + // false: signature signed by untrusted scheme + err = sp.Verify(testPayload, testSecret, VerificationOpts{ + Now: nowWithinTolerance, + UntrustedSchemes: AllSignatureSchemes, + }) + require.Equal(t, ErrNoVerifiedSignature, err) + + // verified: only one of the schemes is untrusted + err = sp.Verify(testPayload, testSecret, VerificationOpts{ + Now: nowWithinTolerance, + UntrustedSchemes: []SignatureScheme{FakeSignatureScheme}, + }) + require.NoError(t, err) + }) + + t.Run("parse", func(t *testing.T) { + // sanity checks to double check the correctness of the hardcoded test string above + require.Equal(t, ts.Unix(), int64(946720800)) + + v1s1, err := ParseSignature("v1=b70100cf2943bec15996e3ae9392d0dcaf21f285fa81969108185d47b292dfa2") + require.NoError(t, err) + require.Equal(t, NewSignature(SignatureSchemeV1, ts, testPayload, secrets[0]), v1s1) + v1s2, err := ParseSignature("v1=b3218d58417e81cf347b439091b9ede800b2e1555f90fee81ac94f67c249da26") + require.NoError(t, err) + require.Equal(t, NewSignature(SignatureSchemeV1, ts, testPayload, secrets[1]), v1s2) + + v1337s1, err := ParseSignature("v1337=946720800:du-TY1GUFGk:(32)") + require.NoError(t, err) + require.Equal(t, NewSignature(FakeSignatureScheme, ts, testPayload, secrets[0]), v1337s1) + v1337s2, err := ParseSignature("v1337=946720800:some-secret:(32)") + require.NoError(t, err) + require.Equal(t, NewSignature(FakeSignatureScheme, ts, testPayload, secrets[1]), v1337s2) + + parsed, err := ParseSignaturePackage(sp.String()) + require.NoError(t, err) + require.Equal(t, sp, parsed) + require.Equal(t, sp.String(), parsed.String()) + + // error cases + for val, err := range map[string]string{ + "šŸŒ": "invalid signature package", + "v999=šŸŒ": "invalid signature scheme version 999", + "v1=b70100cf2943bec15996e3ae9392d0dcaf21f285fa81969108185d47b292dfa2": "missing timestamp", + "t=šŸŒ": "timestamp must be an integer", + "t=123,v1=b70100cf2943bec15996e3ae9392d0dcaf21f285fa81969108185d47b292dfa2,t=341": "timestamp cannot be specified multiple times", + } { + _, got := ParseSignaturePackage(val) + require.EqualError(t, got, err, val) + } + }) +} + +var FakeSignatureScheme = &fakeSignatureScheme{} + +type fakeSignatureScheme struct{} + +func (s *fakeSignatureScheme) Version() int { + return 1337 +} + +func (s *fakeSignatureScheme) Sign(t time.Time, payload []byte, secret string) string { + return strings.ReplaceAll(fmt.Sprintf("%d:%s:(%d)", t.Unix(), secret, len(payload)), ",", "") +} + +func TestHTTPRequests(t *testing.T) { + now := time.Date(2000, 1, 1, 10, 0, 0, 0, time.UTC) + nowFunc := func() time.Time { return now } + + tcs := []struct { + name string + handler http.HandlerFunc + request func(t *testing.T, url string) *http.Request + }{ + { + name: "happy path", + request: func(t *testing.T, url string) *http.Request { + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(testPayload)) + require.NoError(t, err) + // sign request + err = SignHTTPRequest(req, now, []string{testSecret, "other-secret"}) + require.NoError(t, err) + return req + }, + handler: func(w http.ResponseWriter, r *http.Request) { + err := VerifyHTTPRequest(r, testSecret, VerificationOpts{Now: nowFunc}) + require.NoError(t, err) + + err = VerifyHTTPRequest(r, "other-secret", VerificationOpts{Now: nowFunc}) + require.NoError(t, err) + + err = VerifyHTTPRequest(r, testSecret, VerificationOpts{ + Now: nowFunc, + UntrustedSchemes: AllSignatureSchemes, + }) + require.Equal(t, ErrNoVerifiedSignature, err) + + headerValue := r.Header.Get(HTTPHeaderSignature) + require.Equal(t, "t=946720800,v1=b70100cf2943bec15996e3ae9392d0dcaf21f285fa81969108185d47b292dfa2,v1=817555f45dd54e36c87ad1a349083e5d2e706cae2eae7f4077379f5444f5b985,v1337=946720800:du-TY1GUFGk:(32),v1337=946720800:other-secret:(32)", headerValue) + + // sanity check & ensure that r.Body is still accessible + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + sig := NewSignature(SignatureSchemeV1, time.Unix(946720800, 0), body, testSecret) + require.Equal(t, Signature{ + Scheme: SignatureSchemeV1, + Value: "b70100cf2943bec15996e3ae9392d0dcaf21f285fa81969108185d47b292dfa2", + }, sig) + }, + }, + { + name: "unsigned request", + request: func(t *testing.T, url string) *http.Request { + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(testPayload)) + require.NoError(t, err) + return req + }, + handler: func(w http.ResponseWriter, r *http.Request) { + err := VerifyHTTPRequest(r, testSecret, VerificationOpts{Now: nowFunc}) + require.Equal(t, ErrNotSigned, err) + }, + }, + { + name: "header present without signatures", + request: func(t *testing.T, url string) *http.Request { + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(testPayload)) + require.NoError(t, err) + req.Header.Set(HTTPHeaderSignature, fmt.Sprintf("t=%d", now.Unix())) + return req + }, + handler: func(w http.ResponseWriter, r *http.Request) { + err := VerifyHTTPRequest(r, testSecret, VerificationOpts{Now: nowFunc}) + require.Equal(t, ErrNotSigned, err) + }, + }, + { + name: "bad header value", + request: func(t *testing.T, url string) *http.Request { + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(testPayload)) + require.NoError(t, err) + req.Header.Set(HTTPHeaderSignature, "šŸ§˜šŸ»ā€ā™‚ļøšŸŒšŸ„–šŸš—šŸ“±šŸŽ‰āœ…") + return req + }, + handler: func(w http.ResponseWriter, r *http.Request) { + err := VerifyHTTPRequest(r, testSecret, VerificationOpts{Now: nowFunc}) + require.EqualError(t, err, "parsing signature header: invalid signature package") + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tc.handler(w, r) + w.WriteHeader(http.StatusUnavailableForLegalReasons) + })) + + req := tc.request(t, server.URL) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusUnavailableForLegalReasons, res.StatusCode) + }) + } +}