From 7467ff53d55bd7130e866674c6d3490ff74c9204 Mon Sep 17 00:00:00 2001 From: Matheus Degiovani Date: Thu, 2 Nov 2023 12:37:27 -0300 Subject: [PATCH 1/4] gcs: Add func to determine max cfilter size. This will be helpful in determining the max number of cfilters to return in a future batched getcfilter message. This will also nail down the maximum size of the filter based on the constants used throughout the project, such that if any of them changes, causing the max cfilter size to change, a test will break indicating the need to review any code paths that assume a max cfilter size. --- gcs/blockcf2/maxsize_test.go | 289 +++++++++++++++++++++++++++++++++++ gcs/gcs.go | 43 ++++++ gcs/go.mod | 2 +- 3 files changed, 333 insertions(+), 1 deletion(-) create mode 100644 gcs/blockcf2/maxsize_test.go diff --git a/gcs/blockcf2/maxsize_test.go b/gcs/blockcf2/maxsize_test.go new file mode 100644 index 0000000000..e14ffd6ac9 --- /dev/null +++ b/gcs/blockcf2/maxsize_test.go @@ -0,0 +1,289 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package blockcf2 + +import ( + "encoding/binary" + "math/bits" + "math/rand" + "sort" + "testing" + "time" + + "github.com/dchest/siphash" + "github.com/decred/dcrd/chaincfg/v3" + "github.com/decred/dcrd/gcs/v4" + "github.com/decred/dcrd/wire" +) + +// TestMaxSize verifies the max size of blockcf2 filters for various +// parameters. +// +// This test is meant to nail down the values generated by various +// package-level constants to known good hardcoded values such that if any of +// the values change, this test breaks and any code that relies on assumptions +// about the max filter size having a particular value are reviewed. +func TestMaxSize(t *testing.T) { + // entrySizeForSize returns the size of scripts that ensures that every + // script added to a filter maximizes the passed maximum size. + // + // When adding scripts to build the filter, every input script should + // be unique in order to maximize the chances of unique values after + // the siphash and modulo reduction stages. When the total size + // available to build input scripts is less than or equal to 2^8, then + // a single byte is sufficient (because each byte will have a unique + // value), and so on when the max size is 2^16, 2^24 and 2^32. + // + // In practice, for the current network constants, the script size that + // will maximize all tests will be 3. + entrySizeForSize := func(size uint32) uint32 { + switch { + case size < 1<<8: + return 1 + case size < 1<<16: + return 2 + case size < 1<<24: + return 3 + default: + return 4 + } + } + + // mainnetMaxBlockSize is the max size of mainnet blocks according to + // the most recent consensus rules. + mainnetBlockSizes := chaincfg.MainNetParams().MaximumBlockSizes + mainnetMaxBlockSize := uint32(mainnetBlockSizes[len(mainnetBlockSizes)-1]) + + const ( + // minTxSize is the minimum transaction size with one input + // (and no outputs). + minTxSize = uint32(4 + 51 + 18) + + // txOutSize is the size of an output, plus one byte for the + // pkscript length encoded as a varint. + txOutSize = uint32(10 + 1) + + // p2shSize is the size of a P2SH script (the smallest standard + // script that may be added without limits to a transaction). + p2shSize = uint32(23) + ) + + tests := []struct { + name string + n uint32 + want uint64 + }{{ + // This test asserts the size of a filter when the source + // data for the filter fills an entire P2P message. + name: "filter from data that fills an entire P2P msg", + n: wire.MaxMessagePayload / entrySizeForSize(wire.MaxMessagePayload), + want: 22541387, + }, { + // This test asserts the size of a filter when the source data + // for the filter fills as many bytes as the maximum block + // payload size for a P2P message. + name: "filter from data that fills the max block size", + n: wire.MaxBlockPayload / entrySizeForSize(wire.MaxBlockPayload), + want: 1174034, + }, { + // This test asserts the size of a filter when the source data + // for the filter fills as many bytes as the maximum consensus + // enforced block size for mainnet. + name: "filter from data that fills the max mainnet block size", + n: mainnetMaxBlockSize / entrySizeForSize(mainnetMaxBlockSize), + want: 352214, + }, { + // This test asserts the size of a filter when the source data + // for the filter is a single tx that has as many OP_RETURN + // outputs as necessary to fill a block. + name: "filter from tx filled with OP_RETURN outputs", + n: wire.MaxBlockPayload / (txOutSize + entrySizeForSize(wire.MaxBlockPayload)), + want: 251581, + }, { + // This test asserts the size of a filter when the source data + // is a set of transactions that have 4 OP_RETURN outputs, as + // enforced by the standardness policy checks of the mempool. + name: "filter from standard OP_RETURN tx", + n: wire.MaxBlockPayload / (minTxSize + (txOutSize+entrySizeForSize(wire.MaxBlockPayload))*4), + want: 27305, + }, { + // This test asserts the size of a filter when the source data + // is a single transaction with as many P2SH outputs as + // necessary to fill the largest block of any network. + name: "filter from P2SH outputs tx", + n: wire.MaxBlockPayload / (txOutSize + p2shSize), + want: 103593, + }, { + // This test asserts the size of a filter when the source data + // is a single transaction with as many P2SH outputs as + // necessary to fill the largest block on mainnet. + name: "filter from P2SH outputs tx on mainnet", + n: mainnetMaxBlockSize / (txOutSize + p2shSize), + want: 31080, + }} + + for i := range tests { + tc := tests[i] + t.Run(tc.name, func(t *testing.T) { + got := gcs.MaxFilterV2Size(B, M, tc.n) + if tc.want != got { + t.Logf("N: %d", tc.n) + t.Fatalf("Unexpected max filter size: got %d, want %d", + got, tc.want) + } + }) + } +} + +func fastReduce(x, N uint64) uint64 { + hi, _ := bits.Mul64(x, N) + return hi +} + +// TestWorstMaxSize generates the largest possible cfilter for a given random +// key. +func TestWorstMaxSize(t *testing.T) { + if testing.Short() { + t.Skip("Skipping due to -test.short") + } + + const scriptSize = 3 + rnd := rand.New(rand.NewSource(time.Now().UnixNano())) + nbScripts := wire.MaxBlockPayload / scriptSize + + endian := binary.BigEndian + aux := make([]byte, 4) + + var key [gcs.KeySize]byte + rnd.Read(key[:]) + k0 := binary.LittleEndian.Uint64(key[0:8]) + k1 := binary.LittleEndian.Uint64(key[8:16]) + + modulusNM := uint64(nbScripts * M) + + // Generate every 3-byte script (2^24) and create a map of the reduced + // value to the script values that produce that reduced value (as a + // uint32). + maxNb := 1 << (8 * scriptSize) + seenSip := make(map[uint64]struct{}, maxNb) + seenReduced := make(map[uint64][]uint32, maxNb) + keys := make([]uint64, 0, maxNb) + for c := uint32(0); c < uint32(maxNb); c++ { + endian.PutUint32(aux, c) + v := siphash.Hash(k0, k1, aux[1:]) + if _, ok := seenSip[v]; ok { + continue + } + seenSip[v] = struct{}{} + vv := fastReduce(v, modulusNM) + seenReduced[vv] = append(seenReduced[vv], c) + keys = append(keys, vv) + } + + // Sort the reduced values in ascending order. + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + for i := 0; i < 10; i++ { + t.Logf("smallest seen %d: %d (%v)", i, keys[i], seenReduced[keys[i]]) + } + for i := len(keys) - 10; i < len(keys); i++ { + t.Logf("largest seen %d: %d (%v)", i, keys[i], seenReduced[keys[i]]) + } + + // Create the scripts. The first N-1 scripts will generate the smallest + // possible reduced values. + data := make([][]byte, nbScripts) + for i := 0; i < nbScripts-1; i++ { + values := seenReduced[keys[i]] + if len(values) > 1 { + // Report on different input entries that generate + // different siphash values that nevertheless reduce to + // the same final value. + sips := make([]uint64, len(values)) + mods := make([]uint64, len(values)) + for j := range values { + endian.PutUint32(aux, values[j]) + sips[j] = siphash.Hash(k0, k1, aux[1:]) + mods[j] = fastReduce(sips[j], modulusNM) + } + t.Logf("Collision on modulo reduction for script %d: "+ + "values %v, siphashes %v, modulos %v", i, values, + sips, mods) + } + data[i] = make([]byte, scriptSize) + endian.PutUint32(aux, values[0]) + copy(data[i], aux[1:]) + seenReduced[keys[i]] = values[1:] + } + + // The last script will generate the largest possible reduced value. The + // difference between the last and second-to-last scripts is the largest + // difference possible (when using this specific random key). + i := nbScripts - 1 + data[i] = make([]byte, scriptSize) + endian.PutUint32(aux, seenReduced[keys[len(keys)-1]][0]) + copy(data[i], aux[1:]) + + // Create filter and report results. + filter, err := gcs.NewFilterV2(B, M, key, data) + if err != nil { + t.Fatal(err) + } + sz := uint64(len(filter.Bytes())) + t.Logf("Key: %x", key) + t.Logf("Number of values: %d", len(seenReduced)) + t.Logf("Scripts: %d", len(data)) + + // The maximum possible quotient happens when the first nbScripts-1 + // scripts generate values from 0..nbScripts-2 and the last script + // generates the value modulusNM -1. + N := uint64(nbScripts) + maxQuotient := ((modulusNM - 1) - (N - 1)) >> 19 + t.Logf("Max Quotient: %d", maxQuotient) + + // The max possible size can be calculated by assuming every script + // will generate a value (i.e. no removed duplicates) and the last + // script will generate the max possible quotient. + maxPossible := gcs.MaxFilterV2Size(B, M, uint32(nbScripts)) + t.Logf("Max Possible size: %d, actual size: %d", maxPossible, sz) +} + +// TestRandomFilterSize generates filters with random data. +func TestRandomFilterSize(t *testing.T) { + // Use a unique random seed each test instance and log it if the tests + // fail. + seed := time.Now().UnixNano() + rng := rand.New(rand.NewSource(seed)) + defer func(t *testing.T, seed int64) { + if t.Failed() { + t.Logf("random seed: %d", seed) + } + }(t, seed) + scriptSize := 3 + nbScripts := wire.MaxBlockPayload / scriptSize + data := make([][]byte, nbScripts) + for i := range data { + data[i] = make([]byte, scriptSize) + } + var key [gcs.KeySize]byte + + // Generate a random filter and key. + for i := range data { + rng.Read(data[i]) + } + rng.Read(key[:]) + + // Check if it's larger than possible. + filter, err := gcs.NewFilterV2(B, M, key, data) + if err != nil { + t.Fatal(err) + } + sz := uint64(len(filter.Bytes())) + maxPossible := gcs.MaxFilterV2Size(B, M, uint32(nbScripts)) + t.Logf("Max Possible size: %d, actual size: %d", maxPossible, sz) + if sz > maxPossible { + t.Fatalf("Found a random filter with max size %d > max possible size %d", + sz, maxPossible) + } +} diff --git a/gcs/gcs.go b/gcs/gcs.go index 3735b47948..f31dc39968 100644 --- a/gcs/gcs.go +++ b/gcs/gcs.go @@ -549,3 +549,46 @@ func MakeHeaderForFilter(filter *FilterV1, prevHeader *chainhash.Hash) chainhash // The final filter hash is the blake256 of the hash computed above. return chainhash.Hash(blake256.Sum256(filterTip[:])) } + +// MaxFilterV2Size returns the maximum filter size possible for the given filter +// parameters. +func MaxFilterV2Size(B uint8, M uint64, N uint32) uint64 { + // The maximum (i.e. worst) filter size for V2 filters happens when the + // following conditions are met: + // + // - Every one of the N data entries is unique. + // - Every one of the N data entries produces a unique value after + // the siphash stage, ensuring no values are prematurely removed. + // - The quotient difference is maximized and produces the largest + // possible number of one bits in unary coding for the set of unique + // scripts. + // + // Given that the values are sorted prior being encoded with the + // Golomb/Rice coding, the largest possible quotient happens when the + // difference between two consecutive values is maximized. And that + // happens when the last value has the highest possible value and the + // second-to-last has the least possible value. + // + // The highest possible value after the modulo reduction stage is + // N*M-1. And the least possible value is zero. In other words, the + // first N-1 siphashed entries are mapped, after modulo reduction, to + // the value 0 and the last entry is mapped to the value N*M-1. + // + // Thus the largest possible difference is N*M-1, with the max + // number of bits of the quotient readily determined by shifting right + // that amount by B. + n := uint64(N) + b := uint64(B) + largestDiff := n*M - 1 + maxQuoBits := largestDiff >> b + + // Finally, the maximum size of the filter is determined by assuming + // each one of the N entries takes one bit for the 0 in the quotient + // encoding and B bits for the remainder, one entry takes an additional + // maxQuoBits for the quotient encoding, N is encoded as a varint and + // any necessary padding is added. + nSerSize := uint64(wire.VarIntSerializeSize(n)) + maxBits := n + n*b + maxQuoBits + maxBytes := (maxBits+7)/8 + nSerSize + return maxBytes +} diff --git a/gcs/go.mod b/gcs/go.mod index 95c6a5630e..f33ecad8d7 100644 --- a/gcs/go.mod +++ b/gcs/go.mod @@ -6,6 +6,7 @@ require ( github.com/dchest/siphash v1.2.3 github.com/decred/dcrd/blockchain/stake/v5 v5.0.0 github.com/decred/dcrd/chaincfg/chainhash v1.0.4 + github.com/decred/dcrd/chaincfg/v3 v3.2.0 github.com/decred/dcrd/crypto/blake256 v1.0.1 github.com/decred/dcrd/txscript/v4 v4.1.0 github.com/decred/dcrd/wire v1.6.0 @@ -14,7 +15,6 @@ require ( require ( github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 // indirect github.com/decred/base58 v1.0.5 // indirect - github.com/decred/dcrd/chaincfg/v3 v3.2.0 // indirect github.com/decred/dcrd/crypto/ripemd160 v1.0.2 // indirect github.com/decred/dcrd/database/v3 v3.0.1 // indirect github.com/decred/dcrd/dcrec v1.0.1 // indirect From 22a618be29ebec087f4ace63eb02158297c3c1ea Mon Sep 17 00:00:00 2001 From: Matheus Degiovani Date: Thu, 2 Nov 2023 12:38:13 -0300 Subject: [PATCH 2/4] wire: Add msgs to get batch of cfilters. This adds the getcfsv2 and cfiltersv2 messages. These messages are intended to allow nodes in the P2P network to request and receive a batch of version 2 committed filters that span multiple sequential blocks in a chain. These messages are intended to be used when syncing SPV clients, which require requesting all committed filters on the current main chain and currently use a large number of getcfilter/cfilter messages. One of the critical issues in the design for these messages is the MaxCFiltersV2PerBatch constant, which establishes an upper bound on the max number of cfilters requested and replied. The current value of this constant was decided based on the max block size for all of the currently deployed networks, the max possible filter size for transactions in such a block and the max P2P message size. A runtime check is added to ensure any changes to the constants that were involved in deciding this number trigger a panic so that this bound is verified and, if needed, a new protocol version is introduced to update it. This check is meant to guard against inadvertedly changing the assumptions that went into establishing this bound without reviewing it. --- wire/error.go | 5 + wire/error_test.go | 1 + wire/message.go | 8 + wire/msgcfiltersv2.go | 166 ++++++++++++++++++ wire/msgcfiltersv2_test.go | 335 +++++++++++++++++++++++++++++++++++++ wire/msggetcfsv2.go | 73 ++++++++ wire/msggetcfsv2_test.go | 246 +++++++++++++++++++++++++++ wire/protocol.go | 6 +- 8 files changed, 839 insertions(+), 1 deletion(-) create mode 100644 wire/msgcfiltersv2.go create mode 100644 wire/msgcfiltersv2_test.go create mode 100644 wire/msggetcfsv2.go create mode 100644 wire/msggetcfsv2_test.go diff --git a/wire/error.go b/wire/error.go index 81a104f255..5554e1528c 100644 --- a/wire/error.go +++ b/wire/error.go @@ -149,6 +149,10 @@ const ( // ErrTooManyPrevMixMsgs is returned when too many previous messages of // a mix run are referenced by a message. ErrTooManyPrevMixMsgs + + // ErrTooManyCFilters is returned when the number of committed filters + // exceeds the maximum allowed in a batch. + ErrTooManyCFilters ) // Map of ErrorCode values back to their constant names for pretty printing. @@ -188,6 +192,7 @@ var errorCodeStrings = map[ErrorCode]string{ ErrMixPairReqScriptClassTooLong: "ErrMixPairReqScriptClassTooLong", ErrTooManyMixPairReqUTXOs: "ErrTooManyMixPairReqUTXOs", ErrTooManyPrevMixMsgs: "ErrTooManyPrevMixMsgs", + ErrTooManyCFilters: "ErrTooManyCFilters", } // String returns the ErrorCode as a human-readable name. diff --git a/wire/error_test.go b/wire/error_test.go index 4c162f88dd..30f98e9dee 100644 --- a/wire/error_test.go +++ b/wire/error_test.go @@ -54,6 +54,7 @@ func TestMessageErrorCodeStringer(t *testing.T) { {ErrMixPairReqScriptClassTooLong, "ErrMixPairReqScriptClassTooLong"}, {ErrTooManyMixPairReqUTXOs, "ErrTooManyMixPairReqUTXOs"}, {ErrTooManyPrevMixMsgs, "ErrTooManyPrevMixMsgs"}, + {ErrTooManyCFilters, "ErrTooManyCFilters"}, {0xffff, "Unknown ErrorCode (65535)"}, } diff --git a/wire/message.go b/wire/message.go index 82d96847bd..e5dca165dd 100644 --- a/wire/message.go +++ b/wire/message.go @@ -60,6 +60,8 @@ const ( CmdMixDCNet = "mixdcnet" CmdMixConfirm = "mixconfirm" CmdMixSecrets = "mixsecrets" + CmdGetCFiltersV2 = "getcfsv2" + CmdCFiltersV2 = "cfiltersv2" ) const ( @@ -217,6 +219,12 @@ func makeEmptyMessage(command string) (Message, error) { case CmdMixSecrets: msg = &MsgMixSecrets{} + case CmdGetCFiltersV2: + msg = &MsgGetCFsV2{} + + case CmdCFiltersV2: + msg = &MsgCFiltersV2{} + default: str := fmt.Sprintf("unhandled command [%s]", command) return nil, messageError(op, ErrUnknownCmd, str) diff --git a/wire/msgcfiltersv2.go b/wire/msgcfiltersv2.go new file mode 100644 index 0000000000..352e323754 --- /dev/null +++ b/wire/msgcfiltersv2.go @@ -0,0 +1,166 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MaxCFiltersV2PerBatch is the maximum number of committed filters that may +// be sent in a cfiltersv2 message. +// +// This number has been decided assuming a sequence of blocks filled with +// transactions specifically designed to maximize the filter size, such that +// a cfiltersv2 message will not be larger than the maximum allowed P2P message +// size on any of the currently deployed networks. +const MaxCFiltersV2PerBatch = 100 + +func init() { + // The following runtime assertions are included to ensure if any of + // the input constants used to determine the MaxCFiltersV2PerBatch + // number are changed from their original values, then that constant is + // reviewed to still be compatible to the new protocol or consensus + // constants. + // + // In particular, the max number of cfilters in a batched reply has + // been determined by assuming a sequence of blocks of size + // MaxBlockPayload (1.25 MiB), filled with a transaction with as many + // OP_RETURNs as needed to fill a worst-case v2 filter (251581 bytes). + // + // At 100 filters per batch, with the added overhead of the commitment + // proof, the maximum size of a cfiltersv2 message should be 25 MiB, + // which is less than the maximum P2P message size of 32 MiB. + // + // Check the MaxTestSize test from the blockcf2 package for information + // on how the maximum v2 cfilter size is determined. + // + // If any of these assertions break, such that a change to + // MaxCFiltersV2PerBatch is necessary, then a new protocol version + // should be introduced to allow a modification to + // MaxCFiltersV2PerBatch. + switch { + case MaxCFiltersV2PerBatch != 100: + panic("review MaxCFiltersV2PerBatch due to constant change") + case MaxBlockPayload != 1310720: + panic("review MaxCFiltersV2PerBatch due to MaxBlockPayload change") + case MaxCFilterDataSize != 256*1024: + panic("review MaxCFiltersV2PerBatch due to MaxCFilterDataSize change") + case MaxMessagePayload != 1024*1024*32: + panic("review MaxCFiltersV2PerBatch due to MaxMessagePayload change") + case (&MsgCFiltersV2{}).MaxPayloadLength(BatchedCFiltersV2Version) != 26321001: + panic("review MaxCFiltersV2PerBatch due to MaxPayloadLength change") + case (&MsgCFiltersV2{}).MaxPayloadLength(ProtocolVersion) != 26321001: + panic("review MaxCFiltersV2PerBatch due to MaxPayloadLength change") + } +} + +// MsgCFiltersV2 implements the Message interface and represents a cfiltersv2 +// message. It is used to deliver a batch of version 2 committed gcs filters +// for a given range of blocks, along with a proof that can be used to prove +// the filter is committed to by the block header. +// +// It is delivered in response to a getcfsv2 message (MsgGetCFiltersV2). +type MsgCFiltersV2 struct { + CFilters []MsgCFilterV2 +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgCFiltersV2) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgCFiltersV2.BtcDecode" + if pver < BatchedCFiltersV2Version { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + nbCFilters, err := ReadVarInt(r, pver) + if err != nil { + return err + } + + if nbCFilters > MaxCFiltersV2PerBatch { + msg := fmt.Sprintf("%s too many cfilters sent in batch "+ + "[count %v max %v]", msg.Command(), nbCFilters, + MaxCFiltersV2PerBatch) + return messageError(op, ErrTooManyCFilters, msg) + } + + msg.CFilters = make([]MsgCFilterV2, nbCFilters) + for i := 0; i < int(nbCFilters); i++ { + cf := &msg.CFilters[i] + err := cf.BtcDecode(r, pver) + if err != nil { + return err + } + } + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgCFiltersV2) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgCFiltersV2.BtcEncode" + if pver < BatchedCFiltersV2Version { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + nbCFilters := len(msg.CFilters) + if nbCFilters > MaxCFiltersV2PerBatch { + msg := fmt.Sprintf("%s too many cfilters to send in batch "+ + "[count %v max %v]", msg.Command(), nbCFilters, + MaxCFiltersV2PerBatch) + return messageError(op, ErrTooManyCFilters, msg) + } + + err := WriteVarInt(w, pver, uint64(nbCFilters)) + if err != nil { + return err + } + + for i := 0; i < nbCFilters; i++ { + err = msg.CFilters[i].BtcEncode(w, pver) + if err != nil { + return err + } + } + + return nil +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgCFiltersV2) Command() string { + return CmdCFiltersV2 +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgCFiltersV2) MaxPayloadLength(pver uint32) uint32 { + // Varint + n * individual cfilter message: + // Block hash + max filter data (including varint) + + // proof index + max num proof hashes (including varint). + return uint32(VarIntSerializeSize(MaxCFiltersV2PerBatch)) + + (chainhash.HashSize+ + uint32(VarIntSerializeSize(MaxCFilterDataSize))+ + MaxCFilterDataSize+4+ + uint32(VarIntSerializeSize(MaxHeaderProofHashes))+ + (MaxHeaderProofHashes*chainhash.HashSize))*uint32(MaxCFiltersV2PerBatch) +} + +// NewMsgCFiltersV2 returns a new cfiltersv2 message that conforms to the +// Message interface using the passed parameters and defaults for the remaining +// fields. +func NewMsgCFiltersV2(filters []MsgCFilterV2) *MsgCFiltersV2 { + return &MsgCFiltersV2{ + CFilters: filters, + } +} diff --git a/wire/msgcfiltersv2_test.go b/wire/msgcfiltersv2_test.go new file mode 100644 index 0000000000..24c74057a0 --- /dev/null +++ b/wire/msgcfiltersv2_test.go @@ -0,0 +1,335 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "io" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// baseMsgCFiltersV2 returns a MsgCFiltersV2 struct populated with mock values +// that are used throughout tests. Note that the tests will need to be updated +// if these values are changed since they rely on the current values. +func baseMsgCFiltersV2(t *testing.T) *MsgCFiltersV2 { + t.Helper() + + filters := []MsgCFilterV2{ + *baseMsgCFilterV2(t), + } + + return NewMsgCFiltersV2(filters) +} + +// TestCFiltersV2 tests the MsgCFiltersV2 API against the latest protocol +// version. +func TestCFiltersV2(t *testing.T) { + pver := ProtocolVersion + + // Ensure the command is expected value. + wantCmd := "cfiltersv2" + msg := baseMsgCFiltersV2(t) + if cmd := msg.Command(); cmd != wantCmd { + t.Errorf("NewMsgCFiltersV2: wrong command - got %v want %v", cmd, + wantCmd) + } + + // Ensure max payload is expected value for latest protocol version. + // varint max number of cfilters + max number of cfilters * + // (Block hash + max commitment name length (including varint) + + // proof index + max num proof hashes (including varint).) + wantPayload := uint32(26321001) + maxPayload := msg.MaxPayloadLength(pver) + if maxPayload != wantPayload { + t.Errorf("MaxPayloadLength: wrong max payload length for protocol "+ + "version %d - got %v, want %v", pver, maxPayload, wantPayload) + } + + // Ensure encoding max number of cfilters with max cfilter data and + // max proof hashes returns no error. + maxData := make([]byte, MaxCFilterDataSize) + maxProofHashes := make([]chainhash.Hash, MaxHeaderProofHashes) + msg.CFilters = make([]MsgCFilterV2, 0, MaxCFiltersV2PerBatch) + for len(msg.CFilters) < MaxCFiltersV2PerBatch { + cf := baseMsgCFilterV2(t) + cf.Data = maxData + cf.ProofHashes = maxProofHashes + msg.CFilters = append(msg.CFilters, *cf) + } + + var buf bytes.Buffer + if err := msg.BtcEncode(&buf, pver); err != nil { + t.Fatal(err) + } + + // Ensure the maximum actually encoded length is less than or equal to + // the max payload length. + if uint32(buf.Len()) > maxPayload { + t.Fatalf("Largest message encoded to a buffer larger than the "+ + " MaxPayloadLength for protocol version %d - got %v, want %v", + pver, buf.Len(), maxPayload) + } +} + +// TestCFiltersV2PreviousProtocol tests the MsgCFiltersV2 API against the protocol +// prior to version BatchedCFiltersV2Version. +func TestCFiltersV2PreviousProtocol(t *testing.T) { + // Use the protocol version just prior to BatchedCFiltersV2Version changes. + pver := BatchedCFiltersV2Version - 1 + + msg := baseMsgCFiltersV2(t) + + // Test encode with old protocol version. + var buf bytes.Buffer + err := msg.BtcEncode(&buf, pver) + if !errors.Is(err, ErrMsgInvalidForPVer) { + t.Errorf("unexpected error when encoding for protocol version %d, "+ + "prior to message introduction - got %v, want %v", pver, + err, ErrMsgInvalidForPVer) + } + + // Test decode with old protocol version. + var readmsg MsgCFiltersV2 + err = readmsg.BtcDecode(&buf, pver) + if !errors.Is(err, ErrMsgInvalidForPVer) { + t.Errorf("unexpected error when decoding for protocol version %d, "+ + "prior to message introduction - got %v, want %v", pver, + err, ErrMsgInvalidForPVer) + } +} + +// TestCFiltersV2CrossProtocol tests the MsgCFiltersV2 API when encoding with +// the latest protocol version and decoding with BatchedCFiltersV2Version. +func TestCFiltersV2CrossProtocol(t *testing.T) { + msg := baseMsgCFiltersV2(t) + + // Encode with latest protocol version. + var buf bytes.Buffer + err := msg.BtcEncode(&buf, ProtocolVersion) + if err != nil { + t.Errorf("encode of MsgCFiltersV2 failed %v err <%v>", msg, err) + } + + // Decode with old protocol version. + var readmsg MsgCFiltersV2 + err = readmsg.BtcDecode(&buf, BatchedCFiltersV2Version) + if err != nil { + t.Errorf("decode of MsgCFiltersV2 failed [%v] err <%v>", buf, err) + } +} + +// TestCFiltersV2Wire tests the MsgCFiltersV2 wire encode and decode for various +// protocol versions. +func TestCFiltersV2Wire(t *testing.T) { + msgCFiltersV2 := baseMsgCFiltersV2(t) + msgCFiltersV2Encoded := []byte{ + 0x01, // Varint for number of filters + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0x1d, // Varint for filter data length + 0x00, 0x00, 0x00, 0x11, 0x1c, 0xa3, 0xaa, 0xfb, + 0x02, 0x30, 0x74, 0xdc, 0x5b, 0xf2, 0x49, 0x8d, + 0xf7, 0x91, 0xb7, 0xd6, 0xe8, 0x46, 0xe9, 0xf5, + 0x01, 0x60, 0x06, 0xd6, 0x00, // Filter data + 0x00, 0x00, 0x00, 0x00, // Proof index + 0x01, // Varint for num proof hashes + 0x47, 0x63, 0x69, 0x67, 0x50, 0xe6, 0x72, 0x86, + 0x7f, 0x91, 0x00, 0x68, 0x79, 0x94, 0x18, 0xdb, + 0x8d, 0xa6, 0x07, 0xba, 0xf2, 0x28, 0x08, 0x55, + 0x22, 0x48, 0xb5, 0xd0, 0xb9, 0x5f, 0x89, 0xb4, // first proof hash + } + + tests := []struct { + in *MsgCFiltersV2 // Message to encode + out *MsgCFiltersV2 // Expected decoded message + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + }{{ + // Latest protocol version. + msgCFiltersV2, + msgCFiltersV2, + msgCFiltersV2Encoded, + ProtocolVersion, + }, { + // Protocol version BatchedCFiltersV2Version+1. + msgCFiltersV2, + msgCFiltersV2, + msgCFiltersV2Encoded, + BatchedCFiltersV2Version + 1, + }, { + // Protocol version BatchedCFiltersV2Version. + msgCFiltersV2, + msgCFiltersV2, + msgCFiltersV2Encoded, + BatchedCFiltersV2Version, + }} + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode the message to wire format. + var buf bytes.Buffer + err := test.in.BtcEncode(&buf, test.pver) + if err != nil { + t.Errorf("BtcEncode #%d error %v", i, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.buf) { + t.Errorf("BtcEncode #%d\n got: %s want: %s", i, + spew.Sdump(buf.Bytes()), spew.Sdump(test.buf)) + continue + } + + // Decode the message from wire format. + var msg MsgCFiltersV2 + rbuf := bytes.NewReader(test.buf) + err = msg.BtcDecode(rbuf, test.pver) + if err != nil { + t.Errorf("BtcDecode #%d error %v", i, err) + continue + } + if !reflect.DeepEqual(&msg, test.out) { + t.Errorf("BtcDecode #%d\n got: %s want: %s", i, + spew.Sdump(&msg), spew.Sdump(test.out)) + continue + } + } +} + +// TestCFiltersV2WireErrors performs negative tests against wire encode and +// decode of MsgCFiltersV2 to confirm error paths work correctly. +func TestCFiltersV2WireErrors(t *testing.T) { + pver := ProtocolVersion + + // Message with valid mock values. + baseCFiltersV2 := baseMsgCFiltersV2(t) + baseCFiltersV2Encoded := []byte{ + 0x01, // Varint for number of cfilters + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0x1d, // Varint for filter data length + 0x00, 0x00, 0x00, 0x11, 0x1c, 0xa3, 0xaa, 0xfb, + 0x02, 0x30, 0x74, 0xdc, 0x5b, 0xf2, 0x49, 0x8d, + 0xf7, 0x91, 0xb7, 0xd6, 0xe8, 0x46, 0xe9, 0xf5, + 0x01, 0x60, 0x06, 0xd6, 0x00, // Filter data + 0x00, 0x00, 0x00, 0x00, // Proof index + 0x01, // Varint for num proof hashes + 0x47, 0x63, 0x69, 0x67, 0x50, 0xe6, 0x72, 0x86, + 0x7f, 0x91, 0x00, 0x68, 0x79, 0x94, 0x18, 0xdb, + 0x8d, 0xa6, 0x07, 0xba, 0xf2, 0x28, 0x08, 0x55, + 0x22, 0x48, 0xb5, 0xd0, 0xb9, 0x5f, 0x89, 0xb4, // first proof hash + } + + // Message that forces an error by having a data that exceeds the max + // allowed length. + badFilterData := bytes.Repeat([]byte{0x00}, MaxCFilterDataSize+1) + maxDataCFiltersV2 := baseMsgCFiltersV2(t) + maxDataCFiltersV2.CFilters[0].Data = badFilterData + maxDataCFiltersV2Encoded := []byte{ + 0x01, // Varint for number of cfilters + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0xfe, 0x01, 0x00, 0x04, 0x00, // Varint for filter data length + } + + // Message that forces an error by having more than the max allowed proof + // hashes. + maxHashesCFiltersV2 := baseMsgCFiltersV2(t) + maxHashesCFiltersV2.CFilters[0].ProofHashes = make([]chainhash.Hash, MaxHeaderProofHashes+1) + maxHashesCFiltersV2Encoded := []byte{ + 0x01, // Varint for number of cfilters + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0x1d, // Varint for filter data length + 0x00, 0x00, 0x00, 0x11, 0x1c, 0xa3, 0xaa, 0xfb, + 0x02, 0x30, 0x74, 0xdc, 0x5b, 0xf2, 0x49, 0x8d, + 0xf7, 0x91, 0xb7, 0xd6, 0xe8, 0x46, 0xe9, 0xf5, + 0x01, 0x60, 0x06, 0xd6, 0x00, // Filter data + 0x00, 0x00, 0x00, 0x00, // Proof index + 0x21, // Varint for num proof hashes + } + + // Message that forces an error by having more than the max allowed number + // of cfilters. + maxCFiltersV2 := baseMsgCFiltersV2(t) + for len(maxCFiltersV2.CFilters) < MaxCFiltersV2PerBatch+1 { + maxCFiltersV2.CFilters = append(maxCFiltersV2.CFilters, *baseMsgCFilterV2(t)) + } + maxCFiltersV2Encoded := []byte{ + 0x65, // Varint for number of cfilters + } + + tests := []struct { + in *MsgCFiltersV2 // Value to encode + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + max int // Max size of fixed buffer to induce errors + writeErr error // Expected write error + readErr error // Expected read error + }{ + // Force error in cfilter number varint. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 0, io.ErrShortWrite, io.EOF}, + // Force error in start of block hash. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 1, io.ErrShortWrite, io.EOF}, + // Force error in middle of block hash. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 9, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error in filter data len. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 33, io.ErrShortWrite, io.EOF}, + // Force error in start of filter data. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 34, io.ErrShortWrite, io.EOF}, + // Force error in middle of filter data. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 46, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error in start of proof index. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 63, io.ErrShortWrite, io.EOF}, + // Force error in middle of proof index. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 65, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error in num proof hashes. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 67, io.ErrShortWrite, io.EOF}, + // Force error in start of first proof hash. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 68, io.ErrShortWrite, io.EOF}, + // Force error in middle of first proof hash. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 78, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error with greater than max filter data. + {maxDataCFiltersV2, maxDataCFiltersV2Encoded, pver, 38, ErrFilterTooLarge, ErrVarBytesTooLong}, + // Force error with greater than max proof hashes. + {maxHashesCFiltersV2, maxHashesCFiltersV2Encoded, pver, 68, ErrTooManyProofs, ErrTooManyProofs}, + // Force error with greater than max cfilters. + {maxCFiltersV2, maxCFiltersV2Encoded, pver, 1, ErrTooManyCFilters, ErrTooManyCFilters}, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode to wire format. + w := newFixedWriter(test.max) + err := test.in.BtcEncode(w, test.pver) + if !errors.Is(err, test.writeErr) { + t.Errorf("BtcEncode #%d wrong error got: %v, want: %v", i, err, + test.writeErr) + continue + } + + // Decode from wire format. + var msg MsgCFiltersV2 + r := newFixedReader(test.max, test.buf) + err = msg.BtcDecode(r, test.pver) + if !errors.Is(err, test.readErr) { + t.Errorf("BtcDecode #%d wrong error got: %v, want: %v", i, err, + test.readErr) + continue + } + } +} diff --git a/wire/msggetcfsv2.go b/wire/msggetcfsv2.go new file mode 100644 index 0000000000..53fb6c17bf --- /dev/null +++ b/wire/msggetcfsv2.go @@ -0,0 +1,73 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MsgGetCFsV2 implements the Message interface and represents a decred +// getcfsv2 message. It is used to request a batch of version 2 committed +// filters that span a subset of a chain, from StartHash up to (and including) +// EndHash. The response is sent in a MsgCFiltersV2 message. +// +// At most MaxCFiltersV2PerBatch may be requested by each MsgGetCFsV2 +// message, which means the number of blocks between EndHash and StartHash must +// be lesser than or equal to that constant's value. +type MsgGetCFsV2 struct { + StartHash chainhash.Hash + EndHash chainhash.Hash +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgGetCFsV2) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgGetCFsV2.BtcDecode" + if pver < BatchedCFiltersV2Version { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + return readElements(r, &msg.StartHash, &msg.EndHash) +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgGetCFsV2) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgGetCFsV2.BtcEncode" + if pver < BatchedCFiltersV2Version { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + return writeElements(w, &msg.StartHash, msg.EndHash) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgGetCFsV2) Command() string { + return CmdGetCFiltersV2 +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgGetCFsV2) MaxPayloadLength(pver uint32) uint32 { + // Block hash. + return chainhash.HashSize * 2 +} + +// NewMsgGetCFsV2 returns a new Decred getcfiltersv2 message that conforms +// to the Message interface using the passed parameters. +func NewMsgGetCFsV2(startHash, endHash *chainhash.Hash) *MsgGetCFsV2 { + return &MsgGetCFsV2{ + StartHash: *startHash, + EndHash: *endHash, + } +} diff --git a/wire/msggetcfsv2_test.go b/wire/msggetcfsv2_test.go new file mode 100644 index 0000000000..875fbc88f7 --- /dev/null +++ b/wire/msggetcfsv2_test.go @@ -0,0 +1,246 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "io" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// baseMsgGetCFsV2 returns a MsgGetCFsV2 struct populated with mock +// values that are used throughout tests. Note that the tests will need to be +// updated if these values are changed since they rely on the current values. +func baseMsgGetCFsV2(t *testing.T) *MsgGetCFsV2 { + t.Helper() + + // Mock block hash. + startHashStr := "000000000000c41019872ff7db8fd2e9bfa05f42d3f8fee8e895e8c1e5b8dcba" + startHash, err := chainhash.NewHashFromStr(startHashStr) + if err != nil { + t.Fatalf("NewHashFromStr: %v", err) + } + + endHashStr := "00000000000108ac3e3f51a0f4424dd757a3b0485da0ec96592f637f27bd1cf5" + endHash, err := chainhash.NewHashFromStr(endHashStr) + if err != nil { + t.Fatalf("NewHashFromStr: %v", err) + } + + return NewMsgGetCFsV2(startHash, endHash) +} + +// TestGetCFiltersV2 tests the MsgGetCFsV2 API against the latest protocol +// version. +func TestGetCFiltersV2(t *testing.T) { + pver := ProtocolVersion + + // Ensure the command is expected value. + wantCmd := "getcfsv2" + msg := baseMsgGetCFsV2(t) + if cmd := msg.Command(); cmd != wantCmd { + t.Errorf("NewMsgGetCFsV2: wrong command - got %v want %v", cmd, + wantCmd) + } + + // Ensure max payload is expected value for latest protocol version. + // Start hash + end hash. + wantPayload := uint32(64) + maxPayload := msg.MaxPayloadLength(pver) + if maxPayload != wantPayload { + t.Errorf("MaxPayloadLength: wrong max payload length for protocol "+ + "version %d - got %v, want %v", pver, maxPayload, wantPayload) + } + + // Ensure max payload length is not more than MaxMessagePayload. + if maxPayload > MaxMessagePayload { + t.Fatalf("MaxPayloadLength: payload length (%v) for protocol version "+ + "%d exceeds MaxMessagePayload (%v).", maxPayload, pver, + MaxMessagePayload) + } +} + +// TestGetCFiltersV2PreviousProtocol tests the MsgGetCFsV2 API against the +// protocol prior to version BatchedCFiltersV2Version. +func TestGetCFiltersV2PreviousProtocol(t *testing.T) { + // Use the protocol version just prior to CFilterV2Version changes. + pver := BatchedCFiltersV2Version - 1 + + msg := baseMsgGetCFsV2(t) + + // Test encode with old protocol version. + var buf bytes.Buffer + err := msg.BtcEncode(&buf, pver) + if !errors.Is(err, ErrMsgInvalidForPVer) { + t.Errorf("unexpected error when encoding for protocol version %d, "+ + "prior to message introduction - got %v, want %v", pver, + err, ErrMsgInvalidForPVer) + } + + // Test decode with old protocol version. + var readmsg MsgGetCFsV2 + err = readmsg.BtcDecode(&buf, pver) + if !errors.Is(err, ErrMsgInvalidForPVer) { + t.Errorf("unexpected error when decoding for protocol version %d, "+ + "prior to message introduction - got %v, want %v", pver, + err, ErrMsgInvalidForPVer) + } +} + +// TestGetCFiltersV2CrossProtocol tests the MsgGetCFsV2 API when encoding +// with the latest protocol version and decoding with BatchedCFiltersV2Version. +func TestGetCFiltersV2CrossProtocol(t *testing.T) { + msg := baseMsgGetCFsV2(t) + + // Encode with latest protocol version. + var buf bytes.Buffer + err := msg.BtcEncode(&buf, ProtocolVersion) + if err != nil { + t.Errorf("encode of MsgGetCFsV2 failed %v err <%v>", msg, err) + } + + // Decode with old protocol version. + var readmsg MsgGetCFilterV2 + err = readmsg.BtcDecode(&buf, BatchedCFiltersV2Version) + if err != nil { + t.Errorf("decode of MsgGetCFsV2 failed [%v] err <%v>", buf, err) + } +} + +// TestGetCFiltersV2Wire tests the MsgGetCFsV2 wire encode and decode for +// various commitment names and protocol versions. +func TestGetCFiltersV2Wire(t *testing.T) { + // MsgGetCFsV2 message with mock block hashes. + msgGetCFsV2 := baseMsgGetCFsV2(t) + msgGetCFsV2Encoded := []byte{ + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock start hash + 0xf5, 0x1c, 0xbd, 0x27, 0x7f, 0x63, 0x2f, 0x59, + 0x96, 0xec, 0xa0, 0x5d, 0x48, 0xb0, 0xa3, 0x57, + 0xd7, 0x4d, 0x42, 0xf4, 0xa0, 0x51, 0x3f, 0x3e, + 0xac, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock end hash + } + + tests := []struct { + in *MsgGetCFsV2 // Message to encode + out *MsgGetCFsV2 // Expected decoded message + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + }{{ + // Latest protocol version. + msgGetCFsV2, + msgGetCFsV2, + msgGetCFsV2Encoded, + ProtocolVersion, + }, { + // Protocol version CFilterV2Version+1. + msgGetCFsV2, + msgGetCFsV2, + msgGetCFsV2Encoded, + BatchedCFiltersV2Version + 1, + }, { + // Protocol version CFilterV2Version. + msgGetCFsV2, + msgGetCFsV2, + msgGetCFsV2Encoded, + BatchedCFiltersV2Version, + }} + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode the message to wire format. + var buf bytes.Buffer + err := test.in.BtcEncode(&buf, test.pver) + if err != nil { + t.Errorf("BtcEncode #%d error %v", i, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.buf) { + t.Errorf("BtcEncode #%d\n got: %s want: %s", i, + spew.Sdump(buf.Bytes()), spew.Sdump(test.buf)) + continue + } + + // Decode the message from wire format. + var msg MsgGetCFsV2 + rbuf := bytes.NewReader(test.buf) + err = msg.BtcDecode(rbuf, test.pver) + if err != nil { + t.Errorf("BtcDecode #%d error %v", i, err) + continue + } + if !reflect.DeepEqual(&msg, test.out) { + t.Errorf("BtcDecode #%d\n got: %s want: %s", i, spew.Sdump(&msg), + spew.Sdump(test.out)) + continue + } + } +} + +// TestGetCFiltersV2WireErrors performs negative tests against wire encode and +// decode of MsgGetCFsV2 to confirm error paths work correctly. +func TestGetCFiltersV2WireErrors(t *testing.T) { + pver := ProtocolVersion + + // MsgGetCFilterV2 message with mock block hash. + baseGetCFiltersV2 := baseMsgGetCFsV2(t) + baseGetCFiltersV2Encoded := []byte{ + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0xf5, 0x1c, 0xbd, 0x27, 0x7f, 0x63, 0x2f, 0x59, + 0x96, 0xec, 0xa0, 0x5d, 0x48, 0xb0, 0xa3, 0x57, + 0xd7, 0x4d, 0x42, 0xf4, 0xa0, 0x51, 0x3f, 0x3e, + 0xac, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock end hash + } + + tests := []struct { + in *MsgGetCFsV2 // Value to encode + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + max int // Max size of fixed buffer to induce errors + writeErr error // Expected write error + readErr error // Expected read error + }{ + // Force error in start of start hash. + {baseGetCFiltersV2, baseGetCFiltersV2Encoded, pver, 0, io.ErrShortWrite, io.EOF}, + // Force error in middle of start hash. + {baseGetCFiltersV2, baseGetCFiltersV2Encoded, pver, 8, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error in start of end hash. + {baseGetCFiltersV2, baseGetCFiltersV2Encoded, pver, 32, io.ErrShortWrite, io.EOF}, + // Force error in middle of end hash. + {baseGetCFiltersV2, baseGetCFiltersV2Encoded, pver, 40, io.ErrShortWrite, io.ErrUnexpectedEOF}, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode to wire format. + w := newFixedWriter(test.max) + err := test.in.BtcEncode(w, test.pver) + if !errors.Is(err, test.writeErr) { + t.Errorf("BtcEncode #%d wrong error got: %v, want: %v", i, err, + test.writeErr) + continue + } + + // Decode from wire format. + var msg MsgGetCFsV2 + r := newFixedReader(test.max, test.buf) + err = msg.BtcDecode(r, test.pver) + if !errors.Is(err, test.readErr) { + t.Errorf("BtcDecode #%d wrong error got: %v, want: %v", i, err, + test.readErr) + continue + } + } +} diff --git a/wire/protocol.go b/wire/protocol.go index c2585d4380..1e0c0e497a 100644 --- a/wire/protocol.go +++ b/wire/protocol.go @@ -17,7 +17,7 @@ const ( InitialProcotolVersion uint32 = 1 // ProtocolVersion is the latest protocol version this package supports. - ProtocolVersion uint32 = 10 + ProtocolVersion uint32 = 11 // NodeBloomVersion is the protocol version which added the SFNodeBloom // service flag (unused). @@ -54,6 +54,10 @@ const ( // MixVersion is the protocol version which adds peer-to-peer mixing. MixVersion uint32 = 10 + + // BatchedCFiltersV2Version is the protocol version which adds support + // for the batched getcfsv2 and cfiltersv2 messages. + BatchedCFiltersV2Version uint32 = 11 ) // ServiceFlag identifies services supported by a Decred peer. From 5d4cfa2c25073812f5ced4f94334529fb577758b Mon Sep 17 00:00:00 2001 From: Matheus Degiovani Date: Thu, 2 Nov 2023 12:38:48 -0300 Subject: [PATCH 3/4] blockchain: Add function to locate multiple cfilters. This function will be used to reply to the recently introduced getcfsv2 P2P message. The LocateCFiltersV2 function fetches a batch of cfilters from the database and encodes it in a wire.CFiltersV2 message, ready to be sent to remote peers. --- internal/blockchain/chainio.go | 15 +++++ internal/blockchain/error.go | 8 +++ internal/blockchain/error_test.go | 2 + internal/blockchain/headercmt.go | 102 ++++++++++++++++++++++++++++++ 4 files changed, 127 insertions(+) diff --git a/internal/blockchain/chainio.go b/internal/blockchain/chainio.go index ebf1eea9eb..9c8c2ef05b 100644 --- a/internal/blockchain/chainio.go +++ b/internal/blockchain/chainio.go @@ -851,6 +851,21 @@ func dbFetchGCSFilter(dbTx database.Tx, blockHash *chainhash.Hash) (*gcs.FilterV return filter, nil } +// dbFetchRawGCSFilter fetches the raw version 2 GCS filter for the passed +// block, without decoding it from the db. +// +// WARNING: the returned slice is only valid for the duration of the database +// transaction and MUST be copied to a new buffer if it is needed after the db +// transaction ends. +// +// This function is meant for cases where the raw filter bytes will be used +// without decoding into a gcs.FilterV2 value. For a safer alternative, use +// dbFetchGCSFilter. +func dbFetchRawGCSFilter(dbTx database.Tx, blockHash *chainhash.Hash) []byte { + filterBucket := dbTx.Metadata().Bucket(gcsFilterBucketName) + return filterBucket.Get(blockHash[:]) +} + // dbPutGCSFilter uses an existing database transaction to update the version 2 // GCS filter for the given block hash using the provided filter. func dbPutGCSFilter(dbTx database.Tx, blockHash *chainhash.Hash, filter *gcs.FilterV2) error { diff --git a/internal/blockchain/error.go b/internal/blockchain/error.go index 52141c12da..f7f4e24009 100644 --- a/internal/blockchain/error.go +++ b/internal/blockchain/error.go @@ -589,6 +589,14 @@ const ( // ErrSerializeHeader indicates an attempt to serialize a block header failed. ErrSerializeHeader = ErrorKind("ErrSerializeHeader") + // ErrNotAnAncestor indicates an attempt to fetch a chain of filters where + // the start hash is not an ancestor of the end hash. + ErrNotAnAncestor = ErrorKind("ErrNotAnAncestor") + + // ErrRequestTooLarge indicates an attempt to request too much data + // in a batched request. + ErrRequestTooLarge = ErrorKind("ErrRequestTooLarge") + // ------------------------------------------ // Errors related to the UTXO backend. // ------------------------------------------ diff --git a/internal/blockchain/error_test.go b/internal/blockchain/error_test.go index 73714ac898..2262e06edb 100644 --- a/internal/blockchain/error_test.go +++ b/internal/blockchain/error_test.go @@ -150,6 +150,8 @@ func TestErrorKindStringer(t *testing.T) { {ErrNoTreasuryBalance, "ErrNoTreasuryBalance"}, {ErrInvalidateGenesisBlock, "ErrInvalidateGenesisBlock"}, {ErrSerializeHeader, "ErrSerializeHeader"}, + {ErrNotAnAncestor, "ErrNotAnAncestor"}, + {ErrRequestTooLarge, "ErrRequestTooLarge"}, {ErrUtxoBackend, "ErrUtxoBackend"}, {ErrUtxoBackendCorruption, "ErrUtxoBackendCorruption"}, {ErrUtxoBackendNotOpen, "ErrUtxoBackendNotOpen"}, diff --git a/internal/blockchain/headercmt.go b/internal/blockchain/headercmt.go index 62fd1987d3..d47c5eefaf 100644 --- a/internal/blockchain/headercmt.go +++ b/internal/blockchain/headercmt.go @@ -191,3 +191,105 @@ func (b *BlockChain) FilterByBlockHash(hash *chainhash.Hash) (*gcs.FilterV2, *He } return filter, headerProof, nil } + +// LocateCFiltersV2 fetches all committed filters between startHash and endHash +// (inclusive) and prepares a MsgCFiltersV2 response to return this batch +// of CFilters to a remote peer. +// +// The start and end blocks must both exist and the start block must be an +// ancestor to the end block. +// +// This function is safe for concurrent access. +func (b *BlockChain) LocateCFiltersV2(startHash, endHash *chainhash.Hash) (*wire.MsgCFiltersV2, error) { + // Sanity check. + b.chainLock.RLock() + startNode := b.index.LookupNode(startHash) + if startNode == nil { + b.chainLock.RUnlock() + return nil, unknownBlockError(startHash) + } + endNode := b.index.LookupNode(endHash) + if endNode == nil { + b.chainLock.RUnlock() + return nil, unknownBlockError(endHash) + } + if !startNode.IsAncestorOf(endNode) { + b.chainLock.RUnlock() + str := fmt.Sprintf("start block %s is not an ancestor of end block %s", + startHash, endHash) + return nil, contextError(ErrNotAnAncestor, str) + } + + // Figure out size of the response. + nb := int(endNode.height - startNode.height + 1) + if nb > wire.MaxCFiltersV2PerBatch { + b.chainLock.RUnlock() + str := fmt.Sprintf("number of requested cfilters %d greater than max allowed %d", + nb, wire.MaxCFiltersV2PerBatch) + return nil, contextError(ErrRequestTooLarge, str) + } + + // Allocate all internal messages in a single buffer to reduce memory + // allocation counts. + filters := make([]wire.MsgCFilterV2, nb) + proofLeaves := make([][]chainhash.Hash, nb) + + // Fetch all relevant block hashes. + node := endNode + for i := nb - 1; i >= 0; i-- { + filters[i].BlockHash = node.hash + node = node.parent + } + + // At this point all index operations have completed, so release the + // RLock. + b.chainLock.RUnlock() + + // Prepare the response from DB. + err := b.db.View(func(dbTx database.Tx) error { + totalLen := 0 + data := make([][]byte, nb) + for i := 0; i < nb; i++ { + hash := &filters[i].BlockHash + cfData := dbFetchRawGCSFilter(dbTx, hash) + if cfData == nil { + str := fmt.Sprintf("no filter available for block %s", hash) + return contextError(ErrNoFilter, str) + } + + data[i] = cfData + totalLen += len(cfData) + + var err error + proofLeaves[i], err = dbFetchHeaderCommitments(dbTx, hash) + if err != nil { + return err + } + } + + // Allocate a single backing buffer for all cfilter data and + // copy each individual db buffer into it. This reduces + // memory allocation counts and fragmentation. + buffer := make([]byte, totalLen) + var j int + for i := 0; i < nb; i++ { + sz := len(data[i]) + copy(buffer[j:], data[i]) + filters[i].Data = buffer[j : j+sz : j+sz] + j += sz + } + return nil + }) + if err != nil { + return nil, err + } + + // Prepare the response. + const proofIndex = HeaderCmtFilterIndex + for i := 0; i < nb; i++ { + proofHashes := standalone.GenerateInclusionProof(proofLeaves[i], proofIndex) + filters[i].ProofHashes = proofHashes + filters[i].ProofIndex = proofIndex + } + return wire.NewMsgCFiltersV2(filters), nil +} From 5a28304e53f8618e6479b3f337d2e5e44116acff Mon Sep 17 00:00:00 2001 From: Matheus Degiovani Date: Thu, 2 Nov 2023 12:39:06 -0300 Subject: [PATCH 4/4] multi: Respond to getcfsv2 message. This adds the appropriate processing to the peer and server structs to respond to the recently introduced getcfsv2 message. It also bumps the peer and server max supported protocol versions to version 10 (BatchedCFiltersV2Version). This message queries the chain for a batch of committed filters spanning a set of sequential blocks and will be used by SPV clients to fetch committed filters during their initial sync process. --- peer/go.mod | 2 ++ peer/peer.go | 19 ++++++++++++++++++- peer/peer_test.go | 14 ++++++++++++++ server.go | 13 ++++++++++++- 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/peer/go.mod b/peer/go.mod index cb4b8ff8c8..614778a12c 100644 --- a/peer/go.mod +++ b/peer/go.mod @@ -12,6 +12,8 @@ require ( github.com/decred/slog v1.2.0 ) +replace github.com/decred/dcrd/wire => ../wire + require ( github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 // indirect github.com/dchest/siphash v1.2.3 // indirect diff --git a/peer/peer.go b/peer/peer.go index 438ac9de7b..8070f442d8 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -28,7 +28,7 @@ import ( const ( // MaxProtocolVersion is the max protocol version the peer supports. - MaxProtocolVersion = wire.RemoveRejectVersion + MaxProtocolVersion = wire.BatchedCFiltersV2Version // outputBufferSize is the number of elements the output channels use. outputBufferSize = 5000 @@ -129,6 +129,9 @@ type MessageListeners struct { // OnCFilterV2 is invoked when a peer receives a cfilterv2 wire message. OnCFilterV2 func(p *Peer, msg *wire.MsgCFilterV2) + // OnCFiltersV2 is invoked when a peer receives a cfiltersv2 wire message. + OnCFiltersV2 func(p *Peer, msg *wire.MsgCFiltersV2) + // OnCFHeaders is invoked when a peer receives a cfheaders wire // message. OnCFHeaders func(p *Peer, msg *wire.MsgCFHeaders) @@ -163,6 +166,10 @@ type MessageListeners struct { // message. OnGetCFilterV2 func(p *Peer, msg *wire.MsgGetCFilterV2) + // OnGetCFiltersV2 is invoked when a peer receives a getcfsv2 wire + // message. + OnGetCFiltersV2 func(p *Peer, msg *wire.MsgGetCFsV2) + // OnGetCFHeaders is invoked when a peer receives a getcfheaders // wire message. OnGetCFHeaders func(p *Peer, msg *wire.MsgGetCFHeaders) @@ -1423,6 +1430,16 @@ out: p.cfg.Listeners.OnCFilterV2(p, msg) } + case *wire.MsgGetCFsV2: + if p.cfg.Listeners.OnGetCFiltersV2 != nil { + p.cfg.Listeners.OnGetCFiltersV2(p, msg) + } + + case *wire.MsgCFiltersV2: + if p.cfg.Listeners.OnCFiltersV2 != nil { + p.cfg.Listeners.OnCFiltersV2(p, msg) + } + case *wire.MsgGetInitState: if p.cfg.Listeners.OnGetInitState != nil { p.cfg.Listeners.OnGetInitState(p, msg) diff --git a/peer/peer_test.go b/peer/peer_test.go index 1a2befb310..aa74fe6b1c 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -411,6 +411,12 @@ func TestPeerListeners(t *testing.T) { OnInitState: func(p *Peer, msg *wire.MsgInitState) { ok <- msg }, + OnGetCFiltersV2: func(p *Peer, msg *wire.MsgGetCFsV2) { + ok <- msg + }, + OnCFiltersV2: func(p *Peer, msg *wire.MsgCFiltersV2) { + ok <- msg + }, }, UserAgentName: "peer", UserAgentVersion: "1.0", @@ -565,6 +571,14 @@ func TestPeerListeners(t *testing.T) { "OnInitState", wire.NewMsgInitState(), }, + { + "OnGetCFiltersV2", + wire.NewMsgGetCFsV2(&chainhash.Hash{}, &chainhash.Hash{}), + }, + { + "OnCFiltersV2", + wire.NewMsgCFiltersV2([]wire.MsgCFilterV2{}), + }, } t.Logf("Running %d tests", len(tests)) for _, test := range tests { diff --git a/server.go b/server.go index ecfbd918d3..3a536419c9 100644 --- a/server.go +++ b/server.go @@ -74,7 +74,7 @@ const ( connectionRetryInterval = time.Second * 5 // maxProtocolVersion is the max protocol version the server supports. - maxProtocolVersion = wire.RemoveRejectVersion + maxProtocolVersion = wire.BatchedCFiltersV2Version // These fields are used to track known addresses on a per-peer basis. // @@ -1522,6 +1522,16 @@ func (sp *serverPeer) OnGetCFilterV2(_ *peer.Peer, msg *wire.MsgGetCFilterV2) { sp.QueueMessage(filterMsg, nil) } +// OnGetCFiltersV2 is invoked when a peer receives a getcfsv2 wire message. +func (sp *serverPeer) OnGetCFiltersV2(_ *peer.Peer, msg *wire.MsgGetCFsV2) { + filtersMsg, err := sp.server.chain.LocateCFiltersV2(&msg.StartHash, &msg.EndHash) + if err != nil { + return + } + + sp.QueueMessage(filtersMsg, nil) +} + // OnGetCFHeaders is invoked when a peer receives a getcfheader wire message. func (sp *serverPeer) OnGetCFHeaders(_ *peer.Peer, msg *wire.MsgGetCFHeaders) { // Disconnect and/or ban depending on the node cf services flag and @@ -2308,6 +2318,7 @@ func newPeerConfig(sp *serverPeer) *peer.Config { OnGetHeaders: sp.OnGetHeaders, OnGetCFilter: sp.OnGetCFilter, OnGetCFilterV2: sp.OnGetCFilterV2, + OnGetCFiltersV2: sp.OnGetCFiltersV2, OnGetCFHeaders: sp.OnGetCFHeaders, OnGetCFTypes: sp.OnGetCFTypes, OnGetAddr: sp.OnGetAddr,