diff --git a/node/neighbor.go b/node/neighbor.go index 3438e68ae..26ab822cd 100644 --- a/node/neighbor.go +++ b/node/neighbor.go @@ -468,9 +468,9 @@ func (localNode *LocalNode) VerifySigChain(sc *pb.SigChain, height uint32) error return nil } -func (localNode *LocalNode) VerifySigChainObjection(sc *pb.SigChain, reporterID []byte, height uint32) error { +func (localNode *LocalNode) VerifySigChainObjection(sc *pb.SigChain, reporterID []byte, height uint32) (int, error) { if !config.SigChainObjection.GetValueAtHeight(height) { - return fmt.Errorf("sigchain objection is not enabled") + return 0, fmt.Errorf("sigchain objection is not enabled") } if config.SigChainVerifySkipNode.GetValueAtHeight(height) { @@ -480,13 +480,14 @@ func (localNode *LocalNode) VerifySigChainObjection(sc *pb.SigChain, reporterID dist := chord.Distance(sc.Elems[i].Id, sc.DestId, config.NodeIDBytes*8) fingerIdx := dist.BitLen() - 1 fingerStartID := chord.PowerOffset(sc.Elems[i].Id, uint32(fingerIdx), config.NodeIDBytes*8) - if !chord.BetweenLeftIncl(fingerStartID, sc.Elems[i+1].Id, reporterID) { - return fmt.Errorf("reporter is not skipped") + if chord.BetweenLeftIncl(fingerStartID, sc.Elems[i+1].Id, reporterID) { + return i, nil } } + return 0, fmt.Errorf("reporter is not skipped") } - return nil + return 0, nil } // ShouldRejectAddr returns if remoteAddr should be rejected by localAddr diff --git a/por/porserver.go b/por/porserver.go index 2f9b4d74d..7addd71ee 100644 --- a/por/porserver.go +++ b/por/porserver.go @@ -69,7 +69,7 @@ var store Store // LocalNode interface is used to avoid cyclic dependency type LocalNode interface { - VerifySigChainObjection(sc *pb.SigChain, reporterID []byte, height uint32) error + VerifySigChainObjection(sc *pb.SigChain, reporterID []byte, height uint32) (int, error) } var localNode LocalNode @@ -113,7 +113,7 @@ type BacktrackSigChainInfo struct { type sigChainObjection struct { reporterPubkey []byte reporterID []byte - isVerified bool + skippedHop int } type sigChainObjections []*sigChainObjection @@ -268,7 +268,7 @@ func (ps *PorServer) GetMiningSigChainTxnHash(voteForHeight uint32) (common.Uint if v, ok := ps.sigChainObjectionCache.Get(porPkg.SigHash); ok { if scos, ok := v.(sigChainObjections); ok { if len(scos) >= MaxNextHopChoice { - verifiedCount := 0 + verifiedCount := make(map[int]int) for _, sco := range scos { if len(sco.reporterID) == 0 { id, err := store.GetID(sco.reporterPubkey, height) @@ -276,20 +276,27 @@ func (ps *PorServer) GetMiningSigChainTxnHash(voteForHeight uint32) (common.Uint continue } sco.reporterID = id - err = localNode.VerifySigChainObjection(porPkg.SigChain, id, height) + i, err := localNode.VerifySigChainObjection(porPkg.SigChain, id, height) if err != nil { continue } - sco.isVerified = true + sco.skippedHop = i } - if sco.isVerified { - verifiedCount++ + if sco.skippedHop > 0 { + verifiedCount[sco.skippedHop]++ } - if verifiedCount >= MaxNextHopChoice { + if verifiedCount[sco.skippedHop] >= MaxNextHopChoice { break } } - if verifiedCount >= MaxNextHopChoice { + isSigChainInvalid := false + for _, count := range verifiedCount { + if count >= MaxNextHopChoice { + isSigChainInvalid = true + break + } + } + if isSigChainInvalid { continue } } @@ -703,50 +710,61 @@ func (ps *PorServer) AddSigChainObjection(currentHeight, voteForHeight uint32, s return false } sco.reporterID = reporterID - err = localNode.VerifySigChainObjection(porPkg.SigChain, reporterID, porPkg.Height) + i, err := localNode.VerifySigChainObjection(porPkg.SigChain, reporterID, porPkg.Height) if err != nil { return false } - sco.isVerified = true + sco.skippedHop = i } var scos sigChainObjections if v, ok := ps.sigChainObjectionCache.Get(sigHash); ok { var ok bool if scos, ok = v.(sigChainObjections); ok { - verifiedCount := 0 + verifiedCount := make(map[int]int) + needVerify := false for _, s := range scos { - if bytes.Compare(s.reporterPubkey, reporterPubkey) == 0 { + if bytes.Equal(s.reporterPubkey, reporterPubkey) { return false } - if s.isVerified { - verifiedCount++ + if s.skippedHop > 0 { + verifiedCount[s.skippedHop]++ + } + if len(s.reporterID) == 0 { + needVerify = true } } - if verifiedCount >= MaxNextHopChoice { - return false + for _, count := range verifiedCount { + if count >= MaxNextHopChoice { + return false + } } - if porPkg != nil && len(scos) >= verifiedCount { + if porPkg != nil && needVerify { verifiedScos := sigChainObjections{} + verifiedCount := make(map[int]int) for _, s := range scos { reporterID, err := store.GetID(s.reporterPubkey, porPkg.Height) if err != nil { continue } s.reporterID = reporterID - err = localNode.VerifySigChainObjection(porPkg.SigChain, reporterID, porPkg.Height) + i, err := localNode.VerifySigChainObjection(porPkg.SigChain, reporterID, porPkg.Height) if err != nil { continue } - s.isVerified = true + s.skippedHop = i verifiedScos = append(verifiedScos, s) + verifiedCount[s.skippedHop]++ } scos = verifiedScos - if len(scos) >= MaxNextHopChoice { - ps.sigChainObjectionCache.Set(sigHash, scos) - return false + + for _, count := range verifiedCount { + if count >= MaxNextHopChoice { + ps.sigChainObjectionCache.Set(sigHash, scos) + return false + } } }