diff --git a/gcs/blockcf2/maxsize_test.go b/gcs/blockcf2/maxsize_test.go new file mode 100644 index 0000000000..e14ffd6ac9 --- /dev/null +++ b/gcs/blockcf2/maxsize_test.go @@ -0,0 +1,289 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package blockcf2 + +import ( + "encoding/binary" + "math/bits" + "math/rand" + "sort" + "testing" + "time" + + "github.com/dchest/siphash" + "github.com/decred/dcrd/chaincfg/v3" + "github.com/decred/dcrd/gcs/v4" + "github.com/decred/dcrd/wire" +) + +// TestMaxSize verifies the max size of blockcf2 filters for various +// parameters. +// +// This test is meant to nail down the values generated by various +// package-level constants to known good hardcoded values such that if any of +// the values change, this test breaks and any code that relies on assumptions +// about the max filter size having a particular value are reviewed. +func TestMaxSize(t *testing.T) { + // entrySizeForSize returns the size of scripts that ensures that every + // script added to a filter maximizes the passed maximum size. + // + // When adding scripts to build the filter, every input script should + // be unique in order to maximize the chances of unique values after + // the siphash and modulo reduction stages. When the total size + // available to build input scripts is less than or equal to 2^8, then + // a single byte is sufficient (because each byte will have a unique + // value), and so on when the max size is 2^16, 2^24 and 2^32. + // + // In practice, for the current network constants, the script size that + // will maximize all tests will be 3. + entrySizeForSize := func(size uint32) uint32 { + switch { + case size < 1<<8: + return 1 + case size < 1<<16: + return 2 + case size < 1<<24: + return 3 + default: + return 4 + } + } + + // mainnetMaxBlockSize is the max size of mainnet blocks according to + // the most recent consensus rules. + mainnetBlockSizes := chaincfg.MainNetParams().MaximumBlockSizes + mainnetMaxBlockSize := uint32(mainnetBlockSizes[len(mainnetBlockSizes)-1]) + + const ( + // minTxSize is the minimum transaction size with one input + // (and no outputs). + minTxSize = uint32(4 + 51 + 18) + + // txOutSize is the size of an output, plus one byte for the + // pkscript length encoded as a varint. + txOutSize = uint32(10 + 1) + + // p2shSize is the size of a P2SH script (the smallest standard + // script that may be added without limits to a transaction). + p2shSize = uint32(23) + ) + + tests := []struct { + name string + n uint32 + want uint64 + }{{ + // This test asserts the size of a filter when the source + // data for the filter fills an entire P2P message. + name: "filter from data that fills an entire P2P msg", + n: wire.MaxMessagePayload / entrySizeForSize(wire.MaxMessagePayload), + want: 22541387, + }, { + // This test asserts the size of a filter when the source data + // for the filter fills as many bytes as the maximum block + // payload size for a P2P message. + name: "filter from data that fills the max block size", + n: wire.MaxBlockPayload / entrySizeForSize(wire.MaxBlockPayload), + want: 1174034, + }, { + // This test asserts the size of a filter when the source data + // for the filter fills as many bytes as the maximum consensus + // enforced block size for mainnet. + name: "filter from data that fills the max mainnet block size", + n: mainnetMaxBlockSize / entrySizeForSize(mainnetMaxBlockSize), + want: 352214, + }, { + // This test asserts the size of a filter when the source data + // for the filter is a single tx that has as many OP_RETURN + // outputs as necessary to fill a block. + name: "filter from tx filled with OP_RETURN outputs", + n: wire.MaxBlockPayload / (txOutSize + entrySizeForSize(wire.MaxBlockPayload)), + want: 251581, + }, { + // This test asserts the size of a filter when the source data + // is a set of transactions that have 4 OP_RETURN outputs, as + // enforced by the standardness policy checks of the mempool. + name: "filter from standard OP_RETURN tx", + n: wire.MaxBlockPayload / (minTxSize + (txOutSize+entrySizeForSize(wire.MaxBlockPayload))*4), + want: 27305, + }, { + // This test asserts the size of a filter when the source data + // is a single transaction with as many P2SH outputs as + // necessary to fill the largest block of any network. + name: "filter from P2SH outputs tx", + n: wire.MaxBlockPayload / (txOutSize + p2shSize), + want: 103593, + }, { + // This test asserts the size of a filter when the source data + // is a single transaction with as many P2SH outputs as + // necessary to fill the largest block on mainnet. + name: "filter from P2SH outputs tx on mainnet", + n: mainnetMaxBlockSize / (txOutSize + p2shSize), + want: 31080, + }} + + for i := range tests { + tc := tests[i] + t.Run(tc.name, func(t *testing.T) { + got := gcs.MaxFilterV2Size(B, M, tc.n) + if tc.want != got { + t.Logf("N: %d", tc.n) + t.Fatalf("Unexpected max filter size: got %d, want %d", + got, tc.want) + } + }) + } +} + +func fastReduce(x, N uint64) uint64 { + hi, _ := bits.Mul64(x, N) + return hi +} + +// TestWorstMaxSize generates the largest possible cfilter for a given random +// key. +func TestWorstMaxSize(t *testing.T) { + if testing.Short() { + t.Skip("Skipping due to -test.short") + } + + const scriptSize = 3 + rnd := rand.New(rand.NewSource(time.Now().UnixNano())) + nbScripts := wire.MaxBlockPayload / scriptSize + + endian := binary.BigEndian + aux := make([]byte, 4) + + var key [gcs.KeySize]byte + rnd.Read(key[:]) + k0 := binary.LittleEndian.Uint64(key[0:8]) + k1 := binary.LittleEndian.Uint64(key[8:16]) + + modulusNM := uint64(nbScripts * M) + + // Generate every 3-byte script (2^24) and create a map of the reduced + // value to the script values that produce that reduced value (as a + // uint32). + maxNb := 1 << (8 * scriptSize) + seenSip := make(map[uint64]struct{}, maxNb) + seenReduced := make(map[uint64][]uint32, maxNb) + keys := make([]uint64, 0, maxNb) + for c := uint32(0); c < uint32(maxNb); c++ { + endian.PutUint32(aux, c) + v := siphash.Hash(k0, k1, aux[1:]) + if _, ok := seenSip[v]; ok { + continue + } + seenSip[v] = struct{}{} + vv := fastReduce(v, modulusNM) + seenReduced[vv] = append(seenReduced[vv], c) + keys = append(keys, vv) + } + + // Sort the reduced values in ascending order. + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + for i := 0; i < 10; i++ { + t.Logf("smallest seen %d: %d (%v)", i, keys[i], seenReduced[keys[i]]) + } + for i := len(keys) - 10; i < len(keys); i++ { + t.Logf("largest seen %d: %d (%v)", i, keys[i], seenReduced[keys[i]]) + } + + // Create the scripts. The first N-1 scripts will generate the smallest + // possible reduced values. + data := make([][]byte, nbScripts) + for i := 0; i < nbScripts-1; i++ { + values := seenReduced[keys[i]] + if len(values) > 1 { + // Report on different input entries that generate + // different siphash values that nevertheless reduce to + // the same final value. + sips := make([]uint64, len(values)) + mods := make([]uint64, len(values)) + for j := range values { + endian.PutUint32(aux, values[j]) + sips[j] = siphash.Hash(k0, k1, aux[1:]) + mods[j] = fastReduce(sips[j], modulusNM) + } + t.Logf("Collision on modulo reduction for script %d: "+ + "values %v, siphashes %v, modulos %v", i, values, + sips, mods) + } + data[i] = make([]byte, scriptSize) + endian.PutUint32(aux, values[0]) + copy(data[i], aux[1:]) + seenReduced[keys[i]] = values[1:] + } + + // The last script will generate the largest possible reduced value. The + // difference between the last and second-to-last scripts is the largest + // difference possible (when using this specific random key). + i := nbScripts - 1 + data[i] = make([]byte, scriptSize) + endian.PutUint32(aux, seenReduced[keys[len(keys)-1]][0]) + copy(data[i], aux[1:]) + + // Create filter and report results. + filter, err := gcs.NewFilterV2(B, M, key, data) + if err != nil { + t.Fatal(err) + } + sz := uint64(len(filter.Bytes())) + t.Logf("Key: %x", key) + t.Logf("Number of values: %d", len(seenReduced)) + t.Logf("Scripts: %d", len(data)) + + // The maximum possible quotient happens when the first nbScripts-1 + // scripts generate values from 0..nbScripts-2 and the last script + // generates the value modulusNM -1. + N := uint64(nbScripts) + maxQuotient := ((modulusNM - 1) - (N - 1)) >> 19 + t.Logf("Max Quotient: %d", maxQuotient) + + // The max possible size can be calculated by assuming every script + // will generate a value (i.e. no removed duplicates) and the last + // script will generate the max possible quotient. + maxPossible := gcs.MaxFilterV2Size(B, M, uint32(nbScripts)) + t.Logf("Max Possible size: %d, actual size: %d", maxPossible, sz) +} + +// TestRandomFilterSize generates filters with random data. +func TestRandomFilterSize(t *testing.T) { + // Use a unique random seed each test instance and log it if the tests + // fail. + seed := time.Now().UnixNano() + rng := rand.New(rand.NewSource(seed)) + defer func(t *testing.T, seed int64) { + if t.Failed() { + t.Logf("random seed: %d", seed) + } + }(t, seed) + scriptSize := 3 + nbScripts := wire.MaxBlockPayload / scriptSize + data := make([][]byte, nbScripts) + for i := range data { + data[i] = make([]byte, scriptSize) + } + var key [gcs.KeySize]byte + + // Generate a random filter and key. + for i := range data { + rng.Read(data[i]) + } + rng.Read(key[:]) + + // Check if it's larger than possible. + filter, err := gcs.NewFilterV2(B, M, key, data) + if err != nil { + t.Fatal(err) + } + sz := uint64(len(filter.Bytes())) + maxPossible := gcs.MaxFilterV2Size(B, M, uint32(nbScripts)) + t.Logf("Max Possible size: %d, actual size: %d", maxPossible, sz) + if sz > maxPossible { + t.Fatalf("Found a random filter with max size %d > max possible size %d", + sz, maxPossible) + } +} diff --git a/gcs/gcs.go b/gcs/gcs.go index 3735b47948..f31dc39968 100644 --- a/gcs/gcs.go +++ b/gcs/gcs.go @@ -549,3 +549,46 @@ func MakeHeaderForFilter(filter *FilterV1, prevHeader *chainhash.Hash) chainhash // The final filter hash is the blake256 of the hash computed above. return chainhash.Hash(blake256.Sum256(filterTip[:])) } + +// MaxFilterV2Size returns the maximum filter size possible for the given filter +// parameters. +func MaxFilterV2Size(B uint8, M uint64, N uint32) uint64 { + // The maximum (i.e. worst) filter size for V2 filters happens when the + // following conditions are met: + // + // - Every one of the N data entries is unique. + // - Every one of the N data entries produces a unique value after + // the siphash stage, ensuring no values are prematurely removed. + // - The quotient difference is maximized and produces the largest + // possible number of one bits in unary coding for the set of unique + // scripts. + // + // Given that the values are sorted prior being encoded with the + // Golomb/Rice coding, the largest possible quotient happens when the + // difference between two consecutive values is maximized. And that + // happens when the last value has the highest possible value and the + // second-to-last has the least possible value. + // + // The highest possible value after the modulo reduction stage is + // N*M-1. And the least possible value is zero. In other words, the + // first N-1 siphashed entries are mapped, after modulo reduction, to + // the value 0 and the last entry is mapped to the value N*M-1. + // + // Thus the largest possible difference is N*M-1, with the max + // number of bits of the quotient readily determined by shifting right + // that amount by B. + n := uint64(N) + b := uint64(B) + largestDiff := n*M - 1 + maxQuoBits := largestDiff >> b + + // Finally, the maximum size of the filter is determined by assuming + // each one of the N entries takes one bit for the 0 in the quotient + // encoding and B bits for the remainder, one entry takes an additional + // maxQuoBits for the quotient encoding, N is encoded as a varint and + // any necessary padding is added. + nSerSize := uint64(wire.VarIntSerializeSize(n)) + maxBits := n + n*b + maxQuoBits + maxBytes := (maxBits+7)/8 + nSerSize + return maxBytes +} diff --git a/gcs/go.mod b/gcs/go.mod index 95c6a5630e..f33ecad8d7 100644 --- a/gcs/go.mod +++ b/gcs/go.mod @@ -6,6 +6,7 @@ require ( github.com/dchest/siphash v1.2.3 github.com/decred/dcrd/blockchain/stake/v5 v5.0.0 github.com/decred/dcrd/chaincfg/chainhash v1.0.4 + github.com/decred/dcrd/chaincfg/v3 v3.2.0 github.com/decred/dcrd/crypto/blake256 v1.0.1 github.com/decred/dcrd/txscript/v4 v4.1.0 github.com/decred/dcrd/wire v1.6.0 @@ -14,7 +15,6 @@ require ( require ( github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 // indirect github.com/decred/base58 v1.0.5 // indirect - github.com/decred/dcrd/chaincfg/v3 v3.2.0 // indirect github.com/decred/dcrd/crypto/ripemd160 v1.0.2 // indirect github.com/decred/dcrd/database/v3 v3.0.1 // indirect github.com/decred/dcrd/dcrec v1.0.1 // indirect diff --git a/internal/blockchain/chainio.go b/internal/blockchain/chainio.go index ebf1eea9eb..9c8c2ef05b 100644 --- a/internal/blockchain/chainio.go +++ b/internal/blockchain/chainio.go @@ -851,6 +851,21 @@ func dbFetchGCSFilter(dbTx database.Tx, blockHash *chainhash.Hash) (*gcs.FilterV return filter, nil } +// dbFetchRawGCSFilter fetches the raw version 2 GCS filter for the passed +// block, without decoding it from the db. +// +// WARNING: the returned slice is only valid for the duration of the database +// transaction and MUST be copied to a new buffer if it is needed after the db +// transaction ends. +// +// This function is meant for cases where the raw filter bytes will be used +// without decoding into a gcs.FilterV2 value. For a safer alternative, use +// dbFetchGCSFilter. +func dbFetchRawGCSFilter(dbTx database.Tx, blockHash *chainhash.Hash) []byte { + filterBucket := dbTx.Metadata().Bucket(gcsFilterBucketName) + return filterBucket.Get(blockHash[:]) +} + // dbPutGCSFilter uses an existing database transaction to update the version 2 // GCS filter for the given block hash using the provided filter. func dbPutGCSFilter(dbTx database.Tx, blockHash *chainhash.Hash, filter *gcs.FilterV2) error { diff --git a/internal/blockchain/error.go b/internal/blockchain/error.go index 52141c12da..f7f4e24009 100644 --- a/internal/blockchain/error.go +++ b/internal/blockchain/error.go @@ -589,6 +589,14 @@ const ( // ErrSerializeHeader indicates an attempt to serialize a block header failed. ErrSerializeHeader = ErrorKind("ErrSerializeHeader") + // ErrNotAnAncestor indicates an attempt to fetch a chain of filters where + // the start hash is not an ancestor of the end hash. + ErrNotAnAncestor = ErrorKind("ErrNotAnAncestor") + + // ErrRequestTooLarge indicates an attempt to request too much data + // in a batched request. + ErrRequestTooLarge = ErrorKind("ErrRequestTooLarge") + // ------------------------------------------ // Errors related to the UTXO backend. // ------------------------------------------ diff --git a/internal/blockchain/error_test.go b/internal/blockchain/error_test.go index 73714ac898..2262e06edb 100644 --- a/internal/blockchain/error_test.go +++ b/internal/blockchain/error_test.go @@ -150,6 +150,8 @@ func TestErrorKindStringer(t *testing.T) { {ErrNoTreasuryBalance, "ErrNoTreasuryBalance"}, {ErrInvalidateGenesisBlock, "ErrInvalidateGenesisBlock"}, {ErrSerializeHeader, "ErrSerializeHeader"}, + {ErrNotAnAncestor, "ErrNotAnAncestor"}, + {ErrRequestTooLarge, "ErrRequestTooLarge"}, {ErrUtxoBackend, "ErrUtxoBackend"}, {ErrUtxoBackendCorruption, "ErrUtxoBackendCorruption"}, {ErrUtxoBackendNotOpen, "ErrUtxoBackendNotOpen"}, diff --git a/internal/blockchain/headercmt.go b/internal/blockchain/headercmt.go index 62fd1987d3..d47c5eefaf 100644 --- a/internal/blockchain/headercmt.go +++ b/internal/blockchain/headercmt.go @@ -191,3 +191,105 @@ func (b *BlockChain) FilterByBlockHash(hash *chainhash.Hash) (*gcs.FilterV2, *He } return filter, headerProof, nil } + +// LocateCFiltersV2 fetches all committed filters between startHash and endHash +// (inclusive) and prepares a MsgCFiltersV2 response to return this batch +// of CFilters to a remote peer. +// +// The start and end blocks must both exist and the start block must be an +// ancestor to the end block. +// +// This function is safe for concurrent access. +func (b *BlockChain) LocateCFiltersV2(startHash, endHash *chainhash.Hash) (*wire.MsgCFiltersV2, error) { + // Sanity check. + b.chainLock.RLock() + startNode := b.index.LookupNode(startHash) + if startNode == nil { + b.chainLock.RUnlock() + return nil, unknownBlockError(startHash) + } + endNode := b.index.LookupNode(endHash) + if endNode == nil { + b.chainLock.RUnlock() + return nil, unknownBlockError(endHash) + } + if !startNode.IsAncestorOf(endNode) { + b.chainLock.RUnlock() + str := fmt.Sprintf("start block %s is not an ancestor of end block %s", + startHash, endHash) + return nil, contextError(ErrNotAnAncestor, str) + } + + // Figure out size of the response. + nb := int(endNode.height - startNode.height + 1) + if nb > wire.MaxCFiltersV2PerBatch { + b.chainLock.RUnlock() + str := fmt.Sprintf("number of requested cfilters %d greater than max allowed %d", + nb, wire.MaxCFiltersV2PerBatch) + return nil, contextError(ErrRequestTooLarge, str) + } + + // Allocate all internal messages in a single buffer to reduce memory + // allocation counts. + filters := make([]wire.MsgCFilterV2, nb) + proofLeaves := make([][]chainhash.Hash, nb) + + // Fetch all relevant block hashes. + node := endNode + for i := nb - 1; i >= 0; i-- { + filters[i].BlockHash = node.hash + node = node.parent + } + + // At this point all index operations have completed, so release the + // RLock. + b.chainLock.RUnlock() + + // Prepare the response from DB. + err := b.db.View(func(dbTx database.Tx) error { + totalLen := 0 + data := make([][]byte, nb) + for i := 0; i < nb; i++ { + hash := &filters[i].BlockHash + cfData := dbFetchRawGCSFilter(dbTx, hash) + if cfData == nil { + str := fmt.Sprintf("no filter available for block %s", hash) + return contextError(ErrNoFilter, str) + } + + data[i] = cfData + totalLen += len(cfData) + + var err error + proofLeaves[i], err = dbFetchHeaderCommitments(dbTx, hash) + if err != nil { + return err + } + } + + // Allocate a single backing buffer for all cfilter data and + // copy each individual db buffer into it. This reduces + // memory allocation counts and fragmentation. + buffer := make([]byte, totalLen) + var j int + for i := 0; i < nb; i++ { + sz := len(data[i]) + copy(buffer[j:], data[i]) + filters[i].Data = buffer[j : j+sz : j+sz] + j += sz + } + return nil + }) + if err != nil { + return nil, err + } + + // Prepare the response. + const proofIndex = HeaderCmtFilterIndex + for i := 0; i < nb; i++ { + proofHashes := standalone.GenerateInclusionProof(proofLeaves[i], proofIndex) + filters[i].ProofHashes = proofHashes + filters[i].ProofIndex = proofIndex + } + return wire.NewMsgCFiltersV2(filters), nil +} diff --git a/peer/go.mod b/peer/go.mod index cb4b8ff8c8..614778a12c 100644 --- a/peer/go.mod +++ b/peer/go.mod @@ -12,6 +12,8 @@ require ( github.com/decred/slog v1.2.0 ) +replace github.com/decred/dcrd/wire => ../wire + require ( github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 // indirect github.com/dchest/siphash v1.2.3 // indirect diff --git a/peer/peer.go b/peer/peer.go index 438ac9de7b..8070f442d8 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -28,7 +28,7 @@ import ( const ( // MaxProtocolVersion is the max protocol version the peer supports. - MaxProtocolVersion = wire.RemoveRejectVersion + MaxProtocolVersion = wire.BatchedCFiltersV2Version // outputBufferSize is the number of elements the output channels use. outputBufferSize = 5000 @@ -129,6 +129,9 @@ type MessageListeners struct { // OnCFilterV2 is invoked when a peer receives a cfilterv2 wire message. OnCFilterV2 func(p *Peer, msg *wire.MsgCFilterV2) + // OnCFiltersV2 is invoked when a peer receives a cfiltersv2 wire message. + OnCFiltersV2 func(p *Peer, msg *wire.MsgCFiltersV2) + // OnCFHeaders is invoked when a peer receives a cfheaders wire // message. OnCFHeaders func(p *Peer, msg *wire.MsgCFHeaders) @@ -163,6 +166,10 @@ type MessageListeners struct { // message. OnGetCFilterV2 func(p *Peer, msg *wire.MsgGetCFilterV2) + // OnGetCFiltersV2 is invoked when a peer receives a getcfsv2 wire + // message. + OnGetCFiltersV2 func(p *Peer, msg *wire.MsgGetCFsV2) + // OnGetCFHeaders is invoked when a peer receives a getcfheaders // wire message. OnGetCFHeaders func(p *Peer, msg *wire.MsgGetCFHeaders) @@ -1423,6 +1430,16 @@ out: p.cfg.Listeners.OnCFilterV2(p, msg) } + case *wire.MsgGetCFsV2: + if p.cfg.Listeners.OnGetCFiltersV2 != nil { + p.cfg.Listeners.OnGetCFiltersV2(p, msg) + } + + case *wire.MsgCFiltersV2: + if p.cfg.Listeners.OnCFiltersV2 != nil { + p.cfg.Listeners.OnCFiltersV2(p, msg) + } + case *wire.MsgGetInitState: if p.cfg.Listeners.OnGetInitState != nil { p.cfg.Listeners.OnGetInitState(p, msg) diff --git a/peer/peer_test.go b/peer/peer_test.go index 1a2befb310..aa74fe6b1c 100644 --- a/peer/peer_test.go +++ b/peer/peer_test.go @@ -411,6 +411,12 @@ func TestPeerListeners(t *testing.T) { OnInitState: func(p *Peer, msg *wire.MsgInitState) { ok <- msg }, + OnGetCFiltersV2: func(p *Peer, msg *wire.MsgGetCFsV2) { + ok <- msg + }, + OnCFiltersV2: func(p *Peer, msg *wire.MsgCFiltersV2) { + ok <- msg + }, }, UserAgentName: "peer", UserAgentVersion: "1.0", @@ -565,6 +571,14 @@ func TestPeerListeners(t *testing.T) { "OnInitState", wire.NewMsgInitState(), }, + { + "OnGetCFiltersV2", + wire.NewMsgGetCFsV2(&chainhash.Hash{}, &chainhash.Hash{}), + }, + { + "OnCFiltersV2", + wire.NewMsgCFiltersV2([]wire.MsgCFilterV2{}), + }, } t.Logf("Running %d tests", len(tests)) for _, test := range tests { diff --git a/server.go b/server.go index ecfbd918d3..3a536419c9 100644 --- a/server.go +++ b/server.go @@ -74,7 +74,7 @@ const ( connectionRetryInterval = time.Second * 5 // maxProtocolVersion is the max protocol version the server supports. - maxProtocolVersion = wire.RemoveRejectVersion + maxProtocolVersion = wire.BatchedCFiltersV2Version // These fields are used to track known addresses on a per-peer basis. // @@ -1522,6 +1522,16 @@ func (sp *serverPeer) OnGetCFilterV2(_ *peer.Peer, msg *wire.MsgGetCFilterV2) { sp.QueueMessage(filterMsg, nil) } +// OnGetCFiltersV2 is invoked when a peer receives a getcfsv2 wire message. +func (sp *serverPeer) OnGetCFiltersV2(_ *peer.Peer, msg *wire.MsgGetCFsV2) { + filtersMsg, err := sp.server.chain.LocateCFiltersV2(&msg.StartHash, &msg.EndHash) + if err != nil { + return + } + + sp.QueueMessage(filtersMsg, nil) +} + // OnGetCFHeaders is invoked when a peer receives a getcfheader wire message. func (sp *serverPeer) OnGetCFHeaders(_ *peer.Peer, msg *wire.MsgGetCFHeaders) { // Disconnect and/or ban depending on the node cf services flag and @@ -2308,6 +2318,7 @@ func newPeerConfig(sp *serverPeer) *peer.Config { OnGetHeaders: sp.OnGetHeaders, OnGetCFilter: sp.OnGetCFilter, OnGetCFilterV2: sp.OnGetCFilterV2, + OnGetCFiltersV2: sp.OnGetCFiltersV2, OnGetCFHeaders: sp.OnGetCFHeaders, OnGetCFTypes: sp.OnGetCFTypes, OnGetAddr: sp.OnGetAddr, diff --git a/wire/error.go b/wire/error.go index 81a104f255..5554e1528c 100644 --- a/wire/error.go +++ b/wire/error.go @@ -149,6 +149,10 @@ const ( // ErrTooManyPrevMixMsgs is returned when too many previous messages of // a mix run are referenced by a message. ErrTooManyPrevMixMsgs + + // ErrTooManyCFilters is returned when the number of committed filters + // exceeds the maximum allowed in a batch. + ErrTooManyCFilters ) // Map of ErrorCode values back to their constant names for pretty printing. @@ -188,6 +192,7 @@ var errorCodeStrings = map[ErrorCode]string{ ErrMixPairReqScriptClassTooLong: "ErrMixPairReqScriptClassTooLong", ErrTooManyMixPairReqUTXOs: "ErrTooManyMixPairReqUTXOs", ErrTooManyPrevMixMsgs: "ErrTooManyPrevMixMsgs", + ErrTooManyCFilters: "ErrTooManyCFilters", } // String returns the ErrorCode as a human-readable name. diff --git a/wire/error_test.go b/wire/error_test.go index 4c162f88dd..30f98e9dee 100644 --- a/wire/error_test.go +++ b/wire/error_test.go @@ -54,6 +54,7 @@ func TestMessageErrorCodeStringer(t *testing.T) { {ErrMixPairReqScriptClassTooLong, "ErrMixPairReqScriptClassTooLong"}, {ErrTooManyMixPairReqUTXOs, "ErrTooManyMixPairReqUTXOs"}, {ErrTooManyPrevMixMsgs, "ErrTooManyPrevMixMsgs"}, + {ErrTooManyCFilters, "ErrTooManyCFilters"}, {0xffff, "Unknown ErrorCode (65535)"}, } diff --git a/wire/message.go b/wire/message.go index 82d96847bd..e5dca165dd 100644 --- a/wire/message.go +++ b/wire/message.go @@ -60,6 +60,8 @@ const ( CmdMixDCNet = "mixdcnet" CmdMixConfirm = "mixconfirm" CmdMixSecrets = "mixsecrets" + CmdGetCFiltersV2 = "getcfsv2" + CmdCFiltersV2 = "cfiltersv2" ) const ( @@ -217,6 +219,12 @@ func makeEmptyMessage(command string) (Message, error) { case CmdMixSecrets: msg = &MsgMixSecrets{} + case CmdGetCFiltersV2: + msg = &MsgGetCFsV2{} + + case CmdCFiltersV2: + msg = &MsgCFiltersV2{} + default: str := fmt.Sprintf("unhandled command [%s]", command) return nil, messageError(op, ErrUnknownCmd, str) diff --git a/wire/msgcfiltersv2.go b/wire/msgcfiltersv2.go new file mode 100644 index 0000000000..352e323754 --- /dev/null +++ b/wire/msgcfiltersv2.go @@ -0,0 +1,166 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MaxCFiltersV2PerBatch is the maximum number of committed filters that may +// be sent in a cfiltersv2 message. +// +// This number has been decided assuming a sequence of blocks filled with +// transactions specifically designed to maximize the filter size, such that +// a cfiltersv2 message will not be larger than the maximum allowed P2P message +// size on any of the currently deployed networks. +const MaxCFiltersV2PerBatch = 100 + +func init() { + // The following runtime assertions are included to ensure if any of + // the input constants used to determine the MaxCFiltersV2PerBatch + // number are changed from their original values, then that constant is + // reviewed to still be compatible to the new protocol or consensus + // constants. + // + // In particular, the max number of cfilters in a batched reply has + // been determined by assuming a sequence of blocks of size + // MaxBlockPayload (1.25 MiB), filled with a transaction with as many + // OP_RETURNs as needed to fill a worst-case v2 filter (251581 bytes). + // + // At 100 filters per batch, with the added overhead of the commitment + // proof, the maximum size of a cfiltersv2 message should be 25 MiB, + // which is less than the maximum P2P message size of 32 MiB. + // + // Check the MaxTestSize test from the blockcf2 package for information + // on how the maximum v2 cfilter size is determined. + // + // If any of these assertions break, such that a change to + // MaxCFiltersV2PerBatch is necessary, then a new protocol version + // should be introduced to allow a modification to + // MaxCFiltersV2PerBatch. + switch { + case MaxCFiltersV2PerBatch != 100: + panic("review MaxCFiltersV2PerBatch due to constant change") + case MaxBlockPayload != 1310720: + panic("review MaxCFiltersV2PerBatch due to MaxBlockPayload change") + case MaxCFilterDataSize != 256*1024: + panic("review MaxCFiltersV2PerBatch due to MaxCFilterDataSize change") + case MaxMessagePayload != 1024*1024*32: + panic("review MaxCFiltersV2PerBatch due to MaxMessagePayload change") + case (&MsgCFiltersV2{}).MaxPayloadLength(BatchedCFiltersV2Version) != 26321001: + panic("review MaxCFiltersV2PerBatch due to MaxPayloadLength change") + case (&MsgCFiltersV2{}).MaxPayloadLength(ProtocolVersion) != 26321001: + panic("review MaxCFiltersV2PerBatch due to MaxPayloadLength change") + } +} + +// MsgCFiltersV2 implements the Message interface and represents a cfiltersv2 +// message. It is used to deliver a batch of version 2 committed gcs filters +// for a given range of blocks, along with a proof that can be used to prove +// the filter is committed to by the block header. +// +// It is delivered in response to a getcfsv2 message (MsgGetCFiltersV2). +type MsgCFiltersV2 struct { + CFilters []MsgCFilterV2 +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgCFiltersV2) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgCFiltersV2.BtcDecode" + if pver < BatchedCFiltersV2Version { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + nbCFilters, err := ReadVarInt(r, pver) + if err != nil { + return err + } + + if nbCFilters > MaxCFiltersV2PerBatch { + msg := fmt.Sprintf("%s too many cfilters sent in batch "+ + "[count %v max %v]", msg.Command(), nbCFilters, + MaxCFiltersV2PerBatch) + return messageError(op, ErrTooManyCFilters, msg) + } + + msg.CFilters = make([]MsgCFilterV2, nbCFilters) + for i := 0; i < int(nbCFilters); i++ { + cf := &msg.CFilters[i] + err := cf.BtcDecode(r, pver) + if err != nil { + return err + } + } + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgCFiltersV2) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgCFiltersV2.BtcEncode" + if pver < BatchedCFiltersV2Version { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + nbCFilters := len(msg.CFilters) + if nbCFilters > MaxCFiltersV2PerBatch { + msg := fmt.Sprintf("%s too many cfilters to send in batch "+ + "[count %v max %v]", msg.Command(), nbCFilters, + MaxCFiltersV2PerBatch) + return messageError(op, ErrTooManyCFilters, msg) + } + + err := WriteVarInt(w, pver, uint64(nbCFilters)) + if err != nil { + return err + } + + for i := 0; i < nbCFilters; i++ { + err = msg.CFilters[i].BtcEncode(w, pver) + if err != nil { + return err + } + } + + return nil +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgCFiltersV2) Command() string { + return CmdCFiltersV2 +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgCFiltersV2) MaxPayloadLength(pver uint32) uint32 { + // Varint + n * individual cfilter message: + // Block hash + max filter data (including varint) + + // proof index + max num proof hashes (including varint). + return uint32(VarIntSerializeSize(MaxCFiltersV2PerBatch)) + + (chainhash.HashSize+ + uint32(VarIntSerializeSize(MaxCFilterDataSize))+ + MaxCFilterDataSize+4+ + uint32(VarIntSerializeSize(MaxHeaderProofHashes))+ + (MaxHeaderProofHashes*chainhash.HashSize))*uint32(MaxCFiltersV2PerBatch) +} + +// NewMsgCFiltersV2 returns a new cfiltersv2 message that conforms to the +// Message interface using the passed parameters and defaults for the remaining +// fields. +func NewMsgCFiltersV2(filters []MsgCFilterV2) *MsgCFiltersV2 { + return &MsgCFiltersV2{ + CFilters: filters, + } +} diff --git a/wire/msgcfiltersv2_test.go b/wire/msgcfiltersv2_test.go new file mode 100644 index 0000000000..24c74057a0 --- /dev/null +++ b/wire/msgcfiltersv2_test.go @@ -0,0 +1,335 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "io" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// baseMsgCFiltersV2 returns a MsgCFiltersV2 struct populated with mock values +// that are used throughout tests. Note that the tests will need to be updated +// if these values are changed since they rely on the current values. +func baseMsgCFiltersV2(t *testing.T) *MsgCFiltersV2 { + t.Helper() + + filters := []MsgCFilterV2{ + *baseMsgCFilterV2(t), + } + + return NewMsgCFiltersV2(filters) +} + +// TestCFiltersV2 tests the MsgCFiltersV2 API against the latest protocol +// version. +func TestCFiltersV2(t *testing.T) { + pver := ProtocolVersion + + // Ensure the command is expected value. + wantCmd := "cfiltersv2" + msg := baseMsgCFiltersV2(t) + if cmd := msg.Command(); cmd != wantCmd { + t.Errorf("NewMsgCFiltersV2: wrong command - got %v want %v", cmd, + wantCmd) + } + + // Ensure max payload is expected value for latest protocol version. + // varint max number of cfilters + max number of cfilters * + // (Block hash + max commitment name length (including varint) + + // proof index + max num proof hashes (including varint).) + wantPayload := uint32(26321001) + maxPayload := msg.MaxPayloadLength(pver) + if maxPayload != wantPayload { + t.Errorf("MaxPayloadLength: wrong max payload length for protocol "+ + "version %d - got %v, want %v", pver, maxPayload, wantPayload) + } + + // Ensure encoding max number of cfilters with max cfilter data and + // max proof hashes returns no error. + maxData := make([]byte, MaxCFilterDataSize) + maxProofHashes := make([]chainhash.Hash, MaxHeaderProofHashes) + msg.CFilters = make([]MsgCFilterV2, 0, MaxCFiltersV2PerBatch) + for len(msg.CFilters) < MaxCFiltersV2PerBatch { + cf := baseMsgCFilterV2(t) + cf.Data = maxData + cf.ProofHashes = maxProofHashes + msg.CFilters = append(msg.CFilters, *cf) + } + + var buf bytes.Buffer + if err := msg.BtcEncode(&buf, pver); err != nil { + t.Fatal(err) + } + + // Ensure the maximum actually encoded length is less than or equal to + // the max payload length. + if uint32(buf.Len()) > maxPayload { + t.Fatalf("Largest message encoded to a buffer larger than the "+ + " MaxPayloadLength for protocol version %d - got %v, want %v", + pver, buf.Len(), maxPayload) + } +} + +// TestCFiltersV2PreviousProtocol tests the MsgCFiltersV2 API against the protocol +// prior to version BatchedCFiltersV2Version. +func TestCFiltersV2PreviousProtocol(t *testing.T) { + // Use the protocol version just prior to BatchedCFiltersV2Version changes. + pver := BatchedCFiltersV2Version - 1 + + msg := baseMsgCFiltersV2(t) + + // Test encode with old protocol version. + var buf bytes.Buffer + err := msg.BtcEncode(&buf, pver) + if !errors.Is(err, ErrMsgInvalidForPVer) { + t.Errorf("unexpected error when encoding for protocol version %d, "+ + "prior to message introduction - got %v, want %v", pver, + err, ErrMsgInvalidForPVer) + } + + // Test decode with old protocol version. + var readmsg MsgCFiltersV2 + err = readmsg.BtcDecode(&buf, pver) + if !errors.Is(err, ErrMsgInvalidForPVer) { + t.Errorf("unexpected error when decoding for protocol version %d, "+ + "prior to message introduction - got %v, want %v", pver, + err, ErrMsgInvalidForPVer) + } +} + +// TestCFiltersV2CrossProtocol tests the MsgCFiltersV2 API when encoding with +// the latest protocol version and decoding with BatchedCFiltersV2Version. +func TestCFiltersV2CrossProtocol(t *testing.T) { + msg := baseMsgCFiltersV2(t) + + // Encode with latest protocol version. + var buf bytes.Buffer + err := msg.BtcEncode(&buf, ProtocolVersion) + if err != nil { + t.Errorf("encode of MsgCFiltersV2 failed %v err <%v>", msg, err) + } + + // Decode with old protocol version. + var readmsg MsgCFiltersV2 + err = readmsg.BtcDecode(&buf, BatchedCFiltersV2Version) + if err != nil { + t.Errorf("decode of MsgCFiltersV2 failed [%v] err <%v>", buf, err) + } +} + +// TestCFiltersV2Wire tests the MsgCFiltersV2 wire encode and decode for various +// protocol versions. +func TestCFiltersV2Wire(t *testing.T) { + msgCFiltersV2 := baseMsgCFiltersV2(t) + msgCFiltersV2Encoded := []byte{ + 0x01, // Varint for number of filters + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0x1d, // Varint for filter data length + 0x00, 0x00, 0x00, 0x11, 0x1c, 0xa3, 0xaa, 0xfb, + 0x02, 0x30, 0x74, 0xdc, 0x5b, 0xf2, 0x49, 0x8d, + 0xf7, 0x91, 0xb7, 0xd6, 0xe8, 0x46, 0xe9, 0xf5, + 0x01, 0x60, 0x06, 0xd6, 0x00, // Filter data + 0x00, 0x00, 0x00, 0x00, // Proof index + 0x01, // Varint for num proof hashes + 0x47, 0x63, 0x69, 0x67, 0x50, 0xe6, 0x72, 0x86, + 0x7f, 0x91, 0x00, 0x68, 0x79, 0x94, 0x18, 0xdb, + 0x8d, 0xa6, 0x07, 0xba, 0xf2, 0x28, 0x08, 0x55, + 0x22, 0x48, 0xb5, 0xd0, 0xb9, 0x5f, 0x89, 0xb4, // first proof hash + } + + tests := []struct { + in *MsgCFiltersV2 // Message to encode + out *MsgCFiltersV2 // Expected decoded message + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + }{{ + // Latest protocol version. + msgCFiltersV2, + msgCFiltersV2, + msgCFiltersV2Encoded, + ProtocolVersion, + }, { + // Protocol version BatchedCFiltersV2Version+1. + msgCFiltersV2, + msgCFiltersV2, + msgCFiltersV2Encoded, + BatchedCFiltersV2Version + 1, + }, { + // Protocol version BatchedCFiltersV2Version. + msgCFiltersV2, + msgCFiltersV2, + msgCFiltersV2Encoded, + BatchedCFiltersV2Version, + }} + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode the message to wire format. + var buf bytes.Buffer + err := test.in.BtcEncode(&buf, test.pver) + if err != nil { + t.Errorf("BtcEncode #%d error %v", i, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.buf) { + t.Errorf("BtcEncode #%d\n got: %s want: %s", i, + spew.Sdump(buf.Bytes()), spew.Sdump(test.buf)) + continue + } + + // Decode the message from wire format. + var msg MsgCFiltersV2 + rbuf := bytes.NewReader(test.buf) + err = msg.BtcDecode(rbuf, test.pver) + if err != nil { + t.Errorf("BtcDecode #%d error %v", i, err) + continue + } + if !reflect.DeepEqual(&msg, test.out) { + t.Errorf("BtcDecode #%d\n got: %s want: %s", i, + spew.Sdump(&msg), spew.Sdump(test.out)) + continue + } + } +} + +// TestCFiltersV2WireErrors performs negative tests against wire encode and +// decode of MsgCFiltersV2 to confirm error paths work correctly. +func TestCFiltersV2WireErrors(t *testing.T) { + pver := ProtocolVersion + + // Message with valid mock values. + baseCFiltersV2 := baseMsgCFiltersV2(t) + baseCFiltersV2Encoded := []byte{ + 0x01, // Varint for number of cfilters + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0x1d, // Varint for filter data length + 0x00, 0x00, 0x00, 0x11, 0x1c, 0xa3, 0xaa, 0xfb, + 0x02, 0x30, 0x74, 0xdc, 0x5b, 0xf2, 0x49, 0x8d, + 0xf7, 0x91, 0xb7, 0xd6, 0xe8, 0x46, 0xe9, 0xf5, + 0x01, 0x60, 0x06, 0xd6, 0x00, // Filter data + 0x00, 0x00, 0x00, 0x00, // Proof index + 0x01, // Varint for num proof hashes + 0x47, 0x63, 0x69, 0x67, 0x50, 0xe6, 0x72, 0x86, + 0x7f, 0x91, 0x00, 0x68, 0x79, 0x94, 0x18, 0xdb, + 0x8d, 0xa6, 0x07, 0xba, 0xf2, 0x28, 0x08, 0x55, + 0x22, 0x48, 0xb5, 0xd0, 0xb9, 0x5f, 0x89, 0xb4, // first proof hash + } + + // Message that forces an error by having a data that exceeds the max + // allowed length. + badFilterData := bytes.Repeat([]byte{0x00}, MaxCFilterDataSize+1) + maxDataCFiltersV2 := baseMsgCFiltersV2(t) + maxDataCFiltersV2.CFilters[0].Data = badFilterData + maxDataCFiltersV2Encoded := []byte{ + 0x01, // Varint for number of cfilters + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0xfe, 0x01, 0x00, 0x04, 0x00, // Varint for filter data length + } + + // Message that forces an error by having more than the max allowed proof + // hashes. + maxHashesCFiltersV2 := baseMsgCFiltersV2(t) + maxHashesCFiltersV2.CFilters[0].ProofHashes = make([]chainhash.Hash, MaxHeaderProofHashes+1) + maxHashesCFiltersV2Encoded := []byte{ + 0x01, // Varint for number of cfilters + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0x1d, // Varint for filter data length + 0x00, 0x00, 0x00, 0x11, 0x1c, 0xa3, 0xaa, 0xfb, + 0x02, 0x30, 0x74, 0xdc, 0x5b, 0xf2, 0x49, 0x8d, + 0xf7, 0x91, 0xb7, 0xd6, 0xe8, 0x46, 0xe9, 0xf5, + 0x01, 0x60, 0x06, 0xd6, 0x00, // Filter data + 0x00, 0x00, 0x00, 0x00, // Proof index + 0x21, // Varint for num proof hashes + } + + // Message that forces an error by having more than the max allowed number + // of cfilters. + maxCFiltersV2 := baseMsgCFiltersV2(t) + for len(maxCFiltersV2.CFilters) < MaxCFiltersV2PerBatch+1 { + maxCFiltersV2.CFilters = append(maxCFiltersV2.CFilters, *baseMsgCFilterV2(t)) + } + maxCFiltersV2Encoded := []byte{ + 0x65, // Varint for number of cfilters + } + + tests := []struct { + in *MsgCFiltersV2 // Value to encode + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + max int // Max size of fixed buffer to induce errors + writeErr error // Expected write error + readErr error // Expected read error + }{ + // Force error in cfilter number varint. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 0, io.ErrShortWrite, io.EOF}, + // Force error in start of block hash. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 1, io.ErrShortWrite, io.EOF}, + // Force error in middle of block hash. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 9, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error in filter data len. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 33, io.ErrShortWrite, io.EOF}, + // Force error in start of filter data. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 34, io.ErrShortWrite, io.EOF}, + // Force error in middle of filter data. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 46, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error in start of proof index. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 63, io.ErrShortWrite, io.EOF}, + // Force error in middle of proof index. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 65, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error in num proof hashes. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 67, io.ErrShortWrite, io.EOF}, + // Force error in start of first proof hash. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 68, io.ErrShortWrite, io.EOF}, + // Force error in middle of first proof hash. + {baseCFiltersV2, baseCFiltersV2Encoded, pver, 78, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error with greater than max filter data. + {maxDataCFiltersV2, maxDataCFiltersV2Encoded, pver, 38, ErrFilterTooLarge, ErrVarBytesTooLong}, + // Force error with greater than max proof hashes. + {maxHashesCFiltersV2, maxHashesCFiltersV2Encoded, pver, 68, ErrTooManyProofs, ErrTooManyProofs}, + // Force error with greater than max cfilters. + {maxCFiltersV2, maxCFiltersV2Encoded, pver, 1, ErrTooManyCFilters, ErrTooManyCFilters}, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode to wire format. + w := newFixedWriter(test.max) + err := test.in.BtcEncode(w, test.pver) + if !errors.Is(err, test.writeErr) { + t.Errorf("BtcEncode #%d wrong error got: %v, want: %v", i, err, + test.writeErr) + continue + } + + // Decode from wire format. + var msg MsgCFiltersV2 + r := newFixedReader(test.max, test.buf) + err = msg.BtcDecode(r, test.pver) + if !errors.Is(err, test.readErr) { + t.Errorf("BtcDecode #%d wrong error got: %v, want: %v", i, err, + test.readErr) + continue + } + } +} diff --git a/wire/msggetcfsv2.go b/wire/msggetcfsv2.go new file mode 100644 index 0000000000..53fb6c17bf --- /dev/null +++ b/wire/msggetcfsv2.go @@ -0,0 +1,73 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MsgGetCFsV2 implements the Message interface and represents a decred +// getcfsv2 message. It is used to request a batch of version 2 committed +// filters that span a subset of a chain, from StartHash up to (and including) +// EndHash. The response is sent in a MsgCFiltersV2 message. +// +// At most MaxCFiltersV2PerBatch may be requested by each MsgGetCFsV2 +// message, which means the number of blocks between EndHash and StartHash must +// be lesser than or equal to that constant's value. +type MsgGetCFsV2 struct { + StartHash chainhash.Hash + EndHash chainhash.Hash +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgGetCFsV2) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgGetCFsV2.BtcDecode" + if pver < BatchedCFiltersV2Version { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + return readElements(r, &msg.StartHash, &msg.EndHash) +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgGetCFsV2) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgGetCFsV2.BtcEncode" + if pver < BatchedCFiltersV2Version { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + return writeElements(w, &msg.StartHash, msg.EndHash) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgGetCFsV2) Command() string { + return CmdGetCFiltersV2 +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgGetCFsV2) MaxPayloadLength(pver uint32) uint32 { + // Block hash. + return chainhash.HashSize * 2 +} + +// NewMsgGetCFsV2 returns a new Decred getcfiltersv2 message that conforms +// to the Message interface using the passed parameters. +func NewMsgGetCFsV2(startHash, endHash *chainhash.Hash) *MsgGetCFsV2 { + return &MsgGetCFsV2{ + StartHash: *startHash, + EndHash: *endHash, + } +} diff --git a/wire/msggetcfsv2_test.go b/wire/msggetcfsv2_test.go new file mode 100644 index 0000000000..875fbc88f7 --- /dev/null +++ b/wire/msggetcfsv2_test.go @@ -0,0 +1,246 @@ +// Copyright (c) 2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "io" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// baseMsgGetCFsV2 returns a MsgGetCFsV2 struct populated with mock +// values that are used throughout tests. Note that the tests will need to be +// updated if these values are changed since they rely on the current values. +func baseMsgGetCFsV2(t *testing.T) *MsgGetCFsV2 { + t.Helper() + + // Mock block hash. + startHashStr := "000000000000c41019872ff7db8fd2e9bfa05f42d3f8fee8e895e8c1e5b8dcba" + startHash, err := chainhash.NewHashFromStr(startHashStr) + if err != nil { + t.Fatalf("NewHashFromStr: %v", err) + } + + endHashStr := "00000000000108ac3e3f51a0f4424dd757a3b0485da0ec96592f637f27bd1cf5" + endHash, err := chainhash.NewHashFromStr(endHashStr) + if err != nil { + t.Fatalf("NewHashFromStr: %v", err) + } + + return NewMsgGetCFsV2(startHash, endHash) +} + +// TestGetCFiltersV2 tests the MsgGetCFsV2 API against the latest protocol +// version. +func TestGetCFiltersV2(t *testing.T) { + pver := ProtocolVersion + + // Ensure the command is expected value. + wantCmd := "getcfsv2" + msg := baseMsgGetCFsV2(t) + if cmd := msg.Command(); cmd != wantCmd { + t.Errorf("NewMsgGetCFsV2: wrong command - got %v want %v", cmd, + wantCmd) + } + + // Ensure max payload is expected value for latest protocol version. + // Start hash + end hash. + wantPayload := uint32(64) + maxPayload := msg.MaxPayloadLength(pver) + if maxPayload != wantPayload { + t.Errorf("MaxPayloadLength: wrong max payload length for protocol "+ + "version %d - got %v, want %v", pver, maxPayload, wantPayload) + } + + // Ensure max payload length is not more than MaxMessagePayload. + if maxPayload > MaxMessagePayload { + t.Fatalf("MaxPayloadLength: payload length (%v) for protocol version "+ + "%d exceeds MaxMessagePayload (%v).", maxPayload, pver, + MaxMessagePayload) + } +} + +// TestGetCFiltersV2PreviousProtocol tests the MsgGetCFsV2 API against the +// protocol prior to version BatchedCFiltersV2Version. +func TestGetCFiltersV2PreviousProtocol(t *testing.T) { + // Use the protocol version just prior to CFilterV2Version changes. + pver := BatchedCFiltersV2Version - 1 + + msg := baseMsgGetCFsV2(t) + + // Test encode with old protocol version. + var buf bytes.Buffer + err := msg.BtcEncode(&buf, pver) + if !errors.Is(err, ErrMsgInvalidForPVer) { + t.Errorf("unexpected error when encoding for protocol version %d, "+ + "prior to message introduction - got %v, want %v", pver, + err, ErrMsgInvalidForPVer) + } + + // Test decode with old protocol version. + var readmsg MsgGetCFsV2 + err = readmsg.BtcDecode(&buf, pver) + if !errors.Is(err, ErrMsgInvalidForPVer) { + t.Errorf("unexpected error when decoding for protocol version %d, "+ + "prior to message introduction - got %v, want %v", pver, + err, ErrMsgInvalidForPVer) + } +} + +// TestGetCFiltersV2CrossProtocol tests the MsgGetCFsV2 API when encoding +// with the latest protocol version and decoding with BatchedCFiltersV2Version. +func TestGetCFiltersV2CrossProtocol(t *testing.T) { + msg := baseMsgGetCFsV2(t) + + // Encode with latest protocol version. + var buf bytes.Buffer + err := msg.BtcEncode(&buf, ProtocolVersion) + if err != nil { + t.Errorf("encode of MsgGetCFsV2 failed %v err <%v>", msg, err) + } + + // Decode with old protocol version. + var readmsg MsgGetCFilterV2 + err = readmsg.BtcDecode(&buf, BatchedCFiltersV2Version) + if err != nil { + t.Errorf("decode of MsgGetCFsV2 failed [%v] err <%v>", buf, err) + } +} + +// TestGetCFiltersV2Wire tests the MsgGetCFsV2 wire encode and decode for +// various commitment names and protocol versions. +func TestGetCFiltersV2Wire(t *testing.T) { + // MsgGetCFsV2 message with mock block hashes. + msgGetCFsV2 := baseMsgGetCFsV2(t) + msgGetCFsV2Encoded := []byte{ + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock start hash + 0xf5, 0x1c, 0xbd, 0x27, 0x7f, 0x63, 0x2f, 0x59, + 0x96, 0xec, 0xa0, 0x5d, 0x48, 0xb0, 0xa3, 0x57, + 0xd7, 0x4d, 0x42, 0xf4, 0xa0, 0x51, 0x3f, 0x3e, + 0xac, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock end hash + } + + tests := []struct { + in *MsgGetCFsV2 // Message to encode + out *MsgGetCFsV2 // Expected decoded message + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + }{{ + // Latest protocol version. + msgGetCFsV2, + msgGetCFsV2, + msgGetCFsV2Encoded, + ProtocolVersion, + }, { + // Protocol version CFilterV2Version+1. + msgGetCFsV2, + msgGetCFsV2, + msgGetCFsV2Encoded, + BatchedCFiltersV2Version + 1, + }, { + // Protocol version CFilterV2Version. + msgGetCFsV2, + msgGetCFsV2, + msgGetCFsV2Encoded, + BatchedCFiltersV2Version, + }} + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode the message to wire format. + var buf bytes.Buffer + err := test.in.BtcEncode(&buf, test.pver) + if err != nil { + t.Errorf("BtcEncode #%d error %v", i, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.buf) { + t.Errorf("BtcEncode #%d\n got: %s want: %s", i, + spew.Sdump(buf.Bytes()), spew.Sdump(test.buf)) + continue + } + + // Decode the message from wire format. + var msg MsgGetCFsV2 + rbuf := bytes.NewReader(test.buf) + err = msg.BtcDecode(rbuf, test.pver) + if err != nil { + t.Errorf("BtcDecode #%d error %v", i, err) + continue + } + if !reflect.DeepEqual(&msg, test.out) { + t.Errorf("BtcDecode #%d\n got: %s want: %s", i, spew.Sdump(&msg), + spew.Sdump(test.out)) + continue + } + } +} + +// TestGetCFiltersV2WireErrors performs negative tests against wire encode and +// decode of MsgGetCFsV2 to confirm error paths work correctly. +func TestGetCFiltersV2WireErrors(t *testing.T) { + pver := ProtocolVersion + + // MsgGetCFilterV2 message with mock block hash. + baseGetCFiltersV2 := baseMsgGetCFsV2(t) + baseGetCFiltersV2Encoded := []byte{ + 0xba, 0xdc, 0xb8, 0xe5, 0xc1, 0xe8, 0x95, 0xe8, + 0xe8, 0xfe, 0xf8, 0xd3, 0x42, 0x5f, 0xa0, 0xbf, + 0xe9, 0xd2, 0x8f, 0xdb, 0xf7, 0x2f, 0x87, 0x19, + 0x10, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock block hash + 0xf5, 0x1c, 0xbd, 0x27, 0x7f, 0x63, 0x2f, 0x59, + 0x96, 0xec, 0xa0, 0x5d, 0x48, 0xb0, 0xa3, 0x57, + 0xd7, 0x4d, 0x42, 0xf4, 0xa0, 0x51, 0x3f, 0x3e, + 0xac, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, // Mock end hash + } + + tests := []struct { + in *MsgGetCFsV2 // Value to encode + buf []byte // Wire encoding + pver uint32 // Protocol version for wire encoding + max int // Max size of fixed buffer to induce errors + writeErr error // Expected write error + readErr error // Expected read error + }{ + // Force error in start of start hash. + {baseGetCFiltersV2, baseGetCFiltersV2Encoded, pver, 0, io.ErrShortWrite, io.EOF}, + // Force error in middle of start hash. + {baseGetCFiltersV2, baseGetCFiltersV2Encoded, pver, 8, io.ErrShortWrite, io.ErrUnexpectedEOF}, + // Force error in start of end hash. + {baseGetCFiltersV2, baseGetCFiltersV2Encoded, pver, 32, io.ErrShortWrite, io.EOF}, + // Force error in middle of end hash. + {baseGetCFiltersV2, baseGetCFiltersV2Encoded, pver, 40, io.ErrShortWrite, io.ErrUnexpectedEOF}, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + // Encode to wire format. + w := newFixedWriter(test.max) + err := test.in.BtcEncode(w, test.pver) + if !errors.Is(err, test.writeErr) { + t.Errorf("BtcEncode #%d wrong error got: %v, want: %v", i, err, + test.writeErr) + continue + } + + // Decode from wire format. + var msg MsgGetCFsV2 + r := newFixedReader(test.max, test.buf) + err = msg.BtcDecode(r, test.pver) + if !errors.Is(err, test.readErr) { + t.Errorf("BtcDecode #%d wrong error got: %v, want: %v", i, err, + test.readErr) + continue + } + } +} diff --git a/wire/protocol.go b/wire/protocol.go index c2585d4380..1e0c0e497a 100644 --- a/wire/protocol.go +++ b/wire/protocol.go @@ -17,7 +17,7 @@ const ( InitialProcotolVersion uint32 = 1 // ProtocolVersion is the latest protocol version this package supports. - ProtocolVersion uint32 = 10 + ProtocolVersion uint32 = 11 // NodeBloomVersion is the protocol version which added the SFNodeBloom // service flag (unused). @@ -54,6 +54,10 @@ const ( // MixVersion is the protocol version which adds peer-to-peer mixing. MixVersion uint32 = 10 + + // BatchedCFiltersV2Version is the protocol version which adds support + // for the batched getcfsv2 and cfiltersv2 messages. + BatchedCFiltersV2Version uint32 = 11 ) // ServiceFlag identifies services supported by a Decred peer.