diff --git a/mixing/mixclient/client.go b/mixing/mixclient/client.go new file mode 100644 index 0000000000..099aff0cf2 --- /dev/null +++ b/mixing/mixclient/client.go @@ -0,0 +1,1566 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package mixclient + +import ( + "bytes" + "context" + cryptorand "crypto/rand" + "crypto/subtle" + "errors" + "fmt" + "hash" + "io" + "math/big" + "math/bits" + "sort" + "sync" + "time" + + "decred.org/cspp/v2/solverrpc" + "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/crypto/blake256" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/decred/dcrd/dcrutil/v4" + "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/internal/chacha20prng" + "github.com/decred/dcrd/mixing/mixpool" + "github.com/decred/dcrd/txscript/v4" + "github.com/decred/dcrd/wire" +) + +// MinPeers is the minimum number of peers required for a mix run to proceed. +const MinPeers = 3 + +// ErrExpired indicates that a dicemix session failed to complete due to the +// submitted pair request expiring. +var ErrExpired = errors.New("mixing pair request expired") + +// Wallet signs mix transactions and listens for and broadcasts mixing +// protocol messages. +// +// While wallets are responsible for generating mixed addresses, this duty is +// performed by the generator function provided to NewCoinJoin rather than +// this interface. This allows each CoinJoin to pass in a generator closure +// for different BIP0032 accounts and branches. +type Wallet interface { + BestBlock() (uint32, chainhash.Hash) + + // Mixpool returns access to the wallet's mixing message pool. + // + // The mixpool should only be used for message access and deletion, + // but never publishing; SubmitMixMessage must be used instead for + // message publishing. + Mixpool() *mixpool.Pool + + // SubmitMixMessage submits a mixing message to the wallet's mixpool + // and broadcasts it to the network. + SubmitMixMessage(ctx context.Context, msg mixing.Message) error + + // SignInput adds a signature script to a transaction input. + SignInput(tx *wire.MsgTx, index int, prevScript []byte) error + + // PublishTransaction adds the transaction to the wallet and publishes + // it to the network. + PublishTransaction(ctx context.Context, tx *wire.MsgTx) error +} + +type deadlines struct { + epoch time.Time + recvKE time.Time + recvCT time.Time + recvSR time.Time + recvDC time.Time + recvCM time.Time +} + +const timeoutDuration = 30 * time.Second + +func (d *deadlines) start(begin time.Time) { + t := begin + add := func() time.Time { + t = t.Add(timeoutDuration) + return t + } + d.recvKE = add() + d.recvCT = add() + d.recvSR = add() + d.recvDC = add() + d.recvCM = add() +} + +func (d *deadlines) shift() { + d.recvKE = d.recvCT + d.recvCT = d.recvSR + d.recvSR = d.recvDC + d.recvDC = d.recvCM + d.recvCM = d.recvCM.Add(timeoutDuration) +} + +func (d *deadlines) restart() { + d.start(d.recvCM) +} + +// Client manages local mixing client sessions. +type Client struct { + wallet Wallet + mixpool *mixpool.Pool + + // Pending and active sessions and peers (both local and, when + // blaming, remote). + pairings map[string]*pairedSessions + height uint32 + mu sync.Mutex + + pairingWG sync.WaitGroup + + blake256Hasher hash.Hash + blake256HasherMu sync.Mutex + + epoch time.Duration + + logger Logger +} + +// NewClient creates a wallet's mixing client manager. +func NewClient(w Wallet) *Client { + height, _ := w.BestBlock() + return &Client{ + wallet: w, + mixpool: w.Mixpool(), + pairings: make(map[string]*pairedSessions), + blake256Hasher: blake256.New(), + epoch: 10 * time.Minute, + height: height, + } +} + +type Logger interface { + Log(args ...interface{}) + Logf(format string, args ...interface{}) +} + +func (c *Client) SetLogger(l Logger) { + c.logger = l +} + +func (c *Client) log(args ...interface{}) { + if c.logger == nil { + return + } + + c.logger.Log(args...) +} + +func (c *Client) logf(format string, args ...interface{}) { + if c.logger == nil { + return + } + + c.logger.Logf(format, args...) +} + +func (c *Client) logerrf(format string, args ...interface{}) { + // XXX use ERR subsystem + c.logf(format, args...) +} + +func (c *Client) sessionLog(sid [32]byte) *sessionLogger { + return &sessionLogger{sid: sid, logger: c.logger} +} + +type sessionLogger struct { + sid [32]byte + logger Logger +} + +func (l *sessionLogger) logf(format string, args ...interface{}) { + l.logger.Logf("sid=%x "+format, append([]interface{}{l.sid[:]}, args...)...) +} + +func (l *sessionLogger) log(args ...interface{}) { + s := fmt.Sprintf("sid=%x", l.sid[:]) + l.logger.Log(append([]interface{}{s}, args...)...) +} + +// Run runs the client manager, blocking until after the context is +// cancelled. +func (s *Client) Run(ctx context.Context) error { + s.epochTicker(ctx) + return ctx.Err() +} + +// pairedSessions tracks the waiting and in-progress mix sessions performed by +// one or more local peers using compatible pairings. +type pairedSessions struct { + localPeers map[identity]*peer + pairing []byte + runs []sessionRun +} + +type sessionRun struct { + sid [32]byte + run uint32 + mtot uint32 + + // Peers sorted by PR hashes. Each peer's myVk is its index in this + // slice. + prs []*wire.MsgMixPairReq + peers []*peer + mcounts []uint32 + roots []*big.Int +} + +// peer represents a participating client in a peer-to-peer mixing session. +// Some fields only pertain to peers created by this wallet, while the rest +// are used during blame assignment. +type peer struct { + ctx context.Context + client *Client + rand io.Reader // non-PRNG cryptographic rand + + res chan error + + pub *secp256k1.PublicKey + priv *secp256k1.PrivateKey + id *identity // serialized pubkey + pr *wire.MsgMixPairReq + coinjoin *CoinJoin + kx *mixing.KX + + prngSeed *[32]byte + prng *chacha20prng.Reader + + rs *wire.MsgMixSecrets + srMsg []*big.Int // random numbers for the exponential slot reservation mix + dcMsg wire.MixVect // anonymized messages to publish in XOR mix + + ke *wire.MsgMixKeyExchange + ct *wire.MsgMixCiphertexts + sr *wire.MsgMixSlotReserve + dc *wire.MsgMixDCNet + cm *wire.MsgMixConfirm + + // Unmixed positions. May change over multiple sessions/runs. + myVk uint32 + myStart uint32 + + // Exponential slot reservation mix + srKP [][][]byte // shared keys for exp dc-net + srMix [][]*big.Int + srMixBytes [][][]byte + + // XOR DC-net + dcKP [][]wire.MixVect + dcNet []wire.MixVect + + // Whether next run must generate fresh KX keys, SR/DC messages + freshGen bool + + // Whether this peer represents a remote peer created from revealed secrets; + // used during blaming. + remote bool +} + +func newRemotePeer(pr *wire.MsgMixPairReq) *peer { + return &peer{ + id: &pr.Identity, + pr: pr, + remote: true, + } +} + +func generateSecp256k1(rand io.Reader) (*secp256k1.PublicKey, *secp256k1.PrivateKey, error) { + if rand == nil { + rand = cryptorand.Reader + } + + privateKey, err := secp256k1.GeneratePrivateKeyFromRand(rand) + if err != nil { + return nil, nil, err + } + + publicKey := privateKey.PubKey() + + return publicKey, privateKey, nil +} + +// Dicemix performs a new mixing session for a coinjoin mix transaction. +func (c *Client) Dicemix(ctx context.Context, rand io.Reader, cj *CoinJoin) error { + pub, priv, err := generateSecp256k1(rand) + if err != nil { + return err + } + + p := &peer{ + ctx: ctx, + client: c, + res: make(chan error, 1), + pub: pub, + priv: priv, + id: (*[33]byte)(pub.SerializeCompressed()), + rand: rand, + coinjoin: cj, + freshGen: true, + } + + pr, err := wire.NewMsgMixPairReq(*p.id, cj.prExpiry, cj.mixValue, + string(mixing.ScriptClassP2PKHv0), cj.tx.Version, + cj.tx.LockTime, cj.mcount, cj.inputValue, cj.prUTXOs, + cj.change) + if err != nil { + return err + } + err = p.signAndHash(pr) + if err != nil { + return err + } + pairingID, err := pr.Pairing() + if err != nil { + return err + } + p.pr = pr + + c.logf("debug: created local peer id=%x PR=%s", p.id[:], p.pr.Hash()) + + c.mu.Lock() + pairing := c.pairings[string(pairingID)] + if pairing == nil { + pairing = c.newPairings(pairingID, nil) + c.pairings[string(pairingID)] = pairing + } + pairing.localPeers[*p.id] = p + c.mu.Unlock() + + err = p.submit(pr) + if err != nil { + c.mu.Lock() + pendingPairing := c.pairings[string(pairingID)] + if pendingPairing != nil { + delete(pendingPairing.localPeers, *p.id) + if len(pendingPairing.localPeers) == 0 { + delete(c.pairings, string(pairingID)) + } + } + c.mu.Unlock() + return err + } + + select { + case res := <-p.res: + return res + case <-ctx.Done(): + return ctx.Err() + } +} + +// ExpireMessages removes all mixpool messages and sessions that indicate an +// expiry height at or before the height parameter and removes any pending +// DiceMix sessions that did not complete. +func (c *Client) ExpireMessages(height uint32) { + c.mu.Lock() + defer c.mu.Unlock() + + // Just mark the current height for expireMessages() to use later, + // after epoch strikes. This allows us to keep using PRs during + // session forming even if they expire, which is required to continue + // mixing paired sessions that are performing reruns. + c.height = height + + // mixpool similarly expires messages in the background after current + // epoch ends. Call it here while in the current epoch, rather than + // when expireMessage() is called later immediately after the next + // epoch strikes. + c.mixpool.ExpireMessages(c.height) +} + +func (c *Client) expireMessages() { + for pairID, ps := range c.pairings { + for id, p := range ps.localPeers { + prHash := p.pr.Hash() + if !c.mixpool.HaveMessage(&prHash) { + delete(ps.localPeers, id) + // p.res is buffered. If the write is + // blocked, we have already served this peer + // or sent another error. + select { + case p.res <- ErrExpired: + default: + } + + } + } + if len(ps.localPeers) == 0 { + delete(c.pairings, pairID) + } + } +} + +// SetEpoch modifies the sessions to use a longer or shorter epoch duration. +// The default epoch is 10 minutes. +func (c *Client) SetEpoch(epoch time.Duration) { + c.epoch = epoch + c.mixpool.SetEpoch(epoch) +} + +// waitForEpoch blocks until the next epoch, or errors when the context is +// cancelled early. Returns the calculated epoch for stage timeout +// calculations. +func (c *Client) waitForEpoch(ctx context.Context) (time.Time, error) { + now := time.Now().UTC() + epoch := now.Truncate(c.epoch).Add(c.epoch) + duration := epoch.Sub(now) + timer := time.NewTimer(duration) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return epoch, ctx.Err() + case <-timer.C: + return epoch, nil + } +} + +func (p *peer) signAndHash(m mixing.Message) error { + err := mixing.SignMessage(m, p.priv) + if err != nil { + return err + } + + p.client.blake256HasherMu.Lock() + m.WriteHash(p.client.blake256Hasher) + p.client.blake256HasherMu.Unlock() + + return nil +} + +func (p *peer) submit(m mixing.Message) error { + err := p.client.wallet.SubmitMixMessage(p.ctx, m) + if err != nil { + return fmt.Errorf("submit %T: %w", m, err) + } + return nil +} + +func (p *peer) signAndSubmit(m mixing.Message) error { + err := p.signAndHash(m) + if err != nil { + return err + } + return p.submit(m) +} + +func (c *Client) newPairings(pairing []byte, peers map[identity]*peer) *pairedSessions { + if peers == nil { + peers = make(map[identity]*peer) + } + ps := &pairedSessions{ + localPeers: peers, + pairing: pairing, + runs: nil, + } + return ps +} + +func (c *Client) epochTicker(ctx context.Context) { + //c.waitForEpoch(ctx) + + for { + epoch, err := c.waitForEpoch(ctx) + if err != nil { + break + } + + c.log("debug: epoch tick") + + var d deadlines + d.epoch = epoch + d.start(epoch) + + // Wait for any previous pairSession calls to timeout if they + // have not yet formed a session before the next epoch tick. + c.pairingWG.Wait() + + c.mu.Lock() + c.mixpool.RemoveConfirmedRuns() + c.expireMessages() + for _, p := range c.pairings { + prs := c.mixpool.CompatiblePRs(p.pairing) + prsMap := make(map[identity]struct{}) + for _, pr := range prs { + prsMap[pr.Identity] = struct{}{} + } + // Clone the p.localPeers map, only including PRs + // currently accepted to mixpool. Adding additional + // waiting local peers must not add more to the map in + // use by runSessions, and deleting peers in a formed + // session from the pending map must not inadvertently + // remove from runSession's ps.localPeers map. + localPeers := make(map[identity]*peer) + for id, peer := range p.localPeers { + if _, ok := prsMap[id]; ok { + localPeers[id] = peer + } + } + c.logf("debug: have %d compatible/%d local PRs waiting for pairing %x", + len(prs), len(localPeers), p.pairing) + if len(prs) < MinPeers { + continue + } + ps := *p + ps.localPeers = localPeers + for id, peer := range p.localPeers { + ps.localPeers[id] = peer + } + c.pairingWG.Add(1) // pairSession calls Done + go c.pairSession(ctx, &ps, prs, d) + } + c.mu.Unlock() + } +} + +func sortPRsForSession(prs []*wire.MsgMixPairReq, epoch uint64) [32]byte { + sort.Slice(prs, func(i, j int) bool { + a := prs[i].Hash() + b := prs[j].Hash() + return bytes.Compare(a[:], b[:]) == -1 + }) + + prHashes := make([]chainhash.Hash, len(prs)) + for i := range prs { + prHashes[i] = prs[i].Hash() + } + return mixing.DeriveSessionID(prHashes, epoch) +} + +func (c *Client) pairSession(ctx context.Context, ps *pairedSessions, prs []*wire.MsgMixPairReq, d deadlines) { + // This session pairing attempt, and calling pairSession again with + // fresh PRs, must end before the next call to pairSession for this + // pairing type. + nextEpoch := d.epoch.Add(c.epoch) + ctx, cancel := context.WithDeadline(ctx, nextEpoch) + defer cancel() + + // Defer removal of completed mix messages. Add local peers back to + // the client to be paired in a later session if the mix was + // unsuccessful or only some peers were included. + var mixedPRs []*wire.MsgMixPairReq + defer func() { + unmixedPeers := ps.localPeers + for _, pr := range mixedPRs { + delete(unmixedPeers, pr.Identity) + } + if len(unmixedPeers) == 0 { + return + } + c.mu.Lock() + pendingPairing := c.pairings[string(ps.pairing)] + if pendingPairing == nil { + pendingPairing = c.newPairings(ps.pairing, unmixedPeers) + c.pairings[string(ps.pairing)] = pendingPairing + } else { + for id, p := range unmixedPeers { + prHash := p.pr.Hash() + if p.ctx.Err() == nil && c.mixpool.HaveMessage(&prHash) { + pendingPairing.localPeers[id] = p + } + } + } + c.mu.Unlock() + }() + + unixEpoch := uint64(d.epoch.Unix()) + + var madePairing bool + defer func() { + if !madePairing { + c.pairingWG.Done() + } + }() + + var sid [32]byte + var run uint32 + var sesLog *sessionLogger + for { + if run == 0 { + sid = sortPRsForSession(prs, unixEpoch) + sesLog = c.sessionLog(sid) + + prHashes := make([]chainhash.Hash, len(prs)) + for i := range prs { + prHashes[i] = prs[i].Hash() + } + + sesRun := sessionRun{ + sid: sid, + run: run, + prs: prs, + mcounts: make([]uint32, 0, len(prs)), + } + var m uint32 + var localPeerCount int + for i, pr := range prs { + p := ps.localPeers[pr.Identity] + if p != nil { + localPeerCount++ + } else { + p = newRemotePeer(pr) + } + p.myVk = uint32(i) + p.myStart = m + + sesRun.peers = append(sesRun.peers, p) + sesRun.mcounts = append(sesRun.mcounts, p.pr.MessageCount) + + m += p.pr.MessageCount + } + sesRun.mtot = m + ps.runs = append(ps.runs, sesRun) + + if localPeerCount == 0 { + return + } + + sesLog.logf("created session for pairid=%x from %d total %d local PRs %s", + ps.pairing, len(prHashes), localPeerCount, prHashes) + sesLog.logf("debug: len(ps.runs)=%d", len(ps.runs)) + } else { + // Calculate new deadlines for reruns in the same session. + d.restart() + } + + err := c.run(ctx, ps, &d, &madePairing) + + var altses *alternateSession + if errors.As(err, &altses) { + prs = altses.prs + if len(prs) < MinPeers { + sesLog.logf("Aborting session with too few remaining peers") + return + } + + d.shift() + + if sid != altses.sid { + sesLog.logf("Recreating as session %x (pairid=%x)", altses.sid, ps.pairing) + run = 0 + } + + continue + } + + var blamed blamedIdentities + if errors.Is(err, errBlameRequired) { + err := c.blame(ctx, &ps.runs[len(ps.runs)-1], reported /* XXX */) + if !errors.As(err, &blamed) { + c.logerrf("aborting session for failed blame assignment: %v", err) + return + } + } + if blamed != nil || errors.As(err, &blamed) { + // Blamed peers were identified, either during the run + // in a way that all participants could have observed, + // or following revealing secrets and blame + // assignment. Begin a rerun excluding these peers. + rerun := excludeBlamed(&ps.runs[len(ps.runs)-1], blamed) + ps.runs = append(ps.runs, rerun) + run = rerun.run + continue + } + + // Any other run error is not actionable. + if err != nil { + sesLog.logf("run error: %v", err) + return + } + + mixedPRs = ps.runs[len(ps.runs)-1].prs + return + } +} + +var ( + errRunStageTimeout = errors.New("mix run stage timeout") + errUnblamedRerun = errors.New("unblamed rerun") + errBlameRequired = errors.New("blame required") +) + +type alternateSession struct { + prs []*wire.MsgMixPairReq + sid [32]byte + unresponsive []*wire.MsgMixPairReq + err error +} + +func (e *alternateSession) Error() string { + return "recreated alternate session" +} +func (e *alternateSession) Unwrap() error { + return e.err +} + +type timeoutError struct { + exclude []identity + prs []*wire.MsgMixPairReq +} + +func (e *timeoutError) Error() string { return "timeout" } +func (e *timeoutError) Exclude() []identity { return e.exclude } +func (e *timeoutError) Rerun() []*wire.MsgMixPairReq { return e.prs } + +func (c *Client) run(ctx context.Context, ps *pairedSessions, d *deadlines, madePairing *bool) error { + var blamed blamedIdentities + + unixEpoch := uint64(d.epoch.Unix()) + + mp := c.wallet.Mixpool() + sesRun := &ps.runs[len(ps.runs)-1] + run := sesRun.run + prs := sesRun.prs + + // Helper function to run a callback on all local peers. Closes over + // the sesRun variable and will be updated to the appropriate session + // and run after a session is agreed upon. + forLocalPeers := func(f func(p *peer) error) { + for _, p := range sesRun.peers { + if p.remote { + continue + } + err := f(p) + if err != nil { + c.logerrf("%s", err) + } + } + } + + // A map of identity public keys to their PR position sort all + // messages in the same order as the PRs are ordered. + identityIndices := make(map[identity]int) + for i, pr := range prs { + identityIndices[pr.Identity] = i + } + + seenPRs := make([]chainhash.Hash, len(prs)) + for i := range prs { + seenPRs[i] = prs[i].Hash() + } + + forLocalPeers(func(p *peer) error { + p.coinjoin.resetUnmixed(prs) + + if p.freshGen { + p.freshGen = false + + // Generate a new PRNG seed + p.prngSeed = new([32]byte) + _, err := io.ReadFull(p.rand, p.prngSeed[:]) + if err != nil { + return err + } + p.prng = chacha20prng.New(p.prngSeed[:], run) + + // Generate fresh keys from this run's PRNG + p.kx, err = mixing.NewKX(p.prng) + if err != nil { + return err + } + + // Generate fresh SR messages + p.srMsg = make([]*big.Int, p.pr.MessageCount) + for i := range p.srMsg { + p.srMsg[i], err = cryptorand.Int(p.rand, mixing.F) + if err != nil { + return err + } + } + + // Generate fresh DC messages + p.dcMsg, err = p.coinjoin.gen() + if err != nil { + return err + } + if len(p.dcMsg) != int(p.pr.MessageCount) { + return errors.New("Gen returned wrong message count") + } + for _, m := range p.dcMsg { + if len(m) != msize { + err := fmt.Errorf("Gen returned bad message "+ + "length [%v != %v]", len(m), msize) + return err + } + } + } else { + // Generate a new PRNG from existing seed and this run + // number. + p.prng = chacha20prng.New(p.prngSeed[:], run) + } + + // Perform key exchange + srMsgBytes := make([][]byte, len(p.srMsg)) + for i := range p.srMsg { + srMsgBytes[i] = p.srMsg[i].Bytes() + } + rs := wire.NewMsgMixSecrets(*p.id, sesRun.sid, run, + *p.prngSeed, srMsgBytes, p.dcMsg) + commitment := rs.Hash() + ecdhPub := *(*[33]byte)(p.kx.ECDHPublicKey.SerializeCompressed()) + pqPub := *p.kx.PQPublicKey + ke := wire.NewMsgMixKeyExchange(*p.id, sesRun.sid, unixEpoch, run, + ecdhPub, pqPub, commitment, seenPRs) + + p.ke = ke + p.rs = rs + return p.signAndSubmit(ke) + }) + + // Receive key exchange messages. + // + // In run 0, it is possible that the attempted session (the session ID + // used by our PR messages) does not match the same session attempted + // by all other remote peers, due to not all peers initially agreeing + // on the same set of PRs. When there is agreement, the session can + // be run like normal, and any reruns will only be performed with some + // of the original peers removed. + // + // When there is session disagreement, we attempt to find a new + // session that the majority of peers will be able to participate in. + // All KE messages that match the pairing ID are received, and each + // seen PRs slice is checked. PRs that were never followed up by a KE + // are immediately excluded. + var kes []*wire.MsgMixKeyExchange + var err error + recvKEs := func(sesRun *sessionRun) (kes []*wire.MsgMixKeyExchange, err error) { + rcv := new(mixpool.Received) + rcv.Run = sesRun.run + rcv.Sid = sesRun.sid + rcv.KEs = make([]*wire.MsgMixKeyExchange, 0, len(sesRun.prs)) + ctx, cancel := context.WithDeadlineCause(ctx, + d.recvKE, errRunStageTimeout) + defer cancel() + err = mp.Receive(ctx, len(sesRun.prs), rcv) + return rcv.KEs, err + } + + switch { + case run == 0: + // Receive KEs for the last attempted session. Local + // peers may have been modified (new keys generated, and myVk + // indexes changed) if this is a recreated session, and we + // cannot continue mix using these messages. + // + // XXX: do we need to keep peer info available for previous + // session attempts? It is possible that previous sessions + // may be operable now if all wallets have come to agree on a + // previous session we also tried to form. + kes, err = recvKEs(sesRun) + if len(kes) == len(sesRun.prs) { + break + } + if err := ctx.Err(); err != nil { + return err + } + + // Alternate session needs to be attempted. Do not form an + // alternate session if we've already entered into the next + // epoch without forming a complete session. The session + // forming will be performed by a new goroutine started by the + // epoch ticker, possibly with additional PRs. + nextEpoch := d.epoch.Add(c.epoch) + if time.Now().After(nextEpoch) { + return err + } + + altses := c.alternateSession(ps.pairing, sesRun.prs, nil, d) + if errors.Is(altses.err, ErrTooFewPeers) || altses.sid == sesRun.sid { + altses = c.alternateSession(ps.pairing, sesRun.prs, + altses.unresponsive, d) + } + if altses.err != nil { + return err + } + return altses + + default: + // Receive KEs only for the agreed-upon session. + kes, err = recvKEs(sesRun) + if err != nil { + return err + } + } + + for _, ke := range kes { + if idx, ok := identityIndices[ke.Identity]; ok { + sesRun.peers[idx].ke = ke + } + } + + // Remove paired local peers from waiting pairing. + if run == 0 { + c.mu.Lock() + if waiting := c.pairings[string(ps.pairing)]; waiting != nil { + for id := range ps.localPeers { + delete(waiting.localPeers, id) + } + if len(waiting.localPeers) == 0 { + delete(c.pairings, string(ps.pairing)) + } + } + c.mu.Unlock() + + if !*madePairing { + c.pairingWG.Done() + *madePairing = true + } + } + + sort.Slice(kes, func(i, j int) bool { + a := identityIndices[kes[i].Identity] + b := identityIndices[kes[j].Identity] + return a < b + }) + + sesLog := c.sessionLog(sesRun.sid) + + involvesLocalPeers := false + for _, ke := range kes { + if ps.localPeers[ke.Identity] != nil { + involvesLocalPeers = true + break + } + } + if !involvesLocalPeers { + return errors.New("excluded all local peers") + } + + ecdhPublicKeys := make([]*secp256k1.PublicKey, 0, len(prs)) + pqpk := make([]*mixing.PQPublicKey, 0, len(prs)) + for _, ke := range kes { + ecdhPub, err := secp256k1.ParsePubKey(ke.ECDH[:]) + if err != nil { + blamed = append(blamed, ke.Identity) + continue + } + ecdhPublicKeys = append(ecdhPublicKeys, ecdhPub) + pqpk = append(pqpk, &ke.PQPK) + } + if len(blamed) > 0 { + return blamed + } + + forLocalPeers(func(p *peer) error { + // Create shared keys and ciphextexts for each peer + pqct, err := p.kx.Encapsulate(p.prng, pqpk, int(p.myVk)) + if err != nil { + return err + } + + // Send ciphertext messages + seenKEs := make([]chainhash.Hash, len(kes)) + for i := range kes { + seenKEs[i] = kes[i].Hash() + } + ct := wire.NewMsgMixCiphertexts(*p.id, sesRun.sid, run, pqct, seenKEs) + p.ct = ct + return p.signAndSubmit(ct) + }) + + // Receive all ciphertext messages + rcv := new(mixpool.Received) + rcv.Sid = sesRun.sid + rcv.KEs = nil + rcv.CTs = make([]*wire.MsgMixCiphertexts, 0, len(prs)) + rcvCtx, rcvCtxCancel := context.WithDeadlineCause(context.Background(), + d.recvCT, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(prs), rcv) + rcvCtxCancel() + cts := rcv.CTs + for _, ct := range cts { + if idx, ok := identityIndices[ct.Identity]; ok { + sesRun.peers[idx].ct = ct + } + } + if len(cts) != len(prs) || errors.Is(err, errRunStageTimeout) { + // Blame peers + sesLog.logf("received %d CTs for %d peers; rerunning", len(cts), len(prs)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(cts, func(i, j int) bool { + a := identityIndices[cts[i].Identity] + b := identityIndices[cts[j].Identity] + return a < b + }) + + blamedMap := make(map[identity]struct{}) + var blameErr error + forLocalPeers(func(p *peer) error { + revealed := &mixing.RevealedKeys{ + ECDHPublicKeys: ecdhPublicKeys, + Ciphertexts: make([]mixing.PQCiphertext, 0, len(prs)), + MyIndex: p.myVk, + } + ctIds := make([]identity, 0, len(cts)) + for _, ct := range cts { + if len(ct.Ciphertexts) != len(prs) { + // Everyone sees this, can blame now. + blamedMap[ct.Identity] = struct{}{} + return nil + } + revealed.Ciphertexts = append(revealed.Ciphertexts, ct.Ciphertexts[p.myVk]) + ctIds = append(ctIds, ct.Identity) + } + + // Derive shared secret keys + shared, err := p.kx.SharedSecrets(revealed, sesRun.sid[:], run, sesRun.mcounts) + if err != nil { + blameErr = errBlameRequired + return nil + } + p.srKP = shared.SRSecrets + p.dcKP = shared.DCSecrets + + // Calculate slot reservation DC-net vectors + p.srMix = make([][]*big.Int, p.pr.MessageCount) + for i := range p.srMix { + pads := mixing.SRMixPads(p.srKP[i], p.myStart+uint32(i)) + p.srMix[i] = mixing.SRMix(p.srMsg[i], pads) + } + srMixBytes := mixing.IntVectorsToBytes(p.srMix) + + // Broadcast message commitment and exponential DC-mix vectors for slot + // reservations. + seenCTs := make([]chainhash.Hash, len(cts)) + for i := range cts { + seenCTs[i] = cts[i].Hash() + } + sr := wire.NewMsgMixSlotReserve(*p.id, sesRun.sid, run, srMixBytes, seenCTs) + p.sr = sr + return p.signAndSubmit(sr) + }) + if len(blamedMap) > 0 { + for id := range blamedMap { + blamed = append(blamed, id) + } + return blamed + } + if blameErr != nil { + return blameErr + } + + // Receive all slot reservation messages + rcv.CTs = nil + rcv.SRs = make([]*wire.MsgMixSlotReserve, 0, len(prs)) + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(context.Background(), + d.recvSR, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(prs), rcv) + rcvCtxCancel() + srs := rcv.SRs + for _, sr := range srs { + if idx, ok := identityIndices[sr.Identity]; ok { + sesRun.peers[idx].sr = sr + } + } + if len(srs) != len(prs) || errors.Is(err, errRunStageTimeout) { + // Blame peers + sesLog.logf("received %d SRs for %d peers; rerunning", len(srs), len(prs)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(srs, func(i, j int) bool { + a := identityIndices[srs[i].Identity] + b := identityIndices[srs[j].Identity] + return a < b + }) + + // Recover roots + var roots []*big.Int + vs := make([][][]byte, 0, len(prs)) + for _, sr := range srs { + vs = append(vs, sr.DCMix...) + } + powerSums := mixing.AddVectors(mixing.IntVectorsFromBytes(vs)...) + coeffs := mixing.Coefficients(powerSums) + roots, err = solverrpc.Roots(coeffs, mixing.F) + if err != nil { + // Blame peers + return errors.New("blame required") + } + sesRun.roots = roots + + forLocalPeers(func(p *peer) error { + // Find reserved slots + slots := make([]uint32, 0, p.pr.MessageCount) + for _, m := range p.srMsg { + slot := constTimeSlotSearch(m, roots) + if slot == -1 { + // Blame peers + return errors.New("blame required") + } + slots = append(slots, uint32(slot)) + } + + // Calculate XOR DC-net vectors + p.dcNet = make([]wire.MixVect, p.pr.MessageCount) + for i, slot := range slots { + my := p.myStart + uint32(i) + pads := mixing.DCMixPads(p.dcKP[i], my) + p.dcNet[i] = wire.MixVect(mixing.DCMix(pads, p.dcMsg[i][:], slot)) + } + + // Broadcast XOR DC-net vectors. + seenSRs := make([]chainhash.Hash, len(cts)) + for i := range srs { + seenSRs[i] = srs[i].Hash() + } + dc := wire.NewMsgMixDCNet(*p.id, sesRun.sid, run, p.dcNet, seenSRs) + p.dc = dc + return p.signAndSubmit(dc) + }) + + // Receive all DC messages + rcv.SRs = nil + rcv.DCs = make([]*wire.MsgMixDCNet, 0, len(prs)) + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(context.Background(), + d.recvDC, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(prs), rcv) + rcvCtxCancel() + dcs := rcv.DCs + for _, dc := range dcs { + if idx, ok := identityIndices[dc.Identity]; ok { + sesRun.peers[idx].dc = dc + } + } + if len(dcs) != len(prs) || errors.Is(err, errRunStageTimeout) { + // Blame peers + sesLog.logf("received %d DCs for %d peers; rerunning", len(dcs), len(prs)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(dcs, func(i, j int) bool { + a := identityIndices[dcs[i].Identity] + b := identityIndices[dcs[j].Identity] + return a < b + }) + + // Solve XOR dc-net + dcVecs := make([]mixing.Vec, 0, sesRun.mtot) + for _, dc := range dcs { + for _, vec := range dc.DCNet { + dcVecs = append(dcVecs, mixing.Vec(vec)) + } + } + mixedMsgs := mixing.XorVectors(dcVecs) + + type missingMessage interface { + error + MissingMessage() + } + var errMissingMessage missingMessage + forLocalPeers(func(p *peer) error { + // Add outputs for each mixed message + for i := range mixedMsgs { + mixedMsg := mixedMsgs[i][:] + p.coinjoin.addMixedMessage(mixedMsg) + } + p.coinjoin.sort() + + // Confirm that our messages are present, and sign each peer's + // provided inputs. + err = p.coinjoin.confirm(c.wallet) + if err != nil { + if errMissingMessage != nil { + errors.As(err, &errMissingMessage) + } + return nil + } + + // Broadcast partially signed mix tx + seenDCs := make([]chainhash.Hash, len(dcs)) + for i := range dcs { + seenDCs[i] = dcs[i].Hash() + } + cm := wire.NewMsgMixConfirm(*p.id, sesRun.sid, run, + p.coinjoin.Tx().Copy(), seenDCs) + p.cm = cm + return p.signAndSubmit(cm) + }) + if errMissingMessage != nil { + sesLog.logf("missing message; blaming and rerunning") + return errMissingMessage + } + + // Receive all CM messages + rcv.DCs = nil + rcv.CMs = make([]*wire.MsgMixConfirm, 0, len(prs)) + rcvCtx, rcvCtxCancel = context.WithDeadlineCause(context.Background(), + d.recvCM, errRunStageTimeout) + err = mp.Receive(rcvCtx, len(prs), rcv) + rcvCtxCancel() + cms := rcv.CMs + for _, cm := range cms { + if idx, ok := identityIndices[cm.Identity]; ok { + sesRun.peers[idx].cm = cm + } + } + if len(cms) != len(prs) || errors.Is(err, errRunStageTimeout) { + // Blame peers + sesLog.logf("received %d CMs for %d peers; rerunning", len(cms), len(prs)) + return errUnblamedRerun + } + if err != nil { + return err + } + sort.Slice(cms, func(i, j int) bool { + a := identityIndices[cms[i].Identity] + b := identityIndices[cms[j].Identity] + return a < b + }) + + // Merge and validate all signatures. Only a single coinjoin is + // needed at this point. + var cj *CoinJoin + pubkeys := make(map[wire.OutPoint][]byte) + for _, p := range sesRun.peers { + if cj == nil && !p.remote { + cj = p.coinjoin + } + for _, utxo := range p.pr.UTXOs { + pubkeys[utxo.OutPoint] = utxo.PubKey + } + } + for _, cm := range cms { + err := cj.mergeSignatures(cm) + if err != nil { + blamed = append(blamed, cm.Identity) + } + } + if len(blamed) > 0 { + return blamed + } + p2pkhv0Script := []byte{ + 0: txscript.OP_DUP, + 1: txscript.OP_HASH160, + 2: txscript.OP_DATA_20, + 23: txscript.OP_EQUALVERIFY, + 24: txscript.OP_CHECKSIG, + } + sigCache, err := txscript.NewSigCache(uint(len(cj.tx.TxIn))) + if err != nil { + return errUnblamedRerun + } + blamedInputs := make(map[wire.OutPoint]struct{}) + const scriptFlags = txscript.ScriptDiscourageUpgradableNops | + txscript.ScriptVerifyCleanStack | + txscript.ScriptVerifyCheckLockTimeVerify | + txscript.ScriptVerifyCheckSequenceVerify | + txscript.ScriptVerifyTreasury + const scriptVersion = 0 + for i, in := range cj.tx.TxIn { + pk, ok := pubkeys[in.PreviousOutPoint] + if !ok { + blamedInputs[in.PreviousOutPoint] = struct{}{} + continue + } + pkh := dcrutil.Hash160(pk) + copy(p2pkhv0Script[3:23], pkh) + engine, err := txscript.NewEngine(p2pkhv0Script, cj.tx, i, + scriptFlags, scriptVersion, sigCache) + if err != nil { + c.logf("warn: blaming peer for error creating script engine: %v", err) + blamedInputs[in.PreviousOutPoint] = struct{}{} + continue + } + if engine.Execute() != nil { + blamedInputs[in.PreviousOutPoint] = struct{}{} + } + } + if len(blamedInputs) > 0 { + blameInputLoop: + for _, pr := range prs { + for _, in := range pr.UTXOs { + if _, ok := blamedInputs[in.OutPoint]; ok { + blamed = append(blamed, pr.Identity) + continue blameInputLoop + } + } + } + return blamed + } + + err = c.wallet.PublishTransaction(context.Background(), cj.tx) + if err != nil { + return err + } + + mp.RemoveSession(sesRun.sid, true) + + forLocalPeers(func(p *peer) error { + select { + case p.res <- nil: + default: + } + return nil + }) + return nil +} + +func (c *Client) alternateSession(pairing []byte, prs []*wire.MsgMixPairReq, + ignoreUnresponsive []*wire.MsgMixPairReq, d *deadlines) *alternateSession { + + unixEpoch := uint64(d.epoch.Unix()) + + kes := c.mixpool.ReceiveKEsByPairing(pairing, unixEpoch) + + // Sort KEs by identity first (just to group these together) followed + // by the total referenced PR counts in decreasing order. + // When ranging over KEs below, this will allow us to consider the + // order in which other peers created their KEs, and how they are + // forming their sessions. + sort.Slice(kes, func(i, j int) bool { + a := kes[i] + b := kes[j] + if bytes.Compare(a.Identity[:], b.Identity[:]) == -1 { + return true + } + if len(a.SeenPRs) > len(b.SeenPRs) { + return true + } + return false + }) + + ignoreUnresponsiveIdentities := make(map[identity]struct{}) + ignoreUnresponsivePRHashes := make(map[chainhash.Hash]struct{}) + for _, pr := range ignoreUnresponsive { + ignoreUnresponsiveIdentities[pr.Identity] = struct{}{} + ignoreUnresponsivePRHashes[pr.Hash()] = struct{}{} + } + + prsByHash := make(map[chainhash.Hash]*wire.MsgMixPairReq) + prHashByIdentity := make(map[identity]chainhash.Hash) + for _, pr := range prs { + prsByHash[pr.Hash()] = pr + prHashByIdentity[pr.Identity] = pr.Hash() + } + + // Only one KE per peer identity (the KE that references the most PR + // hashes) is used for determining session agreement. + type peerMsgs struct { + pr *wire.MsgMixPairReq + ke *wire.MsgMixKeyExchange + } + msgsByIdentityMap := make(map[identity]*peerMsgs) + for i := range kes { + ke := kes[i] + if msgsByIdentityMap[ke.Identity] != nil { + continue + } + pr := prsByHash[prHashByIdentity[ke.Identity]] + if pr == nil { + err := fmt.Errorf("missing PR %s by %x, but have their KE %s", + prHashByIdentity[ke.Identity], ke.Identity[:], ke.Hash()) + c.logger.Log(err) + continue + } + msgsByIdentityMap[ke.Identity] = &peerMsgs{ + pr: pr, + ke: ke, + } + } + + // Discover unresponsive peers if we have not already calculated them + // and have to ignore them. + var unresponsive []*wire.MsgMixPairReq + if len(ignoreUnresponsive) == 0 { + for _, pr := range prs { + id := pr.Identity + if _, ok := msgsByIdentityMap[id]; !ok { + c.logf("Identity %x (PR %s) is unresponsive", id[:], pr.Hash()) + unresponsive = append(unresponsive, pr) + } + } + } + + // A total count of all unique PRs/identities (including not just + // those we have observed PRs for, but also other unknown PR hashes + // referenced by other peers) is needed. All participating peers must + // come to agreement on this number, otherwise the exclusion rules + // based on majority may result in different peers being removed and + // agreement never being reached. Any PRs not seen by a majority of + // other peers (i.e. without counting peers seeing their own PR) are + // immediately removed from any formed session attempt. + referencedPRCounts := make(map[chainhash.Hash]int) + for _, msgs := range msgsByIdentityMap { + selfPRHash := msgs.pr.Hash() + for _, prHash := range msgs.ke.SeenPRs { + if selfPRHash == prHash { + continue + } + // Do not count ignored unresponsive peer's PRs. + if _, ok := ignoreUnresponsivePRHashes[prHash]; ok { + continue + } + referencedPRCounts[prHash]++ + } + } + prCount := len(referencedPRCounts) + neededForMajority := len(referencedPRCounts) / 2 + remainingPRs := make([]chainhash.Hash, 0, prCount) + prsSeenCount := make([]int, 0, prCount) + for prHash, count := range referencedPRCounts { + if count >= neededForMajority { + remainingPRs = append(remainingPRs, prHash) + prsSeenCount = append(prsSeenCount, count) + } + } + + // Sort PRs by increasing seen counts first, then lexicographically by + // PR hash. + sort.Slice(remainingPRs, func(i, j int) bool { + switch { + case prsSeenCount[i] < prsSeenCount[j]: + return true + case prsSeenCount[i] > prsSeenCount[j]: + return false + default: + a := remainingPRs[i][:] + b := remainingPRs[j][:] + return bytes.Compare(a, b) == -1 + } + }) + + // Incrementally remove PRs that are not seen by all peers, and + // recalculate total observance counts after excluding the peer that + // was dropped. A session is formed when all remaining PR counts + // equal one less than the total number of PRs still considered + // (due to not counting observing one's own PR). + for { + if len(prsSeenCount) < MinPeers { + return &alternateSession{ + unresponsive: unresponsive, + err: ErrTooFewPeers, + } + } + if prsSeenCount[0] == len(prsSeenCount)-1 { + prs := make([]*wire.MsgMixPairReq, len(remainingPRs)) + for i, prHash := range remainingPRs { + prs[i] = prsByHash[prHash] + if prs[i] == nil { + // Agreement should be reached, but we + // won't be able to participate in the + // formed session. + return &alternateSession{ + unresponsive: unresponsive, + err: ErrUnknownPRs, + } + } + } + newSessionID := sortPRsForSession(prs, unixEpoch) + return &alternateSession{ + prs: prs, + sid: newSessionID, + unresponsive: unresponsive, + } + } + + // removedPR := remainingPRHashes[0] + remainingPRs = remainingPRs[1:] + prsSeenCount = prsSeenCount[:0] + + prCounts := make(map[chainhash.Hash]int, len(remainingPRs)) + for _, hash := range remainingPRs { + prCounts[hash] = 0 + } + for _, hash := range remainingPRs { + pr := prsByHash[hash] + if pr == nil { + continue + } + selfPRHash := pr.Hash() + msgs := msgsByIdentityMap[pr.Identity] + if msgs == nil { + continue + } + ke := msgs.ke + for _, prHash := range ke.SeenPRs { + if prHash == selfPRHash { + continue + } + if _, ok := prCounts[prHash]; ok { + prCounts[prHash]++ + } + } + } + for _, prHash := range remainingPRs { + prsSeenCount = append(prsSeenCount, prCounts[prHash]) + } + } + +} + +func excludeBlamed(prevRun *sessionRun, blamed blamedIdentities) sessionRun { + blamedMap := make(map[identity]struct{}) + for _, id := range blamed { + blamedMap[id] = struct{}{} + } + + peers := prevRun.peers[:0] + prs := prevRun.prs[:0] + mcounts := prevRun.mcounts[:0] + var mtot uint32 + + for _, p := range prevRun.peers { + if _, ok := blamedMap[*p.id]; ok { + continue + } + + peers = append(peers, p) + prs = append(prs, p.pr) + mcounts = append(mcounts, p.pr.MessageCount) + mtot += p.pr.MessageCount + } + + return sessionRun{ + sid: prevRun.sid, + run: prevRun.run + 1, + mtot: mtot, + + prs: prs, + peers: peers, + mcounts: mcounts, + roots: nil, + } +} + +var fieldLen = uint(len(mixing.F.Bytes())) + +// constTimeSlotSearch searches for the index of secret in roots in constant time. +// Returns -1 if the secret is not found. +func constTimeSlotSearch(secret *big.Int, roots []*big.Int) int { + paddedSecret := make([]byte, fieldLen) + secretBytes := secret.Bytes() + off, _ := bits.Sub(fieldLen, uint(len(secretBytes)), 0) + copy(paddedSecret[off:], secretBytes) + + slot := -1 + buf := make([]byte, fieldLen) + for i := range roots { + rootBytes := roots[i].Bytes() + off, _ := bits.Sub(fieldLen, uint(len(rootBytes)), 0) + copy(buf[off:], rootBytes) + cmp := subtle.ConstantTimeCompare(paddedSecret, buf) + slot = subtle.ConstantTimeSelect(cmp, i, slot) + for j := range buf { + buf[j] = 0 + } + } + return slot +} diff --git a/mixing/mixpool/mixpool.go b/mixing/mixpool/mixpool.go new file mode 100644 index 0000000000..85abe003d3 --- /dev/null +++ b/mixing/mixpool/mixpool.go @@ -0,0 +1,1340 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +// Package mixpool provides an in-memory pool of mixing messages for full nodes +// that relay these messages and mixing wallets that send and receive them. +package mixpool + +import ( + "bytes" + "context" + "fmt" + "sort" + "sync" + "time" + + "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/chaincfg/v3" + "github.com/decred/dcrd/dcrutil/v4" + "github.com/decred/dcrd/mixing" + "github.com/decred/dcrd/mixing/utxoproof" + "github.com/decred/dcrd/txscript/v4/stdscript" + "github.com/decred/dcrd/wire" +) + +const minconf = 2 +const feeRate = 0.0001e8 + +type idPubKey = [33]byte + +type msgtype int + +// Message type constants, for quickly checking looked up entries by message +// hash match the expected type (without performing a type assertion). +// Excludes PR. +const ( + msgtypeKE msgtype = 1 + iota + msgtypeCT + msgtypeSR + msgtypeDC + msgtypeCM + msgtypeRS + + nmsgtypes = msgtypeRS +) + +func (m msgtype) String() string { + switch m { + case msgtypeKE: + return "KE" + case msgtypeCT: + return "CT" + case msgtypeSR: + return "SR" + case msgtypeDC: + return "DC" + case msgtypeCM: + return "CM" + case msgtypeRS: + return "RS" + default: + return "?" + } +} + +// entry describes non-PR messages accepted to the pool. +type entry struct { + hash chainhash.Hash + sid [32]byte + recvTime time.Time + msg mixing.Message + msgtype msgtype + run uint32 +} + +type session struct { + sid [32]byte + runs []runstate + expiry uint32 + bc broadcast +} + +type runstate struct { + run uint32 + prs []chainhash.Hash + counts [nmsgtypes]uint32 + hashes map[chainhash.Hash]struct{} +} + +func (r *runstate) countFor(t msgtype) uint32 { + return r.counts[t-1] +} + +func (r *runstate) incrementCountFor(t msgtype) { + r.counts[t-1]++ +} + +type broadcast struct { + ch chan struct{} + mu sync.Mutex +} + +// wait returns the wait channel that is closed whenever a message is received +// for a session. Waiters must acquire the pool lock before reading messages. +func (b *broadcast) wait() <-chan struct{} { + b.mu.Lock() + ch := b.ch + b.mu.Unlock() + + return ch +} + +func (b *broadcast) signal() { + b.mu.Lock() + close(b.ch) + b.ch = make(chan struct{}) + b.mu.Unlock() +} + +// Pool records in-memory mix messages that have been broadcast over the +// peer-to-peer network. +type Pool struct { + mtx sync.RWMutex + prs map[chainhash.Hash]*wire.MsgMixPairReq + outPoints map[wire.OutPoint]chainhash.Hash + pool map[chainhash.Hash]entry + messagesByIdentity map[idPubKey][]chainhash.Hash + latestKE map[idPubKey]*wire.MsgMixKeyExchange + sessions map[[32]byte]*session + sessionsByTxHash map[chainhash.Hash]*session + epoch time.Duration + expireHeight uint32 + expireSem chan struct{} + + blockchain BlockChain + utxoFetcher UtxoFetcher + feeRate int64 + params *chaincfg.Params +} + +// UtxoEntry provides details regarding unspent transaction outputs. +type UtxoEntry interface { + IsSpent() bool + PkScript() []byte + ScriptVersion() uint16 + BlockHeight() int64 + Amount() int64 +} + +// UtxoFetcher defines methods used to validate unspent transaction outputs in +// the pair request message. It is optional, but should be implemented by full +// nodes that have this capability to detect and stop relay of spam and junk +// messages. +type UtxoFetcher interface { + // FetchUtxoEntry defines the function to use to fetch unspent + // transaction output information. + FetchUtxoEntry(wire.OutPoint) (UtxoEntry, error) +} + +// BlockChain queries the current status of the blockchain. Its methods should +// be able to be implemented by both full nodes and SPV wallets. +type BlockChain interface { + // ChainParams identifies which chain parameters the mixing pool is + // associated with. + ChainParams() *chaincfg.Params + + // BestHeader returns the hash and height of the current tip block. + BestHeader() (chainhash.Hash, int64) +} + +// NewPool returns a new mixing pool that accepts and validates mixing messages +// required for distributed transaction mixing. +func NewPool(blockchain BlockChain) *Pool { + pool := &Pool{ + prs: make(map[chainhash.Hash]*wire.MsgMixPairReq), + outPoints: make(map[wire.OutPoint]chainhash.Hash), + pool: make(map[chainhash.Hash]entry), + messagesByIdentity: make(map[idPubKey][]chainhash.Hash), + latestKE: make(map[idPubKey]*wire.MsgMixKeyExchange), + sessions: make(map[[32]byte]*session), + sessionsByTxHash: make(map[chainhash.Hash]*session), + epoch: 10 * time.Minute, // XXX: mainnet epoch: add to chainparams + expireHeight: 0, + expireSem: make(chan struct{}, 1), + blockchain: blockchain, + feeRate: feeRate, + params: blockchain.ChainParams(), + } + if u, ok := blockchain.(UtxoFetcher); ok { + pool.utxoFetcher = u + } + return pool +} + +// XXX: can remove this method after adding mixing epoch to chainparams. +func (p *Pool) SetEpoch(d time.Duration) { + p.epoch = d +} + +// MixPRHashes returns the hashes of all MixPR messages recorded by the pool. +// This data is provided to peers requesting initial state of the mixpool. +func (p *Pool) MixPRHashes() []chainhash.Hash { + p.mtx.RLock() + hashes := make([]chainhash.Hash, 0, len(p.prs)) + for hash := range p.prs { + hashes = append(hashes, hash) + } + p.mtx.RUnlock() + + return hashes +} + +// Message searches the mixing pool for a message by its hash. +func (p *Pool) Message(query *chainhash.Hash) (mixing.Message, error) { + p.mtx.RLock() + pr := p.prs[*query] + e, ok := p.pool[*query] + p.mtx.RUnlock() + if pr != nil { + return pr, nil + } + if !ok || e.msg == nil { + return nil, fmt.Errorf("message not found") + } + return e.msg, nil +} + +// HaveMessage checks whether the mixing pool contains a message by its hash. +func (p *Pool) HaveMessage(query *chainhash.Hash) bool { + p.mtx.RLock() + _, ok := p.pool[*query] + if !ok { + _, ok = p.prs[*query] + } + p.mtx.RUnlock() + return ok +} + +// MixPR searches the mixing pool for a PR message by its hash. +func (p *Pool) MixPR(query *chainhash.Hash) (*wire.MsgMixPairReq, error) { + var pr *wire.MsgMixPairReq + + p.mtx.RLock() + pr = p.prs[*query] + p.mtx.RUnlock() + + if pr == nil { + return nil, fmt.Errorf("PR message not found") + } + + return pr, nil +} + +// MixPRs returns all MixPR messages with hashes matching the query. Unknown +// messages are ignored. +// +// If query is nil, all PRs are returned. +// +// In both cases, any expired PRs that are still internally tracked by the +// mixpool for ongoing sessions are excluded from the result set. +func (p *Pool) MixPRs(query []chainhash.Hash) []*wire.MsgMixPairReq { + res := make([]*wire.MsgMixPairReq, 0, len(query)) + + p.mtx.Lock() + defer p.mtx.Unlock() + + p.removeConfirmedRuns() + + if query == nil { + res = make([]*wire.MsgMixPairReq, 0, len(p.prs)) + for _, pr := range p.prs { + // Exclude expired but not yet removed PRs. + if pr.Expiry <= p.expireHeight { + continue + } + + res = append(res, pr) + } + return res + } + + for i := range query { + pr, ok := p.prs[query[i]] + if ok { + // Exclude expired but not yet removed PRs. + if pr.Expiry <= p.expireHeight { + continue + } + + res = append(res, pr) + } + } + + return res +} + +// CompatiblePRs returns all MixPR messages with pairing descriptions matching +// the parameter. +func (p *Pool) CompatiblePRs(pairing []byte) []*wire.MsgMixPairReq { + p.mtx.RLock() + defer p.mtx.RUnlock() + + res := make([]*wire.MsgMixPairReq, 0, len(p.prs)) + for _, pr := range p.prs { + prPairing, _ := pr.Pairing() + if bytes.Equal(pairing, prPairing) { + res = append(res, pr) + } + } + + // Sort by decreasing expiries and remove any PRs double spending an + // output with an earlier expiry. + sort.Slice(res, func(i, j int) bool { + return res[i].Expiry >= res[j].Expiry + }) + seen := make(map[wire.OutPoint]uint32) + for i, pr := range res { + for _, utxo := range pr.UTXOs { + prevExpiry, ok := seen[utxo.OutPoint] + if !ok { + seen[utxo.OutPoint] = pr.Expiry + } else if pr.Expiry < prevExpiry { + res[i] = nil + } + } + } + filtered := res[:0] + for i := range res { + if res[i] != nil { + filtered = append(filtered, res[i]) + } + } + + // Sort again lexicographically by hash. + sort.Slice(filtered, func(i, j int) bool { + a := filtered[i].Hash() + b := filtered[j].Hash() + return bytes.Compare(a[:], b[:]) < 1 + }) + return filtered +} + +// ExpireMessages will, after the current epoch period ends, remove all pair +// requests that indicate an expiry at or before the height parameter and +// removes all messages that chain back to a removed pair request. +func (p *Pool) ExpireMessages(height uint32) { + p.mtx.Lock() + defer p.mtx.Unlock() + + if height > p.expireHeight { + p.expireHeight = height + } + + select { + case p.expireSem <- struct{}{}: + go p.expireMessages() + default: + } +} + +// waitForEpoch blocks until the next epoch occurs. +func (p *Pool) waitForEpoch() { + now := time.Now().UTC() + epoch := now.Truncate(p.epoch).Add(p.epoch) + duration := epoch.Sub(now) + time.Sleep(duration) +} + +func (p *Pool) expireMessages() { + p.waitForEpoch() + + p.mtx.Lock() + defer func() { + <-p.expireSem + p.mtx.Unlock() + }() + + height := p.expireHeight + p.expireHeight = 0 + + // Expire sessions and their messages + for sid, ses := range p.sessions { + if ses.expiry > height { + continue + } + + delete(p.sessions, sid) + for _, r := range ses.runs { + for hash := range r.hashes { + delete(p.pool, hash) + } + } + } + + // Expire PRs and remove identity tracking + for hash, pr := range p.prs { + if pr.Expiry > height { + continue + } + + logf("Removing expired PR %s by %x", + hash, pr.Identity[:]) + delete(p.prs, hash) + delete(p.messagesByIdentity, pr.Identity) + delete(p.latestKE, pr.Identity) + } +} + +// RemoveMessage removes a message that was rejected by the network. +func (p *Pool) RemoveMessage(msg mixing.Message) { + p.mtx.Lock() + defer p.mtx.Unlock() + + msgHash := msg.Hash() + delete(p.pool, msgHash) + if pr, ok := msg.(*wire.MsgMixPairReq); ok { + if p.prs[msgHash] != nil { + logf("Removing PR %s by %x", + msgHash, pr.Identity[:]) + } + delete(p.prs, msgHash) + delete(p.latestKE, pr.Identity) + } + if ke, ok := msg.(*wire.MsgMixKeyExchange); ok { + delete(p.latestKE, ke.Identity) + } +} + +// RemoveIdentities removes all messages from the mixpool that were created by +// any of the identities. +func (p *Pool) RemoveIdentities(identities [][33]byte) { + p.mtx.Lock() + defer p.mtx.Unlock() + + for i := range identities { + id := &identities[i] + for _, hash := range p.messagesByIdentity[*id] { + delete(p.pool, hash) + if p.prs[hash] != nil { + logf("Removing PR %s by %x", + hash, id[:]) + } + delete(p.prs, hash) + } + delete(p.messagesByIdentity, *id) + delete(p.latestKE, *id) + } +} + +func (p *Pool) removeSession(sid [32]byte, txHash *chainhash.Hash, success bool) { + ses := p.sessions[sid] + if ses == nil { + return + } + + // Delete PRs used to form final run + var removePRs []chainhash.Hash + var lastRun *runstate + if success { + lastRun = &ses.runs[len(ses.runs)-1] + removePRs = lastRun.prs + } + + if txHash != nil || success { + if txHash == nil { + // XXX: may be better to store this in the runstate as + // a CM is received. + for h := range lastRun.hashes { + if e, ok := p.pool[h]; ok && e.msgtype == msgtypeCM { + cm := e.msg.(*wire.MsgMixConfirm) + hash := cm.Mix.TxHash() + txHash = &hash + break + } + } + } + if txHash != nil { + delete(p.sessionsByTxHash, *txHash) + } + } + + delete(p.sessions, sid) + for _, r := range ses.runs { + for hash := range r.hashes { + delete(p.pool, hash) + } + } + + for _, prHash := range removePRs { + delete(p.pool, prHash) + if pr := p.prs[prHash]; pr != nil { + logf("Removing mixed PR %s by %x", prHash, + pr.Identity[:]) + delete(p.prs, prHash) + delete(p.latestKE, pr.Identity) + } + } +} + +// RemoveSession removes all non-PR messages from a completed or errored +// session. PR messages of a successful run (or rerun) must also be removed. +func (p *Pool) RemoveSession(sid [32]byte, success bool) { + p.mtx.Lock() + defer p.mtx.Unlock() + + p.removeSession(sid, nil, success) +} + +// RemoveConfirmedRuns removes all messages including pair requests from +// runs which ended in each peer sending a confirm mix message. +func (p *Pool) RemoveConfirmedRuns() { + p.mtx.Lock() + defer p.mtx.Unlock() + +} + +func (p *Pool) removeConfirmedRuns() { + for sid, ses := range p.sessions { + lastRun := &ses.runs[len(ses.runs)-1] + cmCount := lastRun.countFor(msgtypeCM) + if len(lastRun.prs) != int(cmCount) { + continue + } + + delete(p.sessions, sid) + for _, prevRun := range ses.runs[:len(ses.runs)-1] { + for hash := range prevRun.hashes { + delete(p.pool, hash) + } + } + for hash := range lastRun.hashes { + delete(p.pool, hash) + } + + for _, hash := range lastRun.prs { + delete(p.pool, hash) + pr := p.prs[hash] + if pr != nil { + logf("Removing PR %s by %x", + hash, pr.Identity[:]) + delete(p.prs, hash) + delete(p.latestKE, pr.Identity) + } + } + } +} + +// RemoveConfirmedMixes removes sessions and messages belong to a completed +// session that resulted in a published or mined transactions. Transaction +// hashes not associated with a session are ignored. PRs from the successful +// mix run are removed from the pool. +func (p *Pool) RemoveConfirmedMixes(txHashes []chainhash.Hash) { + p.mtx.Lock() + defer p.mtx.Unlock() + + for _, hash := range txHashes { + ses := p.sessionsByTxHash[hash] + if ses == nil { + continue + } + + p.removeSession(ses.sid, &hash, true) + } +} + +// ReceiveKEsByPairing returns the most recently received run-0 KE messages by +// a peer that reference PRs of a particular pairing and epoch. +func (p *Pool) ReceiveKEsByPairing(pairing []byte, epoch uint64) []*wire.MsgMixKeyExchange { + p.mtx.RLock() + defer p.mtx.RUnlock() + + var kes []*wire.MsgMixKeyExchange + for id, ke := range p.latestKE { + prHash := p.messagesByIdentity[id][0] + pr := p.prs[prHash] + prPairing, err := pr.Pairing() + if err != nil { + continue + } + if bytes.Equal(pairing, prPairing) && ke.Epoch == epoch { + kes = append(kes, ke) + } + } + return kes +} + +// RemoveUnresponsiveDuringEpoch removes pair requests of unresponsive peers +// that did not provide any key exchange messages during the epoch in which a +// mix occurred. +func (p *Pool) RemoveUnresponsiveDuringEpoch(pairing []byte, epoch uint64) { + p.mtx.Lock() + defer p.mtx.Unlock() + + var unresponsive []*wire.MsgMixPairReq + +PRLoop: + for _, pr := range p.prs { + prPairing, err := pr.Pairing() + if err != nil { + log.Warnf("pr.Pairing() of accepted pair request errored: %v", + err) + continue + } + if !bytes.Equal(pairing, prPairing) { + continue + } + + for _, msgHash := range p.messagesByIdentity[pr.Identity] { + msg, ok := p.pool[msgHash].msg.(*wire.MsgMixKeyExchange) + if !ok { + continue + } + if msg.Epoch == epoch { + continue PRLoop + } + } + + unresponsive = append(unresponsive, pr) + } + + for _, pr := range unresponsive { + id := &pr.Identity + for _, hash := range p.messagesByIdentity[*id] { + delete(p.pool, hash) + if p.prs[hash] != nil { + logf("Removing PR %s by %x", hash, id[:]) + } + delete(p.prs, hash) + } + delete(p.messagesByIdentity, *id) + delete(p.latestKE, *id) + } +} + +// Received is a parameter for Pool.Receive describing the session and run to +// receive messages for, and slices for returning results. Only non-nil slices +// will be appended to. Received messages are unsorted. +type Received struct { + Sid [32]byte + Run uint32 + KEs []*wire.MsgMixKeyExchange + CTs []*wire.MsgMixCiphertexts + SRs []*wire.MsgMixSlotReserve + DCs []*wire.MsgMixDCNet + CMs []*wire.MsgMixConfirm + RSs []*wire.MsgMixSecrets +} + +// Receive returns messages matching a session, run, and message type, waiting +// until all described messages have been received, or earlier with the +// messages received so far if the context is cancelled before this point. +// +// Receive only returns results for the session ID and run increment in the r +// parameter. If no such session or run has any messages currently accepted +// in the mixpool, the method immediately errors. +// +// If any secrets messages are received for the described session and run, and +// r.RSs is nil, Receive immediately returns ErrSecretsRevealed. An +// additional call to Receive with a non-nil RSs can be used to receive all of +// the secrets after each peer publishes their own revealed secrets. +func (p *Pool) Receive(ctx context.Context, expectedMessages int, r *Received) error { + sid := r.Sid + run := r.Run + var bc *broadcast + var rs *runstate + var err error + + p.mtx.RLock() + ses, ok := p.sessions[sid] + if !ok { + p.mtx.RUnlock() + return fmt.Errorf("unknown session") + } + bc = &ses.bc + if run >= uint32(len(ses.runs)) { + p.mtx.RUnlock() + return fmt.Errorf("unknown run") + } + rs = &ses.runs[run] + +Loop: + for { + // Pool is locked for reads. Count if the total number of + // expected messages have been received. + received := 0 + for hash := range rs.hashes { + msgtype := p.pool[hash].msgtype + switch { + case msgtype == msgtypeKE && r.KEs != nil: + received++ + case msgtype == msgtypeCT && r.CTs != nil: + received++ + case msgtype == msgtypeSR && r.SRs != nil: + received++ + case msgtype == msgtypeDC && r.DCs != nil: + received++ + case msgtype == msgtypeCM && r.CMs != nil: + received++ + case msgtype == msgtypeRS: + if r.RSs == nil { + // Since initial reporters of secrets + // need to take the blame for + // erroneous blame assignment if no + // issue was detected, we only trigger + // this for RS messages that do not + // reference any other previous RS. + prev := p.pool[hash].msg.(*wire.MsgMixSecrets).PrevMsgs + if len(prev) == 0 { + p.mtx.RUnlock() + return ErrSecretsRevealed + } + } else { + received++ + } + } + } + if received >= expectedMessages { + break + } + + // Unlock while waiting for the broadcast channel. + p.mtx.RUnlock() + + select { + case <-ctx.Done(): + // Set error to be returned, but still lock the pool + // and collect received messages. + err = ctx.Err() + p.mtx.RLock() + break Loop + case <-bc.wait(): + } + + p.mtx.RLock() + } + + // Pool is locked for reads. Collect all of the messages. + for hash := range rs.hashes { + msg := p.pool[hash].msg + switch msg := msg.(type) { + case *wire.MsgMixKeyExchange: + if r.KEs != nil { + r.KEs = append(r.KEs, msg) + } + case *wire.MsgMixCiphertexts: + if r.CTs != nil { + r.CTs = append(r.CTs, msg) + } + case *wire.MsgMixSlotReserve: + if r.SRs != nil { + r.SRs = append(r.SRs, msg) + } + case *wire.MsgMixDCNet: + if r.DCs != nil { + r.DCs = append(r.DCs, msg) + } + case *wire.MsgMixConfirm: + if r.CMs != nil { + r.CMs = append(r.CMs, msg) + } + case *wire.MsgMixSecrets: + if r.RSs != nil { + r.RSs = append(r.RSs, msg) + } + } + } + + p.mtx.RUnlock() + return err +} + +// AcceptMessage accepts a mixing message to the pool. +// +// Messages must contain the mixing participant's identity and contain a valid +// signature committing to all non-signature fields. +// +// PR messages will not be accepted if they reference an unknown UTXO or if not +// enough fee is contributed. Any other message will not be accepted if it +// references previous messages that are not recorded by the pool. +func (p *Pool) AcceptMessage(msg mixing.Message) (accepted mixing.Message, err error) { + hash := msg.Hash() + alreadyAccepted := func() bool { + _, ok := p.pool[hash] + if !ok { + _, ok = p.prs[hash] + } + return ok + } + + // Check if already accepted. + p.mtx.RLock() + ok := alreadyAccepted() + p.mtx.RUnlock() + if ok { + return nil, nil + } + + // Require message to be signed by the presented identity. + if !mixing.VerifyMessageSignature(msg) { + return nil, ruleError(ErrInvalidSignature) + } + id := (*idPubKey)(msg.Pub()) + + var msgtype msgtype + switch msg := msg.(type) { + case *wire.MsgMixPairReq: + accepted, err := p.acceptPR(msg, &hash, id) + if err != nil { + return nil, err + } + // Avoid returning a non-nil mixing.Message in return + // variable with a nil PR. + if accepted == nil { + return nil, nil + } + return accepted, nil + + case *wire.MsgMixKeyExchange: + accepted, err := p.acceptKE(msg, &hash, id) + if err != nil { + return nil, err + } + // Avoid returning a non-nil mixing.Message in return + // variable with a nil KE. + if accepted == nil { + return nil, nil + } + return accepted, nil + + case *wire.MsgMixCiphertexts: + msgtype = msgtypeCT + case *wire.MsgMixSlotReserve: + msgtype = msgtypeSR + case *wire.MsgMixDCNet: + msgtype = msgtypeDC + case *wire.MsgMixConfirm: + msgtype = msgtypeCM + case *wire.MsgMixSecrets: + msgtype = msgtypeRS + default: + return nil, fmt.Errorf("unknown mix message type %T", msg) + } + + sid := *(*[32]byte)(msg.Sid()) + + p.mtx.Lock() + defer p.mtx.Unlock() + + // Read lock was given up to acquire write lock. Check if already + // accepted. + if alreadyAccepted() { + return nil, nil + } + + // Check prior message existence in the pool, and only accept messages + // that reference other known and accepted messages of the correct type + // and sid. + // + // TODO: Consider return an error containing the unknown messages, so + // they can be getdata'd, and if they are not received or are garbage, + // peers can be kicked. + prevMsgs := msg.PrevMsgs() + for i := range prevMsgs { + looktype := msgtype - 1 + if msgtype == msgtypeRS { + looktype = 0 + } + _, ok := p.lookupEntry(prevMsgs[i], looktype, &sid) + if !ok { + return nil, fmt.Errorf("%s %s references unknown "+ + "previous message %s", msgtype, &hash, + &prevMsgs[i]) + } + } + + // Check that a message from this identity does not reuse a run number + // for the session. + for _, prevHash := range p.messagesByIdentity[*id] { + e := p.pool[prevHash] + run := msg.GetRun() + if e.msgtype == msgtype && e.msg.GetRun() == run && + bytes.Equal(e.msg.Sid(), msg.Sid()) { + return nil, fmt.Errorf("message %v by identity %x "+ + "reuses run number %d in session %x, "+ + "conflicting with already accepted message %v", + hash, *id, run, msg.Sid(), prevHash) + } + } + + ses := p.sessions[sid] + if ses == nil { + return nil, fmt.Errorf("%s %s belongs to unknown session %x", + msgtype, &hash, &sid) + } + + err = p.acceptEntry(msg, msgtype, &hash, id, ses) + if err != nil { + return nil, err + } + return msg, nil +} + +func (p *Pool) acceptPR(pr *wire.MsgMixPairReq, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixPairReq, err error) { + switch { + case len(pr.UTXOs) == 0: // Require at least one utxo. + return nil, ruleError(ErrMissingUTXOs) + case pr.MessageCount == 0: // Require at least one mixed message. + return nil, ruleError(ErrInvalidMessageCount) + case pr.InputValue < int64(pr.MessageCount)*pr.MixAmount: + return nil, ruleError(ErrInvalidTotalMixAmount) + case pr.Change != nil: + if isDustAmount(pr.Change.Value, p2pkhv0PkScriptSize, feeRate) { + return nil, ruleError(ErrChangeDust) + } + if !stdscript.IsPubKeyHashScriptV0(pr.Change.PkScript) && + !stdscript.IsScriptHashScriptV0(pr.Change.PkScript) { + return nil, ruleError(ErrInvalidScript) + } + } + + // Check that expiry has not been reached, nor that it is too far + // into the future. This limits replay attacks. + _, curHeight := p.blockchain.BestHeader() + maxExpiry := mixing.MaxExpiry(uint32(curHeight), p.params) + switch { + case uint32(curHeight) >= pr.Expiry: + return nil, fmt.Errorf("message has expired") + case pr.Expiry > maxExpiry: + return nil, fmt.Errorf("expiry is too far into future") + } + + // Require known script classes. + switch mixing.ScriptClass(pr.ScriptClass) { + case mixing.ScriptClassP2PKHv0: + default: + return nil, fmt.Errorf("unsupported mixing script class") + } + + // Require enough fee contributed from this mixing participant. + // Size estimation assumes mixing.ScriptClassP2PKHv0 outputs and inputs. + err = checkFee(pr, p.feeRate) + if err != nil { + return nil, err + } + + // If able, sanity check UTXOs. + if p.utxoFetcher != nil { + err := p.checkUTXOs(pr) + if err != nil { + return nil, err + } + } + + p.mtx.Lock() + defer p.mtx.Unlock() + + // Check if already accepted. + if _, ok := p.prs[*hash]; ok { + return nil, nil + } + + // Discourage identity reuse. PRs should be the first message sent by + // this identity, and there should only be one PR per identity. + if len(p.messagesByIdentity[*id]) != 0 { + return nil, fmt.Errorf("identity reused for a PR message") + } + + // Only accept PRs that double spend outpoints if they expire later + // than existing PRs. Otherwise, reject this PR message. + for i := range pr.UTXOs { + otherPRHash := p.outPoints[pr.UTXOs[i].OutPoint] + otherPR, ok := p.prs[otherPRHash] + if !ok { + continue + } + if otherPR.Expiry >= pr.Expiry { + err := fmt.Errorf("PR double spends outpoints of " + + "already-accepted PR message without " + + "increasing expiry") + return nil, err + } + } + + // Accept the PR + p.prs[*hash] = pr + for i := range pr.UTXOs { + p.outPoints[pr.UTXOs[i].OutPoint] = *hash + } + p.messagesByIdentity[*id] = append(make([]chainhash.Hash, 0, 16), *hash) + + return pr, nil +} + +// Check that UTXOs exist, have confirmations, sum of UTXO values matches the +// input value, and proof of ownership is valid. +func (p *Pool) checkUTXOs(pr *wire.MsgMixPairReq) error { + var totalValue int64 + _, curHeight := p.blockchain.BestHeader() + + for i := range pr.UTXOs { + utxo := &pr.UTXOs[i] + entry, err := p.utxoFetcher.FetchUtxoEntry(utxo.OutPoint) + if err != nil { + return err + } + if entry == nil || entry.IsSpent() { + return fmt.Errorf("output %v is not unspent", + &utxo.OutPoint) + } + height := entry.BlockHeight() + if !confirmed(minconf, height, curHeight) { + return fmt.Errorf("output %v is unconfirmed", + &utxo.OutPoint) + } + if entry.ScriptVersion() != 0 { + return fmt.Errorf("output %v does not use script version 0", + &utxo.OutPoint) + } + + // Check proof of key ownership and ability to sign coinjoin + // inputs. + utxoPkScript := entry.PkScript() + var extractPubKeyHash160 func([]byte) []byte + switch { + case stdscript.IsPubKeyHashScriptV0(utxoPkScript): + extractPubKeyHash160 = stdscript.ExtractPubKeyHashV0 + case stdscript.IsStakeGenPubKeyHashScriptV0(utxoPkScript): + extractPubKeyHash160 = stdscript.ExtractStakeGenPubKeyHashV0 + case stdscript.IsStakeRevocationPubKeyHashScriptV0(utxoPkScript): + extractPubKeyHash160 = stdscript.ExtractStakeRevocationPubKeyHashV0 + case stdscript.IsTreasuryGenPubKeyHashScriptV0(utxoPkScript): + extractPubKeyHash160 = stdscript.ExtractTreasuryGenPubKeyHashV0 + default: + return fmt.Errorf("unsupported output script for UTXO %s", &utxo.OutPoint) + } + valid := validateOwnerProofP2PKHv0(extractPubKeyHash160, + utxoPkScript, utxo.PubKey, utxo.Signature, pr.Expires()) + if !valid { + return ruleError(ErrInvalidUTXOProof) + } + + totalValue += entry.Amount() + } + + if totalValue != pr.InputValue { + return fmt.Errorf("input value does not match sum of UTXO " + + "values") + } + + return nil +} + +func validateOwnerProofP2PKHv0(extractFunc func([]byte) []byte, pkscript, pubkey, sig []byte, expires uint32) bool { + extractedHash160 := extractFunc(pkscript) + pubkeyHash160 := dcrutil.Hash160(pubkey) + if !bytes.Equal(extractedHash160, pubkeyHash160) { + return false + } + + return utxoproof.ValidateSecp256k1P2PKH(pubkey, sig, expires) +} + +func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixKeyExchange, err error) { + sid := ke.SessionID + + // In all runs, previous PR messages in the KE must be sorted. + // This defines the initial unmixed peer positions. + sorted := sort.SliceIsSorted(ke.SeenPRs, func(i, j int) bool { + a := ke.SeenPRs[i][:] + b := ke.SeenPRs[j][:] + return bytes.Compare(a, b) == -1 + }) + if !sorted { + err := fmt.Errorf("KE message contains unsorted previous PR " + + "hashes") + return nil, err + } + + // Run-0 KE messages define a session ID by hashing all previously-seen + // PR message hashes. This must match the sid also present in the + // message. Later runs after a failed run may drop peers from the + // SeenPRs set, but the sid remains the same. A sid cannot be conjured + // out of thin air, and other messages seen from the network for an + // unknown session are not accepted. + if ke.Run == 0 { + derivedSid := mixing.DeriveSessionID(ke.SeenPRs, ke.Epoch) + if sid != derivedSid { + return nil, ruleError(ErrInvalidSessionID) + } + } + + p.mtx.Lock() + defer p.mtx.Unlock() + + // Check if already accepted. + if _, ok := p.pool[*hash]; ok { + return nil, nil + } + + // While KEs are allowed to reference unknown PRs, they must at least + // reference the PR submitted by their own identity. The + // messagesByIdentity map will contain at least one entry after a PR + // is received, and their PR will be the first hash in this slice. + // Of all PRs that are known, their pairing types must be compatible. + if len(p.messagesByIdentity[ke.Identity]) == 0 { + err := fmt.Errorf("KE identity %x has not submitted a PR", ke.Identity) + return nil, err + } + ownPRHash := p.messagesByIdentity[ke.Identity][0] + var ownPR *wire.MsgMixPairReq + prs := make([]*wire.MsgMixPairReq, 0, len(ke.SeenPRs)) + var pairing []byte + for _, seenPR := range ke.SeenPRs { + pr, ok := p.prs[seenPR] + if !ok { + continue + } + if seenPR == ownPRHash { + ownPR = pr + } + if pairing == nil { + var err error + pairing, err = pr.Pairing() + if err != nil { + return nil, err + } + } else { + pairing2, err := pr.Pairing() + if err != nil { + return nil, err + } + if !bytes.Equal(pairing, pairing2) { + err := fmt.Errorf("referenced PRs are incompatible") + return nil, err + } + } + } + if ownPR == nil { + err := fmt.Errorf("KE identity's own PR unexpectedly missing from mixpool") + return nil, err + } + + ses := p.sessions[sid] + + // Create a session for the first run-0 KE + if ses == nil { + if ke.Run != 0 { + err := fmt.Errorf("unknown session for run-%d KE", + ke.Run) + return nil, err + } + + expiry := ^uint32(0) + for i := range prs { + prExpiry := prs[i].Expires() + if expiry > prExpiry { + expiry = prExpiry + } + } + ses = &session{ + sid: sid, + runs: make([]runstate, 0, 4), + expiry: expiry, + bc: broadcast{ch: make(chan struct{})}, + } + p.sessions[sid] = ses + } + + err = p.acceptEntry(ke, msgtypeKE, hash, id, ses) + if err != nil { + return nil, err + } + p.latestKE[*id] = ke + return ke, nil +} + +func (p *Pool) acceptEntry(msg mixing.Message, msgtype msgtype, hash *chainhash.Hash, + id *[33]byte, ses *session) error { + + run := msg.GetRun() + if msg.GetRun() > uint32(len(ses.runs)) { + return fmt.Errorf("message skips runs") + } + + var rs *runstate + if msgtype == msgtypeKE && msg.GetRun() == uint32(len(ses.runs)) { + // Add a runstate for the next run. + ses.runs = append(ses.runs, runstate{ + run: msg.GetRun(), + prs: msg.PrevMsgs(), + hashes: make(map[chainhash.Hash]struct{}), + }) + rs = &ses.runs[len(ses.runs)-1] + } else { + // Add to existing runstate + rs = &ses.runs[run] + } + + rs.hashes[*hash] = struct{}{} + e := entry{ + hash: *hash, + sid: ses.sid, + recvTime: time.Now(), + msg: msg, + msgtype: msgtype, + run: msg.GetRun(), + } + p.pool[*hash] = e + p.messagesByIdentity[*id] = append(p.messagesByIdentity[*id], *hash) + + if cm, ok := msg.(*wire.MsgMixConfirm); ok { + p.sessionsByTxHash[cm.Mix.TxHash()] = ses + } + + rs.incrementCountFor(msgtype) + ses.bc.signal() + + return nil +} + +// lookupEntry returns the message entry matching a message hash with msgtype +// and session id. If msgtype is zero, any message type can be looked up. +func (p *Pool) lookupEntry(hash chainhash.Hash, msgtype msgtype, sid *[32]byte) (entry, bool) { + e, ok := p.pool[hash] + if !ok { + return entry{}, false + } + if msgtype != 0 && e.msgtype != msgtype { + return entry{}, false + } + if e.sid != *sid { + return entry{}, false + } + + return e, true +} + +func confirmed(minConf, txHeight, curHeight int64) bool { + return confirms(txHeight, curHeight) >= minConf +} +func confirms(txHeight, curHeight int64) int64 { + switch { + case txHeight == -1, txHeight > curHeight: + return 0 + default: + return curHeight - txHeight + 1 + } +} + +// isDustAmount determines whether a transaction output value and script length would +// cause the output to be considered dust. Transactions with dust outputs are +// not standard and are rejected by mempools with default policies. +func isDustAmount(amount int64, scriptSize int, relayFeePerKb int64) bool { + // Calculate the total (estimated) cost to the network. This is + // calculated using the serialize size of the output plus the serial + // size of a transaction input which redeems it. The output is assumed + // to be compressed P2PKH as this is the most common script type. Use + // the average size of a compressed P2PKH redeem input (165) rather than + // the largest possible (txsizes.RedeemP2PKHInputSize). + totalSize := 8 + 2 + wire.VarIntSerializeSize(uint64(scriptSize)) + + scriptSize + 165 + + // Dust is defined as an output value where the total cost to the network + // (output size + input size) is greater than 1/3 of the relay fee. + return amount*1000/(3*int64(totalSize)) < relayFeePerKb +} + +func checkFee(pr *wire.MsgMixPairReq, feeRate int64) error { + fee := pr.InputValue - int64(pr.MessageCount)*pr.MixAmount + if pr.Change != nil { + fee -= pr.Change.Value + } + + estimatedSize := estimateP2PKHv0SerializeSize(len(pr.UTXOs), + int(pr.MessageCount), pr.Change != nil) + requiredFee := feeForSerializeSize(feeRate, estimatedSize) + if fee < requiredFee { + return fmt.Errorf("not enough input value, or too low fee") + } + + return nil +} + +func feeForSerializeSize(relayFeePerKb int64, txSerializeSize int) int64 { + fee := relayFeePerKb * int64(txSerializeSize) / 1000 + + if fee == 0 && relayFeePerKb > 0 { + fee = relayFeePerKb + } + + const maxAmount = 21e6 * 1e8 + if fee < 0 || fee > maxAmount { + fee = maxAmount + } + + return fee +} + +const ( + redeemP2PKHv0SigScriptSize = 1 + 73 + 1 + 33 + p2pkhv0PkScriptSize = 1 + 1 + 1 + 20 + 1 + 1 +) + +func estimateP2PKHv0SerializeSize(inputs, outputs int, hasChange bool) int { + // Sum the estimated sizes of the inputs and outputs. + txInsSize := inputs * estimateInputSize(redeemP2PKHv0SigScriptSize) + txOutsSize := outputs * estimateOutputSize(p2pkhv0PkScriptSize) + + changeSize := 0 + if hasChange { + changeSize = estimateOutputSize(p2pkhv0PkScriptSize) + outputs++ + } + + // 12 additional bytes are for version, locktime and expiry. + return 12 + (2 * wire.VarIntSerializeSize(uint64(inputs))) + + wire.VarIntSerializeSize(uint64(outputs)) + + txInsSize + txOutsSize + changeSize +} + +// estimateInputSize returns the worst case serialize size estimate for a tx input +func estimateInputSize(scriptSize int) int { + return 32 + // previous tx + 4 + // output index + 1 + // tree + 8 + // amount + 4 + // block height + 4 + // block index + wire.VarIntSerializeSize(uint64(scriptSize)) + // size of script + scriptSize + // script itself + 4 // sequence +} + +// estimateOutputSize returns the worst case serialize size estimate for a tx output +func estimateOutputSize(scriptSize int) int { + return 8 + // previous tx + 2 + // version + wire.VarIntSerializeSize(uint64(scriptSize)) + // size of script + scriptSize // script itself +} diff --git a/wire/msgmixsecrets.go b/wire/msgmixsecrets.go index 53d0b06e2d..7532cc0bf5 100644 --- a/wire/msgmixsecrets.go +++ b/wire/msgmixsecrets.go @@ -25,6 +25,7 @@ type MsgMixSecrets struct { Seed [32]byte SlotReserveMsgs [][]byte DCNetMsgs MixVect + SeenSecrets []chainhash.Hash // hash records the hash of the message. It is a member of the // message for convenience and performance, but is never automatically @@ -43,6 +44,9 @@ func writeMixVect(op string, w io.Writer, pver uint32, vec MixVect) error { if err != nil { return err } + if len(vec) == 0 { + return nil + } err = WriteVarInt(w, pver, MixMsgSize) if err != nil { return err @@ -138,6 +142,25 @@ func (msg *MsgMixSecrets) BtcDecode(r io.Reader, pver uint32) error { } msg.DCNetMsgs = dcnetMsgs + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + seen := make([]chainhash.Hash, count) + for i := range seen { + err := readElement(r, &seen[i]) + if err != nil { + return err + } + } + msg.SeenSecrets = seen + return nil } @@ -173,6 +196,17 @@ func (msg *MsgMixSecrets) Hash() chainhash.Hash { return msg.hash } +// Commitment returns a hash committing to the contents of the reveal secrets +// message without committing to any previous seen messages or the message +// signature. This is the hash that is referenced by peers' key exchange +// messages. +func (msg *MsgMixSecrets) Commitment(h hash.Hash) chainhash.Hash { + msgCopy := *msg + msgCopy.SeenSecrets = nil + msgCopy.WriteHash(h) + return msgCopy.hash +} + // WriteHash serializes the message to a hasher and records the sum in the // message's Hash field. // @@ -196,6 +230,14 @@ func (msg *MsgMixSecrets) WriteHash(h hash.Hash) { // // This method never errors for invalid message construction. func (msg *MsgMixSecrets) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + // Limit to max previous messages hashes. + count := len(msg.SeenSecrets) + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run, &msg.Seed) if err != nil { return err @@ -217,6 +259,17 @@ func (msg *MsgMixSecrets) writeMessageNoSignature(op string, w io.Writer, pver u return err } + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := range msg.SeenSecrets { + err := writeElement(w, &msg.SeenSecrets[i]) + if err != nil { + return err + } + } + return nil } @@ -242,7 +295,7 @@ func (msg *MsgMixSecrets) MaxPayloadLength(pver uint32) uint32 { } // See tests for this calculation - return 54444 + return 70831 } // Pub returns the message sender's public key identity. @@ -255,13 +308,14 @@ func (msg *MsgMixSecrets) Sig() []byte { return msg.Signature[:] } -// PrevMsgs returns nil. Previous messages are not needed to perform blame -// assignment, because of the assumption that all previous messages must have -// been received for a blame stage to be necessary. Additionally, a -// commitment to the secrets message is included in the key exchange, and -// future message hashes are not available at that time. +// PrevMsgs returns previously revealed secrets messages by other peers. An +// honest peer who needs to report blame assignment does not need to reference +// any previous secrets messages, and a secrets message with other referenced +// secrets is necessary to begin blame assignment. Dishonest peers who +// initially reveal their secrets without blame assignment being necessary are +// themselves removed in future runs. func (msg *MsgMixSecrets) PrevMsgs() []chainhash.Hash { - return nil + return msg.SeenSecrets } // Sid returns the session ID. @@ -287,5 +341,6 @@ func NewMsgMixSecrets(identity [33]byte, sid [32]byte, run uint32, Seed: seed, SlotReserveMsgs: slotReserveMsgs, DCNetMsgs: dcNetMsgs, + SeenSecrets: make([]chainhash.Hash, 0), } } diff --git a/wire/msgmixsecrets_test.go b/wire/msgmixsecrets_test.go index 2905d94eb4..e0fc7935a4 100644 --- a/wire/msgmixsecrets_test.go +++ b/wire/msgmixsecrets_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" ) func newTestMixSecrets() *MsgMixSecrets { @@ -34,7 +35,13 @@ func newTestMixSecrets() *MsgMixSecrets { copy(m[b-0x89][:], repeat(b, 20)) } + seenRSs := make([]chainhash.Hash, 4) + for b := byte(0x8D); b < 0x91; b++ { + copy(seenRSs[b-0x8D][:], repeat(b, 32)) + } + rs := NewMsgMixSecrets(id, sid, run, seed, sr, m) + rs.SeenSecrets = seenRSs rs.Signature = sig return rs @@ -67,12 +74,18 @@ func TestMsgMixSecretsWire(t *testing.T) { expected = append(expected, repeat(0x87, 32)...) expected = append(expected, 0x20) expected = append(expected, repeat(0x88, 32)...) - // Four slot reservation mixed messages (repeating 20 bytes of 0x89, 0x8a, 0x8b, 0x8c) + // Four xor dc-net mixed messages (repeating 20 bytes of 0x89, 0x8a, 0x8b, 0x8c) expected = append(expected, 0x04, 0x14) expected = append(expected, repeat(0x89, 20)...) expected = append(expected, repeat(0x8a, 20)...) expected = append(expected, repeat(0x8b, 20)...) expected = append(expected, repeat(0x8c, 20)...) + // Four seen RSs (repeating 32 bytes of 0x8d, 0x8e, 0x8f, 0x90) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x8d, 32)...) + expected = append(expected, repeat(0x8e, 32)...) + expected = append(expected, repeat(0x8f, 32)...) + expected = append(expected, repeat(0x90, 32)...) expectedSerializationEqual(t, buf.Bytes(), expected) @@ -168,7 +181,9 @@ func TestMsgMixSecretsMaxPayloadLength(t *testing.T) { MaxMixMcount*varBytesLen(MaxMixFieldValLen) + // Unpadded SR values uint32(VarIntSerializeSize(MaxMixMcount)) + // DC-net message count uint32(VarIntSerializeSize(MixMsgSize)) + // DC-net message size - MaxMixMcount*MixMsgSize // DC-net messages + MaxMixMcount*MixMsgSize + // DC-net messages + uint32(VarIntSerializeSize(MaxMixPeers)) + // RS count + 32*MaxMixPeers // RS hashes tests := []struct { name string