From 586600723ba9d7426b165e2167583866dc50bad5 Mon Sep 17 00:00:00 2001 From: Josh Rickmar Date: Tue, 30 Apr 2024 18:51:48 +0000 Subject: [PATCH] wire: add previous revealed secrets hashes to RS message For proper blaming behavior, peers who publish their secrets are blamed and removed from the following run if all secrets were revealed, but no other misbehavior was detected. In order to correctly pin the blame on the peers who disrupted the mix by initially revealing their secrets messages, this blaming will only be triggered when a reveal secrets message is received that does not reference any other received secrets messages. Because the secrets message hash now commits to these previous hashes by its signature, an additional method is added to return the commitment hash (to be published and verified in key exchange messages) that does not hash the previous messages (as they do not exist at the time of creating the key exchange). While here, an issue was discovered and corrected in the serialization of the MixVect type. When the vector has zero length, deserializing would return early after reading the first 0 dimension, but serialization was writing both the count and message size dimensions. This was corrected by changing the serialization method to return early if the count is zero. --- mixing/mixclient/client.go | 1566 ++++++++++++++++++++++++++++++++++++ mixing/mixpool/mixpool.go | 1340 ++++++++++++++++++++++++++++++ wire/msgmixsecrets.go | 69 +- wire/msgmixsecrets_test.go | 19 +- 4 files changed, 2985 insertions(+), 9 deletions(-) create mode 100644 mixing/mixclient/client.go create mode 100644 mixing/mixpool/mixpool.go diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go new file mode 100644 index 0000000000..099aff0cf2 --- /dev/null +++ b/mixing/mixclient/client.go @@ -0,0 +1,1566 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package mixclient + +import ( + "bytes" + "context" + cryptorand "crypto/rand" + "crypto/subtle" + "errors" + "fmt" + "hash" + "io" + "math/big" + "math/bits" + "sort" + "sync" + "time" + + "decred.org/cspp/v2/solverrpc" + "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/crypto/blake256" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/decred/dcrd/dcrutil/v4" + "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/internal/chacha20prng" + "github.com/decred/dcrd/mixing/mixpool" + "github.com/decred/dcrd/txscript/v4" + "github.com/decred/dcrd/wire" +) + +// MinPeers is the minimum number of peers required for a mix run to proceed. +const MinPeers = 3 + +// ErrExpired indicates that a dicemix session failed to complete due to the +// submitted pair request expiring. +var ErrExpired = errors.New("mixing pair request expired") + +// Wallet signs mix transactions and listens for and broadcasts mixing +// protocol messages. +// +// While wallets are responsible for generating mixed addresses, this duty is +// performed by the generator function provided to NewCoinJoin rather than +// this interface. This allows each CoinJoin to pass in a generator closure +// for different BIP0032 accounts and branches. +type Wallet interface { + BestBlock() (uint32, chainhash.Hash) + + // Mixpool returns access to the wallet's mixing message pool. + // + // The mixpool should only be used for message access and deletion, + // but never publishing; SubmitMixMessage must be used instead for + // message publishing. + Mixpool() *mixpool.Pool + + // SubmitMixMessage submits a mixing message to the wallet's mixpool + // and broadcasts it to the network. + SubmitMixMessage(ctx context.Context, msg mixing.Message) error + + // SignInput adds a signature script to a transaction input. + SignInput(tx *wire.MsgTx, index int, prevScript []byte) error + + // PublishTransaction adds the transaction to the wallet and publishes + // it to the network. + PublishTransaction(ctx context.Context, tx *wire.MsgTx) error +} + +type deadlines struct { + epoch time.Time + recvKE time.Time + recvCT time.Time + recvSR time.Time + recvDC time.Time + recvCM time.Time +} + +const timeoutDuration = 30 * time.Second + +func (d *deadlines) start(begin time.Time) { + t := begin + add := func() time.Time { + t = t.Add(timeoutDuration) + return t + } + d.recvKE = add() + d.recvCT = add() + d.recvSR = add() + d.recvDC = add() + d.recvCM = add() +} + +func (d *deadlines) shift() { + d.recvKE = d.recvCT + d.recvCT = d.recvSR + d.recvSR = d.recvDC + d.recvDC = d.recvCM + d.recvCM = d.recvCM.Add(timeoutDuration) +} + +func (d *deadlines) restart() { + d.start(d.recvCM) +} + +// Client manages local mixing client sessions. +type Client struct { + wallet Wallet + mixpool *mixpool.Pool + + // Pending and active sessions and peers (both local and, when + // blaming, remote). + pairings map[string]*pairedSessions + height uint32 + mu sync.Mutex + + pairingWG sync.WaitGroup + + blake256Hasher hash.Hash + blake256HasherMu sync.Mutex + + epoch time.Duration + + logger Logger +} + +// NewClient creates a wallet's mixing client manager. +func NewClient(w Wallet) *Client { + height, _ := w.BestBlock() + return &Client{ + wallet: w, + mixpool: w.Mixpool(), + pairings: make(map[string]*pairedSessions), + blake256Hasher: blake256.New(), + epoch: 10 * time.Minute, + height: height, + } +} + +type Logger interface { + Log(args ...interface{}) + Logf(format string, args ...interface{}) +} + +func (c *Client) SetLogger(l Logger) { + c.logger = l +} + +func (c *Client) log(args ...interface{}) { + if c.logger == nil { + return + } + + c.logger.Log(args...) +} + +func (c *Client) logf(format string, args ...interface{}) { + if c.logger == nil { + return + } + + c.logger.Logf(format, args...) +} + +func (c *Client) logerrf(format string, args ...interface{}) { + // XXX use ERR subsystem + c.logf(format, args...) +} + +func (c *Client) sessionLog(sid [32]byte) *sessionLogger { + return &sessionLogger{sid: sid, logger: c.logger} +} + +type sessionLogger struct { + sid [32]byte + logger Logger +} + +func (l *sessionLogger) logf(format string, args ...interface{}) { + l.logger.Logf("sid=%x "+format, append([]interface{}{l.sid[:]}, args...)...) +} + +func (l *sessionLogger) log(args ...interface{}) { + s := fmt.Sprintf("sid=%x", l.sid[:]) + l.logger.Log(append([]interface{}{s}, args...)...) +} + +// Run runs the client manager, blocking until after the context is +// cancelled. +func (s *Client) Run(ctx context.Context) error { + s.epochTicker(ctx) + return ctx.Err() +} + +// pairedSessions tracks the waiting and in-progress mix sessions performed by +// one or more local peers using compatible pairings. +type pairedSessions struct { + localPeers map[identity]*peer + pairing []byte + runs []sessionRun +} + +type sessionRun struct { + sid [32]byte + run uint32 + mtot uint32 + + // Peers sorted by PR hashes. Each peer's myVk is its index in this + // slice. + prs []*wire.MsgMixPairReq + peers []*peer + mcounts []uint32 + roots []*big.Int +} + +// peer represents a participating client in a peer-to-peer mixing session. +// Some fields only pertain to peers created by this wallet, while the rest +// are used during blame assignment. +type peer struct { + ctx context.Context + client *Client + rand io.Reader // non-PRNG cryptographic rand + + res chan error + + pub *secp256k1.PublicKey + priv *secp256k1.PrivateKey + id *identity // serialized pubkey + pr *wire.MsgMixPairReq + coinjoin *CoinJoin + kx *mixing.KX + + prngSeed *[32]byte + prng *chacha20prng.Reader + + rs *wire.MsgMixSecrets + srMsg []*big.Int // random numbers for the exponential slot reservation mix + dcMsg wire.MixVect // anonymized messages to publish in XOR mix + + ke *wire.MsgMixKeyExchange + ct *wire.MsgMixCiphertexts + sr *wire.MsgMixSlotReserve + dc *wire.MsgMixDCNet + cm *wire.MsgMixConfirm + + // Unmixed positions. May change over multiple sessions/runs. + myVk uint32 + myStart uint32 + + // Exponential slot reservation mix + srKP [][][]byte // shared keys for exp dc-net + srMix [][]*big.Int + srMixBytes [][][]byte + + // XOR DC-net + dcKP [][]wire.MixVect + dcNet []wire.MixVect + + // Whether next run must generate fresh KX keys, SR/DC messages + freshGen bool + + // Whether this peer represents a remote peer created from revealed secrets; + // used during blaming. + remote bool +} + +func newRemotePeer(pr *wire.MsgMixPairReq) *peer { + return &peer{ + id: &pr.Identity, + pr: pr, + remote: true, + } +} + +func generateSecp256k1(rand io.Reader) (*secp256k1.PublicKey, *secp256k1.PrivateKey, error) { + if rand == nil { + rand = cryptorand.Reader + } + + privateKey, err := secp256k1.GeneratePrivateKeyFromRand(rand) + if err != nil { + return nil, nil, err + } + + publicKey := privateKey.PubKey() + + return publicKey, privateKey, nil +} + +// Dicemix performs a new mixing session for a coinjoin mix transaction. +func (c *Client) Dicemix(ctx context.Context, rand io.Reader, cj *CoinJoin) error { + pub, priv, err := generateSecp256k1(rand) + if err != nil { + return err + } + + p := &peer{ + ctx: ctx, + client: c, + res: make(chan error, 1), + pub: pub, + priv: priv, + id: (*[33]byte)(pub.SerializeCompressed()), + rand: rand, + coinjoin: cj, + freshGen: true, + } + + pr, err := wire.NewMsgMixPairReq(*p.id, cj.prExpiry, cj.mixValue, + string(mixing.ScriptClassP2PKHv0), cj.tx.Version, + cj.tx.LockTime, cj.mcount, cj.inputValue, cj.prUTXOs, + cj.change) + if err != nil { + return err + } + err = p.signAndHash(pr) + if err != nil { + return err + } + pairingID, err := pr.Pairing() + if err != nil { + return err + } + p.pr = pr + + c.logf("debug: created local peer id=%x PR=%s", p.id[:], p.pr.Hash()) + + c.mu.Lock() + pairing := c.pairings[string(pairingID)] + if pairing == nil { + pairing = c.newPairings(pairingID, nil) + c.pairings[string(pairingID)] = pairing + } + pairing.localPeers[*p.id] = p + c.mu.Unlock() + + err = p.submit(pr) + if err != nil { + c.mu.Lock() + pendingPairing := c.pairings[string(pairingID)] + if pendingPairing != nil { + delete(pendingPairing.localPeers, *p.id) + if len(pendingPairing.localPeers) == 0 { + delete(c.pairings, string(pairingID)) + } + } + c.mu.Unlock() + return err + } + + select { + case res := <-p.res: + return res + case <-ctx.Done(): + return ctx.Err() + } +} + +// ExpireMessages removes all mixpool messages and sessions that indicate an +// expiry height at or before the height parameter and removes any pending +// DiceMix sessions that did not complete. +func (c *Client) ExpireMessages(height uint32) { + c.mu.Lock() + defer c.mu.Unlock() + + // Just mark the current height for expireMessages() to use later, + // after epoch strikes. This allows us to keep using PRs during + // session forming even if they expire, which is required to continue + // mixing paired sessions that are performing reruns. + c.height = height + + // mixpool similarly expires messages in the background after current + // epoch ends. Call it here while in the current epoch, rather than + // when expireMessage() is called later immediately after the next + // epoch strikes. + c.mixpool.ExpireMessages(c.height) +} + +func (c *Client) expireMessages() { + for pairID, ps := range c.pairings { + for id, p := range ps.localPeers { + prHash := p.pr.Hash() + if !c.mixpool.HaveMessage(&prHash) { + delete(ps.localPeers, id) + // p.res is buffered. If the write is + // blocked, we have already served this peer + // or sent another error. + select { + case p.res <- ErrExpired: + default: + } + + } + } + if len(ps.localPeers) == 0 { + delete(c.pairings, pairID) + } + } +} + +// SetEpoch modifies the sessions to use a longer or shorter epoch duration. +// The default epoch is 10 minutes. +func (c *Client) SetEpoch(epoch time.Duration) { + c.epoch = epoch + c.mixpool.SetEpoch(epoch) +} + +// waitForEpoch blocks until the next epoch, or errors when the context is +// cancelled early. Returns the calculated epoch for stage timeout +// calculations. +func (c *Client) waitForEpoch(ctx context.Context) (time.Time, error) { + now := time.Now().UTC() + epoch := now.Truncate(c.epoch).Add(c.epoch) + duration := epoch.Sub(now) + timer := time.NewTimer(duration) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return epoch, ctx.Err() + case <-timer.C: + return epoch, nil + } +} + +func (p *peer) signAndHash(m mixing.Message) error { + err := mixing.SignMessage(m, p.priv) + if err != nil { + return err + } + + p.client.blake256HasherMu.Lock() + m.WriteHash(p.client.blake256Hasher) + p.client.blake256HasherMu.Unlock() + + return nil +} + +func (p *peer) submit(m mixing.Message) error { + err := p.client.wallet.SubmitMixMessage(p.ctx, m) + if err != nil { + return fmt.Errorf("submit %T: %w", m, err) + } + return nil +} + +func (p *peer) signAndSubmit(m mixing.Message) error { + err := p.signAndHash(m) + if err != nil { + return err + } + return p.submit(m) +} + +func (c *Client) newPairings(pairing []byte, peers map[identity]*peer) *pairedSessions { + if peers == nil { + peers = make(map[identity]*peer) + } + ps := &pairedSessions{ + localPeers: peers, + pairing: pairing, + runs: nil, + } + return ps +} + +func (c *Client) epochTicker(ctx context.Context) { + //c.waitForEpoch(ctx) + + for { + epoch, err := c.waitForEpoch(ctx) + if err != nil { + break + } + + c.log("debug: epoch tick") + + var d deadlines + d.epoch = epoch + d.start(epoch) + + // Wait for any previous pairSession calls to timeout if they + // have not yet formed a session before the next epoch tick. + c.pairingWG.Wait() + + c.mu.Lock() + c.mixpool.RemoveConfirmedRuns() + c.expireMessages() + for _, p := range c.pairings { + prs := c.mixpool.CompatiblePRs(p.pairing) + prsMap := make(map[identity]struct{}) + for _, pr := range prs { + prsMap[pr.Identity] = struct{}{} + } + // Clone the p.localPeers map, only including PRs + // currently accepted to mixpool. Adding additional + // waiting local peers must not add more to the map in + // use by runSessions, and deleting peers in a formed + // session from the pending map must not inadvertently + // remove from runSession's ps.localPeers map. + localPeers := make(map[identity]*peer) + for id, peer := range p.localPeers { + if _, ok := prsMap[id]; ok { + localPeers[id] = peer + } + } + c.logf("debug: have %d compatible/%d local PRs waiting for pairing %x", + len(prs), len(localPeers), p.pairing) + if len(prs) < MinPeers { + continue + } + ps := *p + ps.localPeers = localPeers + for id, peer := range p.localPeers { + ps.localPeers[id] = peer + } + c.pairingWG.Add(1) // pairSession calls Done + go c.pairSession(ctx, &ps, prs, d) + } + c.mu.Unlock() + } +} + +func sortPRsForSession(prs []*wire.MsgMixPairReq, epoch uint64) [32]byte { + sort.Slice(prs, func(i, j int) bool { + a := prs[i].Hash() + b := prs[j].Hash() + return bytes.Compare(a[:], b[:]) == -1 + }) + + prHashes := make([]chainhash.Hash, len(prs)) + for i := range prs { + prHashes[i] = prs[i].Hash() + } + return mixing.DeriveSessionID(prHashes, epoch) +} + +func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wire.MsgMixPairReq, d deadlines) { + // This session pairing attempt, and calling pairSession again with + // fresh PRs, must end before the next call to pairSession for this + // pairing type. + nextEpoch := d.epoch.Add(c.epoch) + ctx, cancel := context.WithDeadline(ctx, nextEpoch) + defer cancel() + + // Defer removal of completed mix messages. Add local peers back to + // the client to be paired in a later session if the mix was + // unsuccessful or only some peers were included. + var mixedPRs []*wire.MsgMixPairReq + defer func() { + unmixedPeers := ps.localPeers + for _, pr := range mixedPRs { + delete(unmixedPeers, pr.Identity) + } + if len(unmixedPeers) == 0 { + return + } + c.mu.Lock() + pendingPairing := c.pairings[string(ps.pairing)] + if pendingPairing == nil { + pendingPairing = c.newPairings(ps.pairing, unmixedPeers) + c.pairings[string(ps.pairing)] = pendingPairing + } else { + for id, p := range unmixedPeers { + prHash := p.pr.Hash() + if p.ctx.Err() == nil && c.mixpool.HaveMessage(&prHash) { + pendingPairing.localPeers[id] = p + } + } + } + c.mu.Unlock() + }() + + unixEpoch := uint64(d.epoch.Unix()) + + var madePairing bool + defer func() { + if !madePairing { + c.pairingWG.Done() + } + }() + + var sid [32]byte + var run uint32 + var sesLog *sessionLogger + for { + if run == 0 { + sid = sortPRsForSession(prs, unixEpoch) + sesLog = c.sessionLog(sid) + + prHashes := make([]chainhash.Hash, len(prs)) + for i := range prs { + prHashes[i] = prs[i].Hash() + } + + sesRun := sessionRun{ + sid: sid, + run: run, + prs: prs, + mcounts: make([]uint32, 0, len(prs)), + } + var m uint32 + var localPeerCount int + for i, pr := range prs { + p := ps.localPeers[pr.Identity] + if p != nil { + localPeerCount++ + } else { + p = newRemotePeer(pr) + } + p.myVk = uint32(i) + p.myStart = m + + sesRun.peers = append(sesRun.peers, p) + sesRun.mcounts = append(sesRun.mcounts, p.pr.MessageCount) + + m += p.pr.MessageCount + } + sesRun.mtot = m + ps.runs = append(ps.runs, sesRun) + + if localPeerCount == 0 { + return + } + + sesLog.logf("created session for pairid=%x from %d total %d local PRs %s", + ps.pairing, len(prHashes), localPeerCount, prHashes) + sesLog.logf("debug: len(ps.runs)=%d", len(ps.runs)) + } else { + // Calculate new deadlines for reruns in the same session. + d.restart() + } + + err := c.run(ctx, ps, &d, &madePairing) + + var altses *alternateSession + if errors.As(err, &altses) { + prs = altses.prs + if len(prs) < MinPeers { + sesLog.logf("Aborting session with too few remaining peers") + return + } + + d.shift() + + if sid != altses.sid { + sesLog.logf("Recreating as session %x (pairid=%x)", altses.sid, ps.pairing) + run = 0 + } + + continue + } + + var blamed blamedIdentities + if errors.Is(err, errBlameRequired) { + err := c.blame(ctx, &ps.runs[len(ps.runs)-1], reported /* XXX */) + if !errors.As(err, &blamed) { + c.logerrf("aborting session for failed blame assignment: %v", err) + return + } + } + if blamed != nil || errors.As(err, &blamed) { + // Blamed peers were identified, either during the run + // in a way that all participants could have observed, + // or following revealing secrets and blame + // assignment. Begin a rerun excluding these peers. + rerun := excludeBlamed(&ps.runs[len(ps.runs)-1], blamed) + ps.runs = append(ps.runs, rerun) + run = rerun.run + continue + } + + // Any other run error is not actionable. + if err != nil { + sesLog.logf("run error: %v", err) + return + } + + mixedPRs = ps.runs[len(ps.runs)-1].prs + return + } +} + +var ( + errRunStageTimeout = errors.New("mix run stage timeout") + errUnblamedRerun = errors.New("unblamed rerun") + errBlameRequired = errors.New("blame required") +) + +type alternateSession struct { + prs []*wire.MsgMixPairReq + sid [32]byte + unresponsive []*wire.MsgMixPairReq + err error +} + +func (e *alternateSession) Error() string { + return "recreated alternate session" +} +func (e *alternateSession) Unwrap() error { + return e.err +} + +type timeoutError struct { + exclude []identity + prs []*wire.MsgMixPairReq +} + +func (e *timeoutError) Error() string { return "timeout" } +func (e *timeoutError) Exclude() []identity { return e.exclude } +func (e *timeoutError) Rerun() []*wire.MsgMixPairReq { return e.prs } + +func (c *Client) run(ctx context.Context, ps *pairedSessions, d *deadlines, madePairing *bool) error { + var blamed blamedIdentities + + unixEpoch := uint64(d.epoch.Unix()) + + mp := c.wallet.Mixpool() + sesRun := &ps.runs[len(ps.runs)-1] + run := sesRun.run + prs := sesRun.prs + + // Helper function to run a callback on all local peers. Closes over + // the sesRun variable and will be updated to the appropriate session + // and run after a session is agreed upon. + forLocalPeers := func(f func(p *peer) error) { + for _, p := range sesRun.peers { + if p.remote { + continue + } + err := f(p) + if err != nil { + c.logerrf("%s", err) + } + } + } + + // A map of identity public keys to their PR position sort all + // messages in the same order as the PRs are ordered. + identityIndices := make(map[identity]int) + for i, pr := range prs { + identityIndices[pr.Identity] = i + } + + seenPRs := make([]chainhash.Hash, len(prs)) + for i := range prs { + seenPRs[i] = prs[i].Hash() + } + + forLocalPeers(func(p *peer) error { + p.coinjoin.resetUnmixed(prs) + + if p.freshGen { + p.freshGen = false + + // Generate a new PRNG seed + p.prngSeed = new([32]byte) + _, err := io.ReadFull(p.rand, p.prngSeed[:]) + if err != nil { + return err + } + p.prng = chacha20prng.New(p.prngSeed[:], run) + + // Generate fresh keys from this run's PRNG + p.kx, err = mixing.NewKX(p.prng) + if err != nil { + return err + } + + // Generate fresh SR messages + p.srMsg = make([]*big.Int, p.pr.MessageCount) + for i := range p.srMsg { + p.srMsg[i], err = cryptorand.Int(p.rand, mixing.F) + if err != nil { + return err + } + } + + // Generate fresh DC messages + p.dcMsg, err = p.coinjoin.gen() + if err != nil { + return err + } + if len(p.dcMsg) != int(p.pr.MessageCount) { + return errors.New("Gen returned wrong message count") + } + for _, m := range p.dcMsg { + if len(m) != msize { + err := fmt.Errorf("Gen returned bad message "+ + "length [%v != %v]", len(m), msize) + return err + } + } + } else { + // Generate a new PRNG from existing seed and this run + // number. + p.prng = chacha20prng.New(p.prngSeed[:], run) + } + + // Perform key exchange + srMsgBytes := make([][]byte, len(p.srMsg)) + for i := range p.srMsg { + srMsgBytes[i] = p.srMsg[i].Bytes() + } + rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, run, + *p.prngSeed, srMsgBytes, p.dcMsg) + commitment := rs.Hash() + ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed()) + pqPub := *p.kx.PQPublicKey + ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, run, + ecdhPub, pqPub, commitment, seenPRs) + + p.ke = ke + p.rs = rs + return p.signAndSubmit(ke) + }) + + // Receive key exchange messages. + // + // In run 0, it is possible that the attempted session (the session ID + // used by our PR messages) does not match the same session attempted + // by all other remote peers, due to not all peers initially agreeing + // on the same set of PRs. When there is agreement, the session can + // be run like normal, and any reruns will only be performed with some + // of the original peers removed. + // + // When there is session disagreement, we attempt to find a new + // session that the majority of peers will be able to participate in. + // All KE messages that match the pairing ID are received, and each + // seen PRs slice is checked. PRs that were never followed up by a KE + // are immediately excluded. + var kes []*wire.MsgMixKeyExchange + var err error + recvKEs := func(sesRun *sessionRun) (kes []*wire.MsgMixKeyExchange, err error) { + rcv := new(mixpool.Received) + rcv.Run = sesRun.run + rcv.Sid = sesRun.sid + rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs)) + ctx, cancel := context.WithDeadlineCause(ctx, + d.recvKE, errRunStageTimeout) + defer cancel() + err = mp.Receive(ctx, len(sesRun.prs), rcv) + return rcv.KEs, err + } + + switch { + case run == 0: + // Receive KEs for the last attempted session. Local + // peers may have been modified (new keys generated, and myVk + // indexes changed) if this is a recreated session, and we + // cannot continue mix using these messages. + // + // XXX: do we need to keep peer info available for previous + // session attempts? It is possible that previous sessions + // may be operable now if all wallets have come to agree on a + // previous session we also tried to form. + kes, err = recvKEs(sesRun) + if len(kes) == len(sesRun.prs) { + break + } + if err := ctx.Err(); err != nil { + return err + } + + // Alternate session needs to be attempted. Do not form an + // alternate session if we've already entered into the next + // epoch without forming a complete session. The session + // forming will be performed by a new goroutine started by the + // epoch ticker, possibly with additional PRs. + nextEpoch := d.epoch.Add(c.epoch) + if time.Now().After(nextEpoch) { + return err + } + + altses := c.alternateSession(ps.pairing, sesRun.prs, nil, d) + if errors.Is(altses.err, ErrTooFewPeers) || altses.sid == sesRun.sid { + altses = c.alternateSession(ps.pairing, sesRun.prs, + altses.unresponsive, d) + } + if altses.err != nil { + return err + } + return altses + + default: + // Receive KEs only for the agreed-upon session. + kes, err = recvKEs(sesRun) + if err != nil { + return err + } + } + + for _, ke := range kes { + if idx, ok := identityIndices[ke.Identity]; ok { + sesRun.peers[idx].ke = ke + } + } + + // Remove paired local peers from waiting pairing. + if run == 0 { + c.mu.Lock() + if waiting := c.pairings[string(ps.pairing)]; waiting != nil { + for id := range ps.localPeers { + delete(waiting.localPeers, id) + } + if len(waiting.localPeers) == 0 { + delete(c.pairings, string(ps.pairing)) + } + } + c.mu.Unlock() + + if !*madePairing { + c.pairingWG.Done() + *madePairing = true + } + } + + sort.Slice(kes, func(i, j int) bool { + a := identityIndices[kes[i].Identity] + b := identityIndices[kes[j].Identity] + return a < b + }) + + sesLog := c.sessionLog(sesRun.sid) + + involvesLocalPeers := false + for _, ke := range kes { + if ps.localPeers[ke.Identity] != nil { + involvesLocalPeers = true + break + } + } + if !involvesLocalPeers { + return errors.New("excluded all local peers") + } + + ecdhPublicKeys := make([]*secp256k1.PublicKey, 0, len(prs)) + pqpk := make([]*mixing.PQPublicKey, 0, len(prs)) + for _, ke := range kes { + ecdhPub, err := secp256k1.ParsePubKey(ke.ECDH[:]) + if err != nil { + blamed = append(blamed, ke.Identity) + continue + } + ecdhPublicKeys = append(ecdhPublicKeys, ecdhPub) + pqpk = append(pqpk, &ke.PQPK) + } + if len(blamed) > 0 { + return blamed + } + + forLocalPeers(func(p *peer) error { + // Create shared keys and ciphextexts for each peer + pqct, err := p.kx.Encapsulate(p.prng, pqpk, int(p.myVk)) + if err != nil { + return err + } + + // Send ciphertext messages + seenKEs := make([]chainhash.Hash, len(kes)) + for i := range kes { + seenKEs[i] = kes[i].Hash() + } + ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, run, pqct, seenKEs) + p.ct = ct + return p.signAndSubmit(ct) + }) + + // Receive all ciphertext messages + rcv := new(mixpool.Received) + rcv.Sid = sesRun.sid + rcv.KEs = nil + rcv.CTs = make([]*wire.MsgMixCiphertexts, 0, len(prs)) + rcvCtx, rcvCtxCancel := context.WithDeadlineCause(context.Background(), + d.recvCT, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(prs), rcv) + rcvCtxCancel() + cts := rcv.CTs + for _, ct := range cts { + if idx, ok := identityIndices[ct.Identity]; ok { + sesRun.peers[idx].ct = ct + } + } + if len(cts) != len(prs) || errors.Is(err, errRunStageTimeout) { + // Blame peers + sesLog.logf("received %d CTs for %d peers; rerunning", len(cts), len(prs)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(cts, func(i, j int) bool { + a := identityIndices[cts[i].Identity] + b := identityIndices[cts[j].Identity] + return a < b + }) + + blamedMap := make(map[identity]struct{}) + var blameErr error + forLocalPeers(func(p *peer) error { + revealed := &mixing.RevealedKeys{ + ECDHPublicKeys: ecdhPublicKeys, + Ciphertexts: make([]mixing.PQCiphertext, 0, len(prs)), + MyIndex: p.myVk, + } + ctIds := make([]identity, 0, len(cts)) + for _, ct := range cts { + if len(ct.Ciphertexts) != len(prs) { + // Everyone sees this, can blame now. + blamedMap[ct.Identity] = struct{}{} + return nil + } + revealed.Ciphertexts = append(revealed.Ciphertexts, ct.Ciphertexts[p.myVk]) + ctIds = append(ctIds, ct.Identity) + } + + // Derive shared secret keys + shared, err := p.kx.SharedSecrets(revealed, sesRun.sid[:], run, sesRun.mcounts) + if err != nil { + blameErr = errBlameRequired + return nil + } + p.srKP = shared.SRSecrets + p.dcKP = shared.DCSecrets + + // Calculate slot reservation DC-net vectors + p.srMix = make([][]*big.Int, p.pr.MessageCount) + for i := range p.srMix { + pads := mixing.SRMixPads(p.srKP[i], p.myStart+uint32(i)) + p.srMix[i] = mixing.SRMix(p.srMsg[i], pads) + } + srMixBytes := mixing.IntVectorsToBytes(p.srMix) + + // Broadcast message commitment and exponential DC-mix vectors for slot + // reservations. + seenCTs := make([]chainhash.Hash, len(cts)) + for i := range cts { + seenCTs[i] = cts[i].Hash() + } + sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, run, srMixBytes, seenCTs) + p.sr = sr + return p.signAndSubmit(sr) + }) + if len(blamedMap) > 0 { + for id := range blamedMap { + blamed = append(blamed, id) + } + return blamed + } + if blameErr != nil { + return blameErr + } + + // Receive all slot reservation messages + rcv.CTs = nil + rcv.SRs = make([]*wire.MsgMixSlotReserve, 0, len(prs)) + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(context.Background(), + d.recvSR, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(prs), rcv) + rcvCtxCancel() + srs := rcv.SRs + for _, sr := range srs { + if idx, ok := identityIndices[sr.Identity]; ok { + sesRun.peers[idx].sr = sr + } + } + if len(srs) != len(prs) || errors.Is(err, errRunStageTimeout) { + // Blame peers + sesLog.logf("received %d SRs for %d peers; rerunning", len(srs), len(prs)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(srs, func(i, j int) bool { + a := identityIndices[srs[i].Identity] + b := identityIndices[srs[j].Identity] + return a < b + }) + + // Recover roots + var roots []*big.Int + vs := make([][][]byte, 0, len(prs)) + for _, sr := range srs { + vs = append(vs, sr.DCMix...) + } + powerSums := mixing.AddVectors(mixing.IntVectorsFromBytes(vs)...) + coeffs := mixing.Coefficients(powerSums) + roots, err = solverrpc.Roots(coeffs, mixing.F) + if err != nil { + // Blame peers + return errors.New("blame required") + } + sesRun.roots = roots + + forLocalPeers(func(p *peer) error { + // Find reserved slots + slots := make([]uint32, 0, p.pr.MessageCount) + for _, m := range p.srMsg { + slot := constTimeSlotSearch(m, roots) + if slot == -1 { + // Blame peers + return errors.New("blame required") + } + slots = append(slots, uint32(slot)) + } + + // Calculate XOR DC-net vectors + p.dcNet = make([]wire.MixVect, p.pr.MessageCount) + for i, slot := range slots { + my := p.myStart + uint32(i) + pads := mixing.DCMixPads(p.dcKP[i], my) + p.dcNet[i] = wire.MixVect(mixing.DCMix(pads, p.dcMsg[i][:], slot)) + } + + // Broadcast XOR DC-net vectors. + seenSRs := make([]chainhash.Hash, len(cts)) + for i := range srs { + seenSRs[i] = srs[i].Hash() + } + dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, run, p.dcNet, seenSRs) + p.dc = dc + return p.signAndSubmit(dc) + }) + + // Receive all DC messages + rcv.SRs = nil + rcv.DCs = make([]*wire.MsgMixDCNet, 0, len(prs)) + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(context.Background(), + d.recvDC, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(prs), rcv) + rcvCtxCancel() + dcs := rcv.DCs + for _, dc := range dcs { + if idx, ok := identityIndices[dc.Identity]; ok { + sesRun.peers[idx].dc = dc + } + } + if len(dcs) != len(prs) || errors.Is(err, errRunStageTimeout) { + // Blame peers + sesLog.logf("received %d DCs for %d peers; rerunning", len(dcs), len(prs)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(dcs, func(i, j int) bool { + a := identityIndices[dcs[i].Identity] + b := identityIndices[dcs[j].Identity] + return a < b + }) + + // Solve XOR dc-net + dcVecs := make([]mixing.Vec, 0, sesRun.mtot) + for _, dc := range dcs { + for _, vec := range dc.DCNet { + dcVecs = append(dcVecs, mixing.Vec(vec)) + } + } + mixedMsgs := mixing.XorVectors(dcVecs) + + type missingMessage interface { + error + MissingMessage() + } + var errMissingMessage missingMessage + forLocalPeers(func(p *peer) error { + // Add outputs for each mixed message + for i := range mixedMsgs { + mixedMsg := mixedMsgs[i][:] + p.coinjoin.addMixedMessage(mixedMsg) + } + p.coinjoin.sort() + + // Confirm that our messages are present, and sign each peer's + // provided inputs. + err = p.coinjoin.confirm(c.wallet) + if err != nil { + if errMissingMessage != nil { + errors.As(err, &errMissingMessage) + } + return nil + } + + // Broadcast partially signed mix tx + seenDCs := make([]chainhash.Hash, len(dcs)) + for i := range dcs { + seenDCs[i] = dcs[i].Hash() + } + cm := wire.NewMsgMixConfirm(*p.id, sesRun.sid, run, + p.coinjoin.Tx().Copy(), seenDCs) + p.cm = cm + return p.signAndSubmit(cm) + }) + if errMissingMessage != nil { + sesLog.logf("missing message; blaming and rerunning") + return errMissingMessage + } + + // Receive all CM messages + rcv.DCs = nil + rcv.CMs = make([]*wire.MsgMixConfirm, 0, len(prs)) + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(context.Background(), + d.recvCM, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(prs), rcv) + rcvCtxCancel() + cms := rcv.CMs + for _, cm := range cms { + if idx, ok := identityIndices[cm.Identity]; ok { + sesRun.peers[idx].cm = cm + } + } + if len(cms) != len(prs) || errors.Is(err, errRunStageTimeout) { + // Blame peers + sesLog.logf("received %d CMs for %d peers; rerunning", len(cms), len(prs)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(cms, func(i, j int) bool { + a := identityIndices[cms[i].Identity] + b := identityIndices[cms[j].Identity] + return a < b + }) + + // Merge and validate all signatures. Only a single coinjoin is + // needed at this point. + var cj *CoinJoin + pubkeys := make(map[wire.OutPoint][]byte) + for _, p := range sesRun.peers { + if cj == nil && !p.remote { + cj = p.coinjoin + } + for _, utxo := range p.pr.UTXOs { + pubkeys[utxo.OutPoint] = utxo.PubKey + } + } + for _, cm := range cms { + err := cj.mergeSignatures(cm) + if err != nil { + blamed = append(blamed, cm.Identity) + } + } + if len(blamed) > 0 { + return blamed + } + p2pkhv0Script := []byte{ + 0: txscript.OP_DUP, + 1: txscript.OP_HASH160, + 2: txscript.OP_DATA_20, + 23: txscript.OP_EQUALVERIFY, + 24: txscript.OP_CHECKSIG, + } + sigCache, err := txscript.NewSigCache(uint(len(cj.tx.TxIn))) + if err != nil { + return errUnblamedRerun + } + blamedInputs := make(map[wire.OutPoint]struct{}) + const scriptFlags = txscript.ScriptDiscourageUpgradableNops | + txscript.ScriptVerifyCleanStack | + txscript.ScriptVerifyCheckLockTimeVerify | + txscript.ScriptVerifyCheckSequenceVerify | + txscript.ScriptVerifyTreasury + const scriptVersion = 0 + for i, in := range cj.tx.TxIn { + pk, ok := pubkeys[in.PreviousOutPoint] + if !ok { + blamedInputs[in.PreviousOutPoint] = struct{}{} + continue + } + pkh := dcrutil.Hash160(pk) + copy(p2pkhv0Script[3:23], pkh) + engine, err := txscript.NewEngine(p2pkhv0Script, cj.tx, i, + scriptFlags, scriptVersion, sigCache) + if err != nil { + c.logf("warn: blaming peer for error creating script engine: %v", err) + blamedInputs[in.PreviousOutPoint] = struct{}{} + continue + } + if engine.Execute() != nil { + blamedInputs[in.PreviousOutPoint] = struct{}{} + } + } + if len(blamedInputs) > 0 { + blameInputLoop: + for _, pr := range prs { + for _, in := range pr.UTXOs { + if _, ok := blamedInputs[in.OutPoint]; ok { + blamed = append(blamed, pr.Identity) + continue blameInputLoop + } + } + } + return blamed + } + + err = c.wallet.PublishTransaction(context.Background(), cj.tx) + if err != nil { + return err + } + + mp.RemoveSession(sesRun.sid, true) + + forLocalPeers(func(p *peer) error { + select { + case p.res <- nil: + default: + } + return nil + }) + return nil +} + +func (c *Client) alternateSession(pairing []byte, prs []*wire.MsgMixPairReq, + ignoreUnresponsive []*wire.MsgMixPairReq, d *deadlines) *alternateSession { + + unixEpoch := uint64(d.epoch.Unix()) + + kes := c.mixpool.ReceiveKEsByPairing(pairing, unixEpoch) + + // Sort KEs by identity first (just to group these together) followed + // by the total referenced PR counts in decreasing order. + // When ranging over KEs below, this will allow us to consider the + // order in which other peers created their KEs, and how they are + // forming their sessions. + sort.Slice(kes, func(i, j int) bool { + a := kes[i] + b := kes[j] + if bytes.Compare(a.Identity[:], b.Identity[:]) == -1 { + return true + } + if len(a.SeenPRs) > len(b.SeenPRs) { + return true + } + return false + }) + + ignoreUnresponsiveIdentities := make(map[identity]struct{}) + ignoreUnresponsivePRHashes := make(map[chainhash.Hash]struct{}) + for _, pr := range ignoreUnresponsive { + ignoreUnresponsiveIdentities[pr.Identity] = struct{}{} + ignoreUnresponsivePRHashes[pr.Hash()] = struct{}{} + } + + prsByHash := make(map[chainhash.Hash]*wire.MsgMixPairReq) + prHashByIdentity := make(map[identity]chainhash.Hash) + for _, pr := range prs { + prsByHash[pr.Hash()] = pr + prHashByIdentity[pr.Identity] = pr.Hash() + } + + // Only one KE per peer identity (the KE that references the most PR + // hashes) is used for determining session agreement. + type peerMsgs struct { + pr *wire.MsgMixPairReq + ke *wire.MsgMixKeyExchange + } + msgsByIdentityMap := make(map[identity]*peerMsgs) + for i := range kes { + ke := kes[i] + if msgsByIdentityMap[ke.Identity] != nil { + continue + } + pr := prsByHash[prHashByIdentity[ke.Identity]] + if pr == nil { + err := fmt.Errorf("missing PR %s by %x, but have their KE %s", + prHashByIdentity[ke.Identity], ke.Identity[:], ke.Hash()) + c.logger.Log(err) + continue + } + msgsByIdentityMap[ke.Identity] = &peerMsgs{ + pr: pr, + ke: ke, + } + } + + // Discover unresponsive peers if we have not already calculated them + // and have to ignore them. + var unresponsive []*wire.MsgMixPairReq + if len(ignoreUnresponsive) == 0 { + for _, pr := range prs { + id := pr.Identity + if _, ok := msgsByIdentityMap[id]; !ok { + c.logf("Identity %x (PR %s) is unresponsive", id[:], pr.Hash()) + unresponsive = append(unresponsive, pr) + } + } + } + + // A total count of all unique PRs/identities (including not just + // those we have observed PRs for, but also other unknown PR hashes + // referenced by other peers) is needed. All participating peers must + // come to agreement on this number, otherwise the exclusion rules + // based on majority may result in different peers being removed and + // agreement never being reached. Any PRs not seen by a majority of + // other peers (i.e. without counting peers seeing their own PR) are + // immediately removed from any formed session attempt. + referencedPRCounts := make(map[chainhash.Hash]int) + for _, msgs := range msgsByIdentityMap { + selfPRHash := msgs.pr.Hash() + for _, prHash := range msgs.ke.SeenPRs { + if selfPRHash == prHash { + continue + } + // Do not count ignored unresponsive peer's PRs. + if _, ok := ignoreUnresponsivePRHashes[prHash]; ok { + continue + } + referencedPRCounts[prHash]++ + } + } + prCount := len(referencedPRCounts) + neededForMajority := len(referencedPRCounts) / 2 + remainingPRs := make([]chainhash.Hash, 0, prCount) + prsSeenCount := make([]int, 0, prCount) + for prHash, count := range referencedPRCounts { + if count >= neededForMajority { + remainingPRs = append(remainingPRs, prHash) + prsSeenCount = append(prsSeenCount, count) + } + } + + // Sort PRs by increasing seen counts first, then lexicographically by + // PR hash. + sort.Slice(remainingPRs, func(i, j int) bool { + switch { + case prsSeenCount[i] < prsSeenCount[j]: + return true + case prsSeenCount[i] > prsSeenCount[j]: + return false + default: + a := remainingPRs[i][:] + b := remainingPRs[j][:] + return bytes.Compare(a, b) == -1 + } + }) + + // Incrementally remove PRs that are not seen by all peers, and + // recalculate total observance counts after excluding the peer that + // was dropped. A session is formed when all remaining PR counts + // equal one less than the total number of PRs still considered + // (due to not counting observing one's own PR). + for { + if len(prsSeenCount) < MinPeers { + return &alternateSession{ + unresponsive: unresponsive, + err: ErrTooFewPeers, + } + } + if prsSeenCount[0] == len(prsSeenCount)-1 { + prs := make([]*wire.MsgMixPairReq, len(remainingPRs)) + for i, prHash := range remainingPRs { + prs[i] = prsByHash[prHash] + if prs[i] == nil { + // Agreement should be reached, but we + // won't be able to participate in the + // formed session. + return &alternateSession{ + unresponsive: unresponsive, + err: ErrUnknownPRs, + } + } + } + newSessionID := sortPRsForSession(prs, unixEpoch) + return &alternateSession{ + prs: prs, + sid: newSessionID, + unresponsive: unresponsive, + } + } + + // removedPR := remainingPRHashes[0] + remainingPRs = remainingPRs[1:] + prsSeenCount = prsSeenCount[:0] + + prCounts := make(map[chainhash.Hash]int, len(remainingPRs)) + for _, hash := range remainingPRs { + prCounts[hash] = 0 + } + for _, hash := range remainingPRs { + pr := prsByHash[hash] + if pr == nil { + continue + } + selfPRHash := pr.Hash() + msgs := msgsByIdentityMap[pr.Identity] + if msgs == nil { + continue + } + ke := msgs.ke + for _, prHash := range ke.SeenPRs { + if prHash == selfPRHash { + continue + } + if _, ok := prCounts[prHash]; ok { + prCounts[prHash]++ + } + } + } + for _, prHash := range remainingPRs { + prsSeenCount = append(prsSeenCount, prCounts[prHash]) + } + } + +} + +func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities) sessionRun { + blamedMap := make(map[identity]struct{}) + for _, id := range blamed { + blamedMap[id] = struct{}{} + } + + peers := prevRun.peers[:0] + prs := prevRun.prs[:0] + mcounts := prevRun.mcounts[:0] + var mtot uint32 + + for _, p := range prevRun.peers { + if _, ok := blamedMap[*p.id]; ok { + continue + } + + peers = append(peers, p) + prs = append(prs, p.pr) + mcounts = append(mcounts, p.pr.MessageCount) + mtot += p.pr.MessageCount + } + + return sessionRun{ + sid: prevRun.sid, + run: prevRun.run + 1, + mtot: mtot, + + prs: prs, + peers: peers, + mcounts: mcounts, + roots: nil, + } +} + +var fieldLen = uint(len(mixing.F.Bytes())) + +// constTimeSlotSearch searches for the index of secret in roots in constant time. +// Returns -1 if the secret is not found. +func constTimeSlotSearch(secret *big.Int, roots []*big.Int) int { + paddedSecret := make([]byte, fieldLen) + secretBytes := secret.Bytes() + off, _ := bits.Sub(fieldLen, uint(len(secretBytes)), 0) + copy(paddedSecret[off:], secretBytes) + + slot := -1 + buf := make([]byte, fieldLen) + for i := range roots { + rootBytes := roots[i].Bytes() + off, _ := bits.Sub(fieldLen, uint(len(rootBytes)), 0) + copy(buf[off:], rootBytes) + cmp := subtle.ConstantTimeCompare(paddedSecret, buf) + slot = subtle.ConstantTimeSelect(cmp, i, slot) + for j := range buf { + buf[j] = 0 + } + } + return slot +} diff --git a/mixing/mixpool/mixpool.go b/mixing/mixpool/mixpool.go new file mode 100644 index 0000000000..85abe003d3 --- /dev/null +++ b/mixing/mixpool/mixpool.go @@ -0,0 +1,1340 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +// Package mixpool provides an in-memory pool of mixing messages for full nodes +// that relay these messages and mixing wallets that send and receive them. +package mixpool + +import ( + "bytes" + "context" + "fmt" + "sort" + "sync" + "time" + + "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/chaincfg/v3" + "github.com/decred/dcrd/dcrutil/v4" + "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/utxoproof" + "github.com/decred/dcrd/txscript/v4/stdscript" + "github.com/decred/dcrd/wire" +) + +const minconf = 2 +const feeRate = 0.0001e8 + +type idPubKey = [33]byte + +type msgtype int + +// Message type constants, for quickly checking looked up entries by message +// hash match the expected type (without performing a type assertion). +// Excludes PR. +const ( + msgtypeKE msgtype = 1 + iota + msgtypeCT + msgtypeSR + msgtypeDC + msgtypeCM + msgtypeRS + + nmsgtypes = msgtypeRS +) + +func (m msgtype) String() string { + switch m { + case msgtypeKE: + return "KE" + case msgtypeCT: + return "CT" + case msgtypeSR: + return "SR" + case msgtypeDC: + return "DC" + case msgtypeCM: + return "CM" + case msgtypeRS: + return "RS" + default: + return "?" + } +} + +// entry describes non-PR messages accepted to the pool. +type entry struct { + hash chainhash.Hash + sid [32]byte + recvTime time.Time + msg mixing.Message + msgtype msgtype + run uint32 +} + +type session struct { + sid [32]byte + runs []runstate + expiry uint32 + bc broadcast +} + +type runstate struct { + run uint32 + prs []chainhash.Hash + counts [nmsgtypes]uint32 + hashes map[chainhash.Hash]struct{} +} + +func (r *runstate) countFor(t msgtype) uint32 { + return r.counts[t-1] +} + +func (r *runstate) incrementCountFor(t msgtype) { + r.counts[t-1]++ +} + +type broadcast struct { + ch chan struct{} + mu sync.Mutex +} + +// wait returns the wait channel that is closed whenever a message is received +// for a session. Waiters must acquire the pool lock before reading messages. +func (b *broadcast) wait() <-chan struct{} { + b.mu.Lock() + ch := b.ch + b.mu.Unlock() + + return ch +} + +func (b *broadcast) signal() { + b.mu.Lock() + close(b.ch) + b.ch = make(chan struct{}) + b.mu.Unlock() +} + +// Pool records in-memory mix messages that have been broadcast over the +// peer-to-peer network. +type Pool struct { + mtx sync.RWMutex + prs map[chainhash.Hash]*wire.MsgMixPairReq + outPoints map[wire.OutPoint]chainhash.Hash + pool map[chainhash.Hash]entry + messagesByIdentity map[idPubKey][]chainhash.Hash + latestKE map[idPubKey]*wire.MsgMixKeyExchange + sessions map[[32]byte]*session + sessionsByTxHash map[chainhash.Hash]*session + epoch time.Duration + expireHeight uint32 + expireSem chan struct{} + + blockchain BlockChain + utxoFetcher UtxoFetcher + feeRate int64 + params *chaincfg.Params +} + +// UtxoEntry provides details regarding unspent transaction outputs. +type UtxoEntry interface { + IsSpent() bool + PkScript() []byte + ScriptVersion() uint16 + BlockHeight() int64 + Amount() int64 +} + +// UtxoFetcher defines methods used to validate unspent transaction outputs in +// the pair request message. It is optional, but should be implemented by full +// nodes that have this capability to detect and stop relay of spam and junk +// messages. +type UtxoFetcher interface { + // FetchUtxoEntry defines the function to use to fetch unspent + // transaction output information. + FetchUtxoEntry(wire.OutPoint) (UtxoEntry, error) +} + +// BlockChain queries the current status of the blockchain. Its methods should +// be able to be implemented by both full nodes and SPV wallets. +type BlockChain interface { + // ChainParams identifies which chain parameters the mixing pool is + // associated with. + ChainParams() *chaincfg.Params + + // BestHeader returns the hash and height of the current tip block. + BestHeader() (chainhash.Hash, int64) +} + +// NewPool returns a new mixing pool that accepts and validates mixing messages +// required for distributed transaction mixing. +func NewPool(blockchain BlockChain) *Pool { + pool := &Pool{ + prs: make(map[chainhash.Hash]*wire.MsgMixPairReq), + outPoints: make(map[wire.OutPoint]chainhash.Hash), + pool: make(map[chainhash.Hash]entry), + messagesByIdentity: make(map[idPubKey][]chainhash.Hash), + latestKE: make(map[idPubKey]*wire.MsgMixKeyExchange), + sessions: make(map[[32]byte]*session), + sessionsByTxHash: make(map[chainhash.Hash]*session), + epoch: 10 * time.Minute, // XXX: mainnet epoch: add to chainparams + expireHeight: 0, + expireSem: make(chan struct{}, 1), + blockchain: blockchain, + feeRate: feeRate, + params: blockchain.ChainParams(), + } + if u, ok := blockchain.(UtxoFetcher); ok { + pool.utxoFetcher = u + } + return pool +} + +// XXX: can remove this method after adding mixing epoch to chainparams. +func (p *Pool) SetEpoch(d time.Duration) { + p.epoch = d +} + +// MixPRHashes returns the hashes of all MixPR messages recorded by the pool. +// This data is provided to peers requesting initial state of the mixpool. +func (p *Pool) MixPRHashes() []chainhash.Hash { + p.mtx.RLock() + hashes := make([]chainhash.Hash, 0, len(p.prs)) + for hash := range p.prs { + hashes = append(hashes, hash) + } + p.mtx.RUnlock() + + return hashes +} + +// Message searches the mixing pool for a message by its hash. +func (p *Pool) Message(query *chainhash.Hash) (mixing.Message, error) { + p.mtx.RLock() + pr := p.prs[*query] + e, ok := p.pool[*query] + p.mtx.RUnlock() + if pr != nil { + return pr, nil + } + if !ok || e.msg == nil { + return nil, fmt.Errorf("message not found") + } + return e.msg, nil +} + +// HaveMessage checks whether the mixing pool contains a message by its hash. +func (p *Pool) HaveMessage(query *chainhash.Hash) bool { + p.mtx.RLock() + _, ok := p.pool[*query] + if !ok { + _, ok = p.prs[*query] + } + p.mtx.RUnlock() + return ok +} + +// MixPR searches the mixing pool for a PR message by its hash. +func (p *Pool) MixPR(query *chainhash.Hash) (*wire.MsgMixPairReq, error) { + var pr *wire.MsgMixPairReq + + p.mtx.RLock() + pr = p.prs[*query] + p.mtx.RUnlock() + + if pr == nil { + return nil, fmt.Errorf("PR message not found") + } + + return pr, nil +} + +// MixPRs returns all MixPR messages with hashes matching the query. Unknown +// messages are ignored. +// +// If query is nil, all PRs are returned. +// +// In both cases, any expired PRs that are still internally tracked by the +// mixpool for ongoing sessions are excluded from the result set. +func (p *Pool) MixPRs(query []chainhash.Hash) []*wire.MsgMixPairReq { + res := make([]*wire.MsgMixPairReq, 0, len(query)) + + p.mtx.Lock() + defer p.mtx.Unlock() + + p.removeConfirmedRuns() + + if query == nil { + res = make([]*wire.MsgMixPairReq, 0, len(p.prs)) + for _, pr := range p.prs { + // Exclude expired but not yet removed PRs. + if pr.Expiry <= p.expireHeight { + continue + } + + res = append(res, pr) + } + return res + } + + for i := range query { + pr, ok := p.prs[query[i]] + if ok { + // Exclude expired but not yet removed PRs. + if pr.Expiry <= p.expireHeight { + continue + } + + res = append(res, pr) + } + } + + return res +} + +// CompatiblePRs returns all MixPR messages with pairing descriptions matching +// the parameter. +func (p *Pool) CompatiblePRs(pairing []byte) []*wire.MsgMixPairReq { + p.mtx.RLock() + defer p.mtx.RUnlock() + + res := make([]*wire.MsgMixPairReq, 0, len(p.prs)) + for _, pr := range p.prs { + prPairing, _ := pr.Pairing() + if bytes.Equal(pairing, prPairing) { + res = append(res, pr) + } + } + + // Sort by decreasing expiries and remove any PRs double spending an + // output with an earlier expiry. + sort.Slice(res, func(i, j int) bool { + return res[i].Expiry >= res[j].Expiry + }) + seen := make(map[wire.OutPoint]uint32) + for i, pr := range res { + for _, utxo := range pr.UTXOs { + prevExpiry, ok := seen[utxo.OutPoint] + if !ok { + seen[utxo.OutPoint] = pr.Expiry + } else if pr.Expiry < prevExpiry { + res[i] = nil + } + } + } + filtered := res[:0] + for i := range res { + if res[i] != nil { + filtered = append(filtered, res[i]) + } + } + + // Sort again lexicographically by hash. + sort.Slice(filtered, func(i, j int) bool { + a := filtered[i].Hash() + b := filtered[j].Hash() + return bytes.Compare(a[:], b[:]) < 1 + }) + return filtered +} + +// ExpireMessages will, after the current epoch period ends, remove all pair +// requests that indicate an expiry at or before the height parameter and +// removes all messages that chain back to a removed pair request. +func (p *Pool) ExpireMessages(height uint32) { + p.mtx.Lock() + defer p.mtx.Unlock() + + if height > p.expireHeight { + p.expireHeight = height + } + + select { + case p.expireSem <- struct{}{}: + go p.expireMessages() + default: + } +} + +// waitForEpoch blocks until the next epoch occurs. +func (p *Pool) waitForEpoch() { + now := time.Now().UTC() + epoch := now.Truncate(p.epoch).Add(p.epoch) + duration := epoch.Sub(now) + time.Sleep(duration) +} + +func (p *Pool) expireMessages() { + p.waitForEpoch() + + p.mtx.Lock() + defer func() { + <-p.expireSem + p.mtx.Unlock() + }() + + height := p.expireHeight + p.expireHeight = 0 + + // Expire sessions and their messages + for sid, ses := range p.sessions { + if ses.expiry > height { + continue + } + + delete(p.sessions, sid) + for _, r := range ses.runs { + for hash := range r.hashes { + delete(p.pool, hash) + } + } + } + + // Expire PRs and remove identity tracking + for hash, pr := range p.prs { + if pr.Expiry > height { + continue + } + + logf("Removing expired PR %s by %x", + hash, pr.Identity[:]) + delete(p.prs, hash) + delete(p.messagesByIdentity, pr.Identity) + delete(p.latestKE, pr.Identity) + } +} + +// RemoveMessage removes a message that was rejected by the network. +func (p *Pool) RemoveMessage(msg mixing.Message) { + p.mtx.Lock() + defer p.mtx.Unlock() + + msgHash := msg.Hash() + delete(p.pool, msgHash) + if pr, ok := msg.(*wire.MsgMixPairReq); ok { + if p.prs[msgHash] != nil { + logf("Removing PR %s by %x", + msgHash, pr.Identity[:]) + } + delete(p.prs, msgHash) + delete(p.latestKE, pr.Identity) + } + if ke, ok := msg.(*wire.MsgMixKeyExchange); ok { + delete(p.latestKE, ke.Identity) + } +} + +// RemoveIdentities removes all messages from the mixpool that were created by +// any of the identities. +func (p *Pool) RemoveIdentities(identities [][33]byte) { + p.mtx.Lock() + defer p.mtx.Unlock() + + for i := range identities { + id := &identities[i] + for _, hash := range p.messagesByIdentity[*id] { + delete(p.pool, hash) + if p.prs[hash] != nil { + logf("Removing PR %s by %x", + hash, id[:]) + } + delete(p.prs, hash) + } + delete(p.messagesByIdentity, *id) + delete(p.latestKE, *id) + } +} + +func (p *Pool) removeSession(sid [32]byte, txHash *chainhash.Hash, success bool) { + ses := p.sessions[sid] + if ses == nil { + return + } + + // Delete PRs used to form final run + var removePRs []chainhash.Hash + var lastRun *runstate + if success { + lastRun = &ses.runs[len(ses.runs)-1] + removePRs = lastRun.prs + } + + if txHash != nil || success { + if txHash == nil { + // XXX: may be better to store this in the runstate as + // a CM is received. + for h := range lastRun.hashes { + if e, ok := p.pool[h]; ok && e.msgtype == msgtypeCM { + cm := e.msg.(*wire.MsgMixConfirm) + hash := cm.Mix.TxHash() + txHash = &hash + break + } + } + } + if txHash != nil { + delete(p.sessionsByTxHash, *txHash) + } + } + + delete(p.sessions, sid) + for _, r := range ses.runs { + for hash := range r.hashes { + delete(p.pool, hash) + } + } + + for _, prHash := range removePRs { + delete(p.pool, prHash) + if pr := p.prs[prHash]; pr != nil { + logf("Removing mixed PR %s by %x", prHash, + pr.Identity[:]) + delete(p.prs, prHash) + delete(p.latestKE, pr.Identity) + } + } +} + +// RemoveSession removes all non-PR messages from a completed or errored +// session. PR messages of a successful run (or rerun) must also be removed. +func (p *Pool) RemoveSession(sid [32]byte, success bool) { + p.mtx.Lock() + defer p.mtx.Unlock() + + p.removeSession(sid, nil, success) +} + +// RemoveConfirmedRuns removes all messages including pair requests from +// runs which ended in each peer sending a confirm mix message. +func (p *Pool) RemoveConfirmedRuns() { + p.mtx.Lock() + defer p.mtx.Unlock() + +} + +func (p *Pool) removeConfirmedRuns() { + for sid, ses := range p.sessions { + lastRun := &ses.runs[len(ses.runs)-1] + cmCount := lastRun.countFor(msgtypeCM) + if len(lastRun.prs) != int(cmCount) { + continue + } + + delete(p.sessions, sid) + for _, prevRun := range ses.runs[:len(ses.runs)-1] { + for hash := range prevRun.hashes { + delete(p.pool, hash) + } + } + for hash := range lastRun.hashes { + delete(p.pool, hash) + } + + for _, hash := range lastRun.prs { + delete(p.pool, hash) + pr := p.prs[hash] + if pr != nil { + logf("Removing PR %s by %x", + hash, pr.Identity[:]) + delete(p.prs, hash) + delete(p.latestKE, pr.Identity) + } + } + } +} + +// RemoveConfirmedMixes removes sessions and messages belong to a completed +// session that resulted in a published or mined transactions. Transaction +// hashes not associated with a session are ignored. PRs from the successful +// mix run are removed from the pool. +func (p *Pool) RemoveConfirmedMixes(txHashes []chainhash.Hash) { + p.mtx.Lock() + defer p.mtx.Unlock() + + for _, hash := range txHashes { + ses := p.sessionsByTxHash[hash] + if ses == nil { + continue + } + + p.removeSession(ses.sid, &hash, true) + } +} + +// ReceiveKEsByPairing returns the most recently received run-0 KE messages by +// a peer that reference PRs of a particular pairing and epoch. +func (p *Pool) ReceiveKEsByPairing(pairing []byte, epoch uint64) []*wire.MsgMixKeyExchange { + p.mtx.RLock() + defer p.mtx.RUnlock() + + var kes []*wire.MsgMixKeyExchange + for id, ke := range p.latestKE { + prHash := p.messagesByIdentity[id][0] + pr := p.prs[prHash] + prPairing, err := pr.Pairing() + if err != nil { + continue + } + if bytes.Equal(pairing, prPairing) && ke.Epoch == epoch { + kes = append(kes, ke) + } + } + return kes +} + +// RemoveUnresponsiveDuringEpoch removes pair requests of unresponsive peers +// that did not provide any key exchange messages during the epoch in which a +// mix occurred. +func (p *Pool) RemoveUnresponsiveDuringEpoch(pairing []byte, epoch uint64) { + p.mtx.Lock() + defer p.mtx.Unlock() + + var unresponsive []*wire.MsgMixPairReq + +PRLoop: + for _, pr := range p.prs { + prPairing, err := pr.Pairing() + if err != nil { + log.Warnf("pr.Pairing() of accepted pair request errored: %v", + err) + continue + } + if !bytes.Equal(pairing, prPairing) { + continue + } + + for _, msgHash := range p.messagesByIdentity[pr.Identity] { + msg, ok := p.pool[msgHash].msg.(*wire.MsgMixKeyExchange) + if !ok { + continue + } + if msg.Epoch == epoch { + continue PRLoop + } + } + + unresponsive = append(unresponsive, pr) + } + + for _, pr := range unresponsive { + id := &pr.Identity + for _, hash := range p.messagesByIdentity[*id] { + delete(p.pool, hash) + if p.prs[hash] != nil { + logf("Removing PR %s by %x", hash, id[:]) + } + delete(p.prs, hash) + } + delete(p.messagesByIdentity, *id) + delete(p.latestKE, *id) + } +} + +// Received is a parameter for Pool.Receive describing the session and run to +// receive messages for, and slices for returning results. Only non-nil slices +// will be appended to. Received messages are unsorted. +type Received struct { + Sid [32]byte + Run uint32 + KEs []*wire.MsgMixKeyExchange + CTs []*wire.MsgMixCiphertexts + SRs []*wire.MsgMixSlotReserve + DCs []*wire.MsgMixDCNet + CMs []*wire.MsgMixConfirm + RSs []*wire.MsgMixSecrets +} + +// Receive returns messages matching a session, run, and message type, waiting +// until all described messages have been received, or earlier with the +// messages received so far if the context is cancelled before this point. +// +// Receive only returns results for the session ID and run increment in the r +// parameter. If no such session or run has any messages currently accepted +// in the mixpool, the method immediately errors. +// +// If any secrets messages are received for the described session and run, and +// r.RSs is nil, Receive immediately returns ErrSecretsRevealed. An +// additional call to Receive with a non-nil RSs can be used to receive all of +// the secrets after each peer publishes their own revealed secrets. +func (p *Pool) Receive(ctx context.Context, expectedMessages int, r *Received) error { + sid := r.Sid + run := r.Run + var bc *broadcast + var rs *runstate + var err error + + p.mtx.RLock() + ses, ok := p.sessions[sid] + if !ok { + p.mtx.RUnlock() + return fmt.Errorf("unknown session") + } + bc = &ses.bc + if run >= uint32(len(ses.runs)) { + p.mtx.RUnlock() + return fmt.Errorf("unknown run") + } + rs = &ses.runs[run] + +Loop: + for { + // Pool is locked for reads. Count if the total number of + // expected messages have been received. + received := 0 + for hash := range rs.hashes { + msgtype := p.pool[hash].msgtype + switch { + case msgtype == msgtypeKE && r.KEs != nil: + received++ + case msgtype == msgtypeCT && r.CTs != nil: + received++ + case msgtype == msgtypeSR && r.SRs != nil: + received++ + case msgtype == msgtypeDC && r.DCs != nil: + received++ + case msgtype == msgtypeCM && r.CMs != nil: + received++ + case msgtype == msgtypeRS: + if r.RSs == nil { + // Since initial reporters of secrets + // need to take the blame for + // erroneous blame assignment if no + // issue was detected, we only trigger + // this for RS messages that do not + // reference any other previous RS. + prev := p.pool[hash].msg.(*wire.MsgMixSecrets).PrevMsgs + if len(prev) == 0 { + p.mtx.RUnlock() + return ErrSecretsRevealed + } + } else { + received++ + } + } + } + if received >= expectedMessages { + break + } + + // Unlock while waiting for the broadcast channel. + p.mtx.RUnlock() + + select { + case <-ctx.Done(): + // Set error to be returned, but still lock the pool + // and collect received messages. + err = ctx.Err() + p.mtx.RLock() + break Loop + case <-bc.wait(): + } + + p.mtx.RLock() + } + + // Pool is locked for reads. Collect all of the messages. + for hash := range rs.hashes { + msg := p.pool[hash].msg + switch msg := msg.(type) { + case *wire.MsgMixKeyExchange: + if r.KEs != nil { + r.KEs = append(r.KEs, msg) + } + case *wire.MsgMixCiphertexts: + if r.CTs != nil { + r.CTs = append(r.CTs, msg) + } + case *wire.MsgMixSlotReserve: + if r.SRs != nil { + r.SRs = append(r.SRs, msg) + } + case *wire.MsgMixDCNet: + if r.DCs != nil { + r.DCs = append(r.DCs, msg) + } + case *wire.MsgMixConfirm: + if r.CMs != nil { + r.CMs = append(r.CMs, msg) + } + case *wire.MsgMixSecrets: + if r.RSs != nil { + r.RSs = append(r.RSs, msg) + } + } + } + + p.mtx.RUnlock() + return err +} + +// AcceptMessage accepts a mixing message to the pool. +// +// Messages must contain the mixing participant's identity and contain a valid +// signature committing to all non-signature fields. +// +// PR messages will not be accepted if they reference an unknown UTXO or if not +// enough fee is contributed. Any other message will not be accepted if it +// references previous messages that are not recorded by the pool. +func (p *Pool) AcceptMessage(msg mixing.Message) (accepted mixing.Message, err error) { + hash := msg.Hash() + alreadyAccepted := func() bool { + _, ok := p.pool[hash] + if !ok { + _, ok = p.prs[hash] + } + return ok + } + + // Check if already accepted. + p.mtx.RLock() + ok := alreadyAccepted() + p.mtx.RUnlock() + if ok { + return nil, nil + } + + // Require message to be signed by the presented identity. + if !mixing.VerifyMessageSignature(msg) { + return nil, ruleError(ErrInvalidSignature) + } + id := (*idPubKey)(msg.Pub()) + + var msgtype msgtype + switch msg := msg.(type) { + case *wire.MsgMixPairReq: + accepted, err := p.acceptPR(msg, &hash, id) + if err != nil { + return nil, err + } + // Avoid returning a non-nil mixing.Message in return + // variable with a nil PR. + if accepted == nil { + return nil, nil + } + return accepted, nil + + case *wire.MsgMixKeyExchange: + accepted, err := p.acceptKE(msg, &hash, id) + if err != nil { + return nil, err + } + // Avoid returning a non-nil mixing.Message in return + // variable with a nil KE. + if accepted == nil { + return nil, nil + } + return accepted, nil + + case *wire.MsgMixCiphertexts: + msgtype = msgtypeCT + case *wire.MsgMixSlotReserve: + msgtype = msgtypeSR + case *wire.MsgMixDCNet: + msgtype = msgtypeDC + case *wire.MsgMixConfirm: + msgtype = msgtypeCM + case *wire.MsgMixSecrets: + msgtype = msgtypeRS + default: + return nil, fmt.Errorf("unknown mix message type %T", msg) + } + + sid := *(*[32]byte)(msg.Sid()) + + p.mtx.Lock() + defer p.mtx.Unlock() + + // Read lock was given up to acquire write lock. Check if already + // accepted. + if alreadyAccepted() { + return nil, nil + } + + // Check prior message existence in the pool, and only accept messages + // that reference other known and accepted messages of the correct type + // and sid. + // + // TODO: Consider return an error containing the unknown messages, so + // they can be getdata'd, and if they are not received or are garbage, + // peers can be kicked. + prevMsgs := msg.PrevMsgs() + for i := range prevMsgs { + looktype := msgtype - 1 + if msgtype == msgtypeRS { + looktype = 0 + } + _, ok := p.lookupEntry(prevMsgs[i], looktype, &sid) + if !ok { + return nil, fmt.Errorf("%s %s references unknown "+ + "previous message %s", msgtype, &hash, + &prevMsgs[i]) + } + } + + // Check that a message from this identity does not reuse a run number + // for the session. + for _, prevHash := range p.messagesByIdentity[*id] { + e := p.pool[prevHash] + run := msg.GetRun() + if e.msgtype == msgtype && e.msg.GetRun() == run && + bytes.Equal(e.msg.Sid(), msg.Sid()) { + return nil, fmt.Errorf("message %v by identity %x "+ + "reuses run number %d in session %x, "+ + "conflicting with already accepted message %v", + hash, *id, run, msg.Sid(), prevHash) + } + } + + ses := p.sessions[sid] + if ses == nil { + return nil, fmt.Errorf("%s %s belongs to unknown session %x", + msgtype, &hash, &sid) + } + + err = p.acceptEntry(msg, msgtype, &hash, id, ses) + if err != nil { + return nil, err + } + return msg, nil +} + +func (p *Pool) acceptPR(pr *wire.MsgMixPairReq, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixPairReq, err error) { + switch { + case len(pr.UTXOs) == 0: // Require at least one utxo. + return nil, ruleError(ErrMissingUTXOs) + case pr.MessageCount == 0: // Require at least one mixed message. + return nil, ruleError(ErrInvalidMessageCount) + case pr.InputValue < int64(pr.MessageCount)*pr.MixAmount: + return nil, ruleError(ErrInvalidTotalMixAmount) + case pr.Change != nil: + if isDustAmount(pr.Change.Value, p2pkhv0PkScriptSize, feeRate) { + return nil, ruleError(ErrChangeDust) + } + if !stdscript.IsPubKeyHashScriptV0(pr.Change.PkScript) && + !stdscript.IsScriptHashScriptV0(pr.Change.PkScript) { + return nil, ruleError(ErrInvalidScript) + } + } + + // Check that expiry has not been reached, nor that it is too far + // into the future. This limits replay attacks. + _, curHeight := p.blockchain.BestHeader() + maxExpiry := mixing.MaxExpiry(uint32(curHeight), p.params) + switch { + case uint32(curHeight) >= pr.Expiry: + return nil, fmt.Errorf("message has expired") + case pr.Expiry > maxExpiry: + return nil, fmt.Errorf("expiry is too far into future") + } + + // Require known script classes. + switch mixing.ScriptClass(pr.ScriptClass) { + case mixing.ScriptClassP2PKHv0: + default: + return nil, fmt.Errorf("unsupported mixing script class") + } + + // Require enough fee contributed from this mixing participant. + // Size estimation assumes mixing.ScriptClassP2PKHv0 outputs and inputs. + err = checkFee(pr, p.feeRate) + if err != nil { + return nil, err + } + + // If able, sanity check UTXOs. + if p.utxoFetcher != nil { + err := p.checkUTXOs(pr) + if err != nil { + return nil, err + } + } + + p.mtx.Lock() + defer p.mtx.Unlock() + + // Check if already accepted. + if _, ok := p.prs[*hash]; ok { + return nil, nil + } + + // Discourage identity reuse. PRs should be the first message sent by + // this identity, and there should only be one PR per identity. + if len(p.messagesByIdentity[*id]) != 0 { + return nil, fmt.Errorf("identity reused for a PR message") + } + + // Only accept PRs that double spend outpoints if they expire later + // than existing PRs. Otherwise, reject this PR message. + for i := range pr.UTXOs { + otherPRHash := p.outPoints[pr.UTXOs[i].OutPoint] + otherPR, ok := p.prs[otherPRHash] + if !ok { + continue + } + if otherPR.Expiry >= pr.Expiry { + err := fmt.Errorf("PR double spends outpoints of " + + "already-accepted PR message without " + + "increasing expiry") + return nil, err + } + } + + // Accept the PR + p.prs[*hash] = pr + for i := range pr.UTXOs { + p.outPoints[pr.UTXOs[i].OutPoint] = *hash + } + p.messagesByIdentity[*id] = append(make([]chainhash.Hash, 0, 16), *hash) + + return pr, nil +} + +// Check that UTXOs exist, have confirmations, sum of UTXO values matches the +// input value, and proof of ownership is valid. +func (p *Pool) checkUTXOs(pr *wire.MsgMixPairReq) error { + var totalValue int64 + _, curHeight := p.blockchain.BestHeader() + + for i := range pr.UTXOs { + utxo := &pr.UTXOs[i] + entry, err := p.utxoFetcher.FetchUtxoEntry(utxo.OutPoint) + if err != nil { + return err + } + if entry == nil || entry.IsSpent() { + return fmt.Errorf("output %v is not unspent", + &utxo.OutPoint) + } + height := entry.BlockHeight() + if !confirmed(minconf, height, curHeight) { + return fmt.Errorf("output %v is unconfirmed", + &utxo.OutPoint) + } + if entry.ScriptVersion() != 0 { + return fmt.Errorf("output %v does not use script version 0", + &utxo.OutPoint) + } + + // Check proof of key ownership and ability to sign coinjoin + // inputs. + utxoPkScript := entry.PkScript() + var extractPubKeyHash160 func([]byte) []byte + switch { + case stdscript.IsPubKeyHashScriptV0(utxoPkScript): + extractPubKeyHash160 = stdscript.ExtractPubKeyHashV0 + case stdscript.IsStakeGenPubKeyHashScriptV0(utxoPkScript): + extractPubKeyHash160 = stdscript.ExtractStakeGenPubKeyHashV0 + case stdscript.IsStakeRevocationPubKeyHashScriptV0(utxoPkScript): + extractPubKeyHash160 = stdscript.ExtractStakeRevocationPubKeyHashV0 + case stdscript.IsTreasuryGenPubKeyHashScriptV0(utxoPkScript): + extractPubKeyHash160 = stdscript.ExtractTreasuryGenPubKeyHashV0 + default: + return fmt.Errorf("unsupported output script for UTXO %s", &utxo.OutPoint) + } + valid := validateOwnerProofP2PKHv0(extractPubKeyHash160, + utxoPkScript, utxo.PubKey, utxo.Signature, pr.Expires()) + if !valid { + return ruleError(ErrInvalidUTXOProof) + } + + totalValue += entry.Amount() + } + + if totalValue != pr.InputValue { + return fmt.Errorf("input value does not match sum of UTXO " + + "values") + } + + return nil +} + +func validateOwnerProofP2PKHv0(extractFunc func([]byte) []byte, pkscript, pubkey, sig []byte, expires uint32) bool { + extractedHash160 := extractFunc(pkscript) + pubkeyHash160 := dcrutil.Hash160(pubkey) + if !bytes.Equal(extractedHash160, pubkeyHash160) { + return false + } + + return utxoproof.ValidateSecp256k1P2PKH(pubkey, sig, expires) +} + +func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixKeyExchange, err error) { + sid := ke.SessionID + + // In all runs, previous PR messages in the KE must be sorted. + // This defines the initial unmixed peer positions. + sorted := sort.SliceIsSorted(ke.SeenPRs, func(i, j int) bool { + a := ke.SeenPRs[i][:] + b := ke.SeenPRs[j][:] + return bytes.Compare(a, b) == -1 + }) + if !sorted { + err := fmt.Errorf("KE message contains unsorted previous PR " + + "hashes") + return nil, err + } + + // Run-0 KE messages define a session ID by hashing all previously-seen + // PR message hashes. This must match the sid also present in the + // message. Later runs after a failed run may drop peers from the + // SeenPRs set, but the sid remains the same. A sid cannot be conjured + // out of thin air, and other messages seen from the network for an + // unknown session are not accepted. + if ke.Run == 0 { + derivedSid := mixing.DeriveSessionID(ke.SeenPRs, ke.Epoch) + if sid != derivedSid { + return nil, ruleError(ErrInvalidSessionID) + } + } + + p.mtx.Lock() + defer p.mtx.Unlock() + + // Check if already accepted. + if _, ok := p.pool[*hash]; ok { + return nil, nil + } + + // While KEs are allowed to reference unknown PRs, they must at least + // reference the PR submitted by their own identity. The + // messagesByIdentity map will contain at least one entry after a PR + // is received, and their PR will be the first hash in this slice. + // Of all PRs that are known, their pairing types must be compatible. + if len(p.messagesByIdentity[ke.Identity]) == 0 { + err := fmt.Errorf("KE identity %x has not submitted a PR", ke.Identity) + return nil, err + } + ownPRHash := p.messagesByIdentity[ke.Identity][0] + var ownPR *wire.MsgMixPairReq + prs := make([]*wire.MsgMixPairReq, 0, len(ke.SeenPRs)) + var pairing []byte + for _, seenPR := range ke.SeenPRs { + pr, ok := p.prs[seenPR] + if !ok { + continue + } + if seenPR == ownPRHash { + ownPR = pr + } + if pairing == nil { + var err error + pairing, err = pr.Pairing() + if err != nil { + return nil, err + } + } else { + pairing2, err := pr.Pairing() + if err != nil { + return nil, err + } + if !bytes.Equal(pairing, pairing2) { + err := fmt.Errorf("referenced PRs are incompatible") + return nil, err + } + } + } + if ownPR == nil { + err := fmt.Errorf("KE identity's own PR unexpectedly missing from mixpool") + return nil, err + } + + ses := p.sessions[sid] + + // Create a session for the first run-0 KE + if ses == nil { + if ke.Run != 0 { + err := fmt.Errorf("unknown session for run-%d KE", + ke.Run) + return nil, err + } + + expiry := ^uint32(0) + for i := range prs { + prExpiry := prs[i].Expires() + if expiry > prExpiry { + expiry = prExpiry + } + } + ses = &session{ + sid: sid, + runs: make([]runstate, 0, 4), + expiry: expiry, + bc: broadcast{ch: make(chan struct{})}, + } + p.sessions[sid] = ses + } + + err = p.acceptEntry(ke, msgtypeKE, hash, id, ses) + if err != nil { + return nil, err + } + p.latestKE[*id] = ke + return ke, nil +} + +func (p *Pool) acceptEntry(msg mixing.Message, msgtype msgtype, hash *chainhash.Hash, + id *[33]byte, ses *session) error { + + run := msg.GetRun() + if msg.GetRun() > uint32(len(ses.runs)) { + return fmt.Errorf("message skips runs") + } + + var rs *runstate + if msgtype == msgtypeKE && msg.GetRun() == uint32(len(ses.runs)) { + // Add a runstate for the next run. + ses.runs = append(ses.runs, runstate{ + run: msg.GetRun(), + prs: msg.PrevMsgs(), + hashes: make(map[chainhash.Hash]struct{}), + }) + rs = &ses.runs[len(ses.runs)-1] + } else { + // Add to existing runstate + rs = &ses.runs[run] + } + + rs.hashes[*hash] = struct{}{} + e := entry{ + hash: *hash, + sid: ses.sid, + recvTime: time.Now(), + msg: msg, + msgtype: msgtype, + run: msg.GetRun(), + } + p.pool[*hash] = e + p.messagesByIdentity[*id] = append(p.messagesByIdentity[*id], *hash) + + if cm, ok := msg.(*wire.MsgMixConfirm); ok { + p.sessionsByTxHash[cm.Mix.TxHash()] = ses + } + + rs.incrementCountFor(msgtype) + ses.bc.signal() + + return nil +} + +// lookupEntry returns the message entry matching a message hash with msgtype +// and session id. If msgtype is zero, any message type can be looked up. +func (p *Pool) lookupEntry(hash chainhash.Hash, msgtype msgtype, sid *[32]byte) (entry, bool) { + e, ok := p.pool[hash] + if !ok { + return entry{}, false + } + if msgtype != 0 && e.msgtype != msgtype { + return entry{}, false + } + if e.sid != *sid { + return entry{}, false + } + + return e, true +} + +func confirmed(minConf, txHeight, curHeight int64) bool { + return confirms(txHeight, curHeight) >= minConf +} +func confirms(txHeight, curHeight int64) int64 { + switch { + case txHeight == -1, txHeight > curHeight: + return 0 + default: + return curHeight - txHeight + 1 + } +} + +// isDustAmount determines whether a transaction output value and script length would +// cause the output to be considered dust. Transactions with dust outputs are +// not standard and are rejected by mempools with default policies. +func isDustAmount(amount int64, scriptSize int, relayFeePerKb int64) bool { + // Calculate the total (estimated) cost to the network. This is + // calculated using the serialize size of the output plus the serial + // size of a transaction input which redeems it. The output is assumed + // to be compressed P2PKH as this is the most common script type. Use + // the average size of a compressed P2PKH redeem input (165) rather than + // the largest possible (txsizes.RedeemP2PKHInputSize). + totalSize := 8 + 2 + wire.VarIntSerializeSize(uint64(scriptSize)) + + scriptSize + 165 + + // Dust is defined as an output value where the total cost to the network + // (output size + input size) is greater than 1/3 of the relay fee. + return amount*1000/(3*int64(totalSize)) < relayFeePerKb +} + +func checkFee(pr *wire.MsgMixPairReq, feeRate int64) error { + fee := pr.InputValue - int64(pr.MessageCount)*pr.MixAmount + if pr.Change != nil { + fee -= pr.Change.Value + } + + estimatedSize := estimateP2PKHv0SerializeSize(len(pr.UTXOs), + int(pr.MessageCount), pr.Change != nil) + requiredFee := feeForSerializeSize(feeRate, estimatedSize) + if fee < requiredFee { + return fmt.Errorf("not enough input value, or too low fee") + } + + return nil +} + +func feeForSerializeSize(relayFeePerKb int64, txSerializeSize int) int64 { + fee := relayFeePerKb * int64(txSerializeSize) / 1000 + + if fee == 0 && relayFeePerKb > 0 { + fee = relayFeePerKb + } + + const maxAmount = 21e6 * 1e8 + if fee < 0 || fee > maxAmount { + fee = maxAmount + } + + return fee +} + +const ( + redeemP2PKHv0SigScriptSize = 1 + 73 + 1 + 33 + p2pkhv0PkScriptSize = 1 + 1 + 1 + 20 + 1 + 1 +) + +func estimateP2PKHv0SerializeSize(inputs, outputs int, hasChange bool) int { + // Sum the estimated sizes of the inputs and outputs. + txInsSize := inputs * estimateInputSize(redeemP2PKHv0SigScriptSize) + txOutsSize := outputs * estimateOutputSize(p2pkhv0PkScriptSize) + + changeSize := 0 + if hasChange { + changeSize = estimateOutputSize(p2pkhv0PkScriptSize) + outputs++ + } + + // 12 additional bytes are for version, locktime and expiry. + return 12 + (2 * wire.VarIntSerializeSize(uint64(inputs))) + + wire.VarIntSerializeSize(uint64(outputs)) + + txInsSize + txOutsSize + changeSize +} + +// estimateInputSize returns the worst case serialize size estimate for a tx input +func estimateInputSize(scriptSize int) int { + return 32 + // previous tx + 4 + // output index + 1 + // tree + 8 + // amount + 4 + // block height + 4 + // block index + wire.VarIntSerializeSize(uint64(scriptSize)) + // size of script + scriptSize + // script itself + 4 // sequence +} + +// estimateOutputSize returns the worst case serialize size estimate for a tx output +func estimateOutputSize(scriptSize int) int { + return 8 + // previous tx + 2 + // version + wire.VarIntSerializeSize(uint64(scriptSize)) + // size of script + scriptSize // script itself +} diff --git a/wire/msgmixsecrets.go b/wire/msgmixsecrets.go index 53d0b06e2d..7532cc0bf5 100644 --- a/wire/msgmixsecrets.go +++ b/wire/msgmixsecrets.go @@ -25,6 +25,7 @@ type MsgMixSecrets struct { Seed [32]byte SlotReserveMsgs [][]byte DCNetMsgs MixVect + SeenSecrets []chainhash.Hash // hash records the hash of the message. It is a member of the // message for convenience and performance, but is never automatically @@ -43,6 +44,9 @@ func writeMixVect(op string, w io.Writer, pver uint32, vec MixVect) error { if err != nil { return err } + if len(vec) == 0 { + return nil + } err = WriteVarInt(w, pver, MixMsgSize) if err != nil { return err @@ -138,6 +142,25 @@ func (msg *MsgMixSecrets) BtcDecode(r io.Reader, pver uint32) error { } msg.DCNetMsgs = dcnetMsgs + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + seen := make([]chainhash.Hash, count) + for i := range seen { + err := readElement(r, &seen[i]) + if err != nil { + return err + } + } + msg.SeenSecrets = seen + return nil } @@ -173,6 +196,17 @@ func (msg *MsgMixSecrets) Hash() chainhash.Hash { return msg.hash } +// Commitment returns a hash committing to the contents of the reveal secrets +// message without committing to any previous seen messages or the message +// signature. This is the hash that is referenced by peers' key exchange +// messages. +func (msg *MsgMixSecrets) Commitment(h hash.Hash) chainhash.Hash { + msgCopy := *msg + msgCopy.SeenSecrets = nil + msgCopy.WriteHash(h) + return msgCopy.hash +} + // WriteHash serializes the message to a hasher and records the sum in the // message's Hash field. // @@ -196,6 +230,14 @@ func (msg *MsgMixSecrets) WriteHash(h hash.Hash) { // // This method never errors for invalid message construction. func (msg *MsgMixSecrets) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + // Limit to max previous messages hashes. + count := len(msg.SeenSecrets) + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run, &msg.Seed) if err != nil { return err @@ -217,6 +259,17 @@ func (msg *MsgMixSecrets) writeMessageNoSignature(op string, w io.Writer, pver u return err } + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := range msg.SeenSecrets { + err := writeElement(w, &msg.SeenSecrets[i]) + if err != nil { + return err + } + } + return nil } @@ -242,7 +295,7 @@ func (msg *MsgMixSecrets) MaxPayloadLength(pver uint32) uint32 { } // See tests for this calculation - return 54444 + return 70831 } // Pub returns the message sender's public key identity. @@ -255,13 +308,14 @@ func (msg *MsgMixSecrets) Sig() []byte { return msg.Signature[:] } -// PrevMsgs returns nil. Previous messages are not needed to perform blame -// assignment, because of the assumption that all previous messages must have -// been received for a blame stage to be necessary. Additionally, a -// commitment to the secrets message is included in the key exchange, and -// future message hashes are not available at that time. +// PrevMsgs returns previously revealed secrets messages by other peers. An +// honest peer who needs to report blame assignment does not need to reference +// any previous secrets messages, and a secrets message with other referenced +// secrets is necessary to begin blame assignment. Dishonest peers who +// initially reveal their secrets without blame assignment being necessary are +// themselves removed in future runs. func (msg *MsgMixSecrets) PrevMsgs() []chainhash.Hash { - return nil + return msg.SeenSecrets } // Sid returns the session ID. @@ -287,5 +341,6 @@ func NewMsgMixSecrets(identity [33]byte, sid [32]byte, run uint32, Seed: seed, SlotReserveMsgs: slotReserveMsgs, DCNetMsgs: dcNetMsgs, + SeenSecrets: make([]chainhash.Hash, 0), } } diff --git a/wire/msgmixsecrets_test.go b/wire/msgmixsecrets_test.go index 2905d94eb4..e0fc7935a4 100644 --- a/wire/msgmixsecrets_test.go +++ b/wire/msgmixsecrets_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" ) func newTestMixSecrets() *MsgMixSecrets { @@ -34,7 +35,13 @@ func newTestMixSecrets() *MsgMixSecrets { copy(m[b-0x89][:], repeat(b, 20)) } + seenRSs := make([]chainhash.Hash, 4) + for b := byte(0x8D); b < 0x91; b++ { + copy(seenRSs[b-0x8D][:], repeat(b, 32)) + } + rs := NewMsgMixSecrets(id, sid, run, seed, sr, m) + rs.SeenSecrets = seenRSs rs.Signature = sig return rs @@ -67,12 +74,18 @@ func TestMsgMixSecretsWire(t *testing.T) { expected = append(expected, repeat(0x87, 32)...) expected = append(expected, 0x20) expected = append(expected, repeat(0x88, 32)...) - // Four slot reservation mixed messages (repeating 20 bytes of 0x89, 0x8a, 0x8b, 0x8c) + // Four xor dc-net mixed messages (repeating 20 bytes of 0x89, 0x8a, 0x8b, 0x8c) expected = append(expected, 0x04, 0x14) expected = append(expected, repeat(0x89, 20)...) expected = append(expected, repeat(0x8a, 20)...) expected = append(expected, repeat(0x8b, 20)...) expected = append(expected, repeat(0x8c, 20)...) + // Four seen RSs (repeating 32 bytes of 0x8d, 0x8e, 0x8f, 0x90) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x8d, 32)...) + expected = append(expected, repeat(0x8e, 32)...) + expected = append(expected, repeat(0x8f, 32)...) + expected = append(expected, repeat(0x90, 32)...) expectedSerializationEqual(t, buf.Bytes(), expected) @@ -168,7 +181,9 @@ func TestMsgMixSecretsMaxPayloadLength(t *testing.T) { MaxMixMcount*varBytesLen(MaxMixFieldValLen) + // Unpadded SR values uint32(VarIntSerializeSize(MaxMixMcount)) + // DC-net message count uint32(VarIntSerializeSize(MixMsgSize)) + // DC-net message size - MaxMixMcount*MixMsgSize // DC-net messages + MaxMixMcount*MixMsgSize + // DC-net messages + uint32(VarIntSerializeSize(MaxMixPeers)) + // RS count + 32*MaxMixPeers // RS hashes tests := []struct { name string