diff --git a/sync2/dbset/dbset.go b/sync2/dbset/dbset.go new file mode 100644 index 0000000000..daeccd7bae --- /dev/null +++ b/sync2/dbset/dbset.go @@ -0,0 +1,275 @@ +package dbset + +import ( + "fmt" + "maps" + "sync" + "time" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/fptree" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +// DBSet is an implementation of rangesync.OrderedSet that uses an SQL database +// as its backing store. It uses an FPTree to perform efficient range queries. +type DBSet struct { + loadMtx sync.Mutex + db sql.Executor + ft *fptree.FPTree + st *sqlstore.SyncedTable + snapshot *sqlstore.SyncedTableSnapshot + dbStore *fptree.DBBackedStore + keyLen int + maxDepth int + received map[string]struct{} +} + +var _ rangesync.OrderedSet = &DBSet{} + +// NewDBSet creates a new DBSet. +func NewDBSet( + db sql.Executor, + st *sqlstore.SyncedTable, + keyLen, maxDepth int, +) *DBSet { + return &DBSet{ + db: db, + st: st, + keyLen: keyLen, + maxDepth: maxDepth, + } +} + +func (d *DBSet) handleIDfromDB(stmt *sql.Statement) bool { + id := make(rangesync.KeyBytes, d.keyLen) + stmt.ColumnBytes(0, id[:]) + d.ft.AddStoredKey(id) + return true +} + +// EnsureLoaded ensures that the DBSet is loaded and ready to be used. +func (d *DBSet) EnsureLoaded() error { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if d.ft != nil { + return nil + } + var err error + d.snapshot, err = d.st.Snapshot(d.db) + if err != nil { + return fmt.Errorf("error taking snapshot: %w", err) + } + count, err := d.snapshot.LoadCount(d.db) + if err != nil { + return fmt.Errorf("error loading count: %w", err) + } + d.dbStore = fptree.NewDBBackedStore(d.db, d.snapshot, count, d.keyLen) + d.ft = fptree.NewFPTree(count, d.dbStore, d.keyLen, d.maxDepth) + return d.snapshot.Load(d.db, d.handleIDfromDB) +} + +// Received returns a sequence of all items that have been received. +// Implements rangesync.OrderedSet. +func (d *DBSet) Received() rangesync.SeqResult { + return rangesync.SeqResult{ + Seq: func(yield func(k rangesync.KeyBytes) bool) { + for k := range d.received { + if !yield(rangesync.KeyBytes(k)) { + return + } + } + }, + Error: rangesync.NoSeqError, + } +} + +// Add adds an item to the DBSet. +// Implements rangesync.OrderedSet. +func (d *DBSet) Add(k rangesync.KeyBytes) error { + if has, err := d.Has(k); err != nil { + return fmt.Errorf("checking if item exists: %w", err) + } else if has { + return nil + } + d.ft.RegisterKey(k) + return nil +} + +// Receive handles a newly received item, arranging for it to be returned as part of the +// sequence returned by Received. +// Implements rangesync.OrderedSet. +func (d *DBSet) Receive(k rangesync.KeyBytes) error { + if d.received == nil { + d.received = make(map[string]struct{}) + } + d.received[string(k)] = struct{}{} + return nil +} + +func (d *DBSet) firstItem() (rangesync.KeyBytes, error) { + if err := d.EnsureLoaded(); err != nil { + return nil, err + } + return d.ft.All().First() +} + +// GetRangeInfo returns information about the range of items in the DBSet. +// Implements rangesync.OrderedSet. +func (d *DBSet) GetRangeInfo(x, y rangesync.KeyBytes) (rangesync.RangeInfo, error) { + if err := d.EnsureLoaded(); err != nil { + return rangesync.RangeInfo{}, err + } + if d.ft.Count() == 0 { + return rangesync.RangeInfo{ + Items: rangesync.EmptySeqResult(), + }, nil + } + if x == nil || y == nil { + if x != nil || y != nil { + panic("BUG: GetRangeInfo called with one of x/y nil but not both") + } + var err error + x, err = d.firstItem() + if err != nil { + return rangesync.RangeInfo{}, fmt.Errorf("getting first item: %w", err) + } + y = x + } + fpr, err := d.ft.FingerprintInterval(x, y, -1) + if err != nil { + return rangesync.RangeInfo{}, err + } + return rangesync.RangeInfo{ + Fingerprint: fpr.FP, + Count: int(fpr.Count), + Items: fpr.Items, + }, nil +} + +// SplitRange splits the range of items in the DBSet into two parts, +// returning information about eachn part and the middle item. +// Implements rangesync.OrderedSet. +func (d *DBSet) SplitRange(x, y rangesync.KeyBytes, count int) (rangesync.SplitInfo, error) { + if count <= 0 { + panic("BUG: bad split count") + } + + if err := d.EnsureLoaded(); err != nil { + return rangesync.SplitInfo{}, err + } + + sr, err := d.ft.Split(x, y, count) + if err != nil { + return rangesync.SplitInfo{}, err + } + + return rangesync.SplitInfo{ + Parts: [2]rangesync.RangeInfo{ + { + Fingerprint: sr.Part0.FP, + Count: int(sr.Part0.Count), + Items: sr.Part0.Items, + }, + { + Fingerprint: sr.Part1.FP, + Count: int(sr.Part1.Count), + Items: sr.Part1.Items, + }, + }, + Middle: sr.Middle, + }, nil +} + +// Items returns a sequence of all items in the DBSet. +// Implements rangesync.OrderedSet. +func (d *DBSet) Items() rangesync.SeqResult { + if err := d.EnsureLoaded(); err != nil { + return rangesync.ErrorSeqResult(err) + } + return d.ft.All() +} + +// Empty returns true if the DBSet is empty. +// Implements rangesync.OrderedSet. +func (d *DBSet) Empty() (bool, error) { + if err := d.EnsureLoaded(); err != nil { + return false, err + } + return d.ft.Count() == 0, nil +} + +// Advance advances the DBSet to the latest state of the underlying database table. +func (d *DBSet) Advance() error { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if d.ft == nil { + // FIXME + panic("BUG: can't advance the DBItemStore before it's loaded") + } + oldSnapshot := d.snapshot + var err error + d.snapshot, err = d.st.Snapshot(d.db) + if err != nil { + return fmt.Errorf("error taking snapshot: %w", err) + } + d.dbStore.SetSnapshot(d.snapshot) + return d.snapshot.LoadSinceSnapshot(d.db, oldSnapshot, d.handleIDfromDB) +} + +// Copy creates a copy of the DBSet. +// Implements rangesync.OrderedSet. +func (d *DBSet) Copy(syncScope bool) rangesync.OrderedSet { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if d.ft == nil { + // FIXME + panic("BUG: can't copy the DBItemStore before it's loaded") + } + ft := d.ft.Clone().(*fptree.FPTree) + return &DBSet{ + db: d.db, + ft: ft, + st: d.st, + keyLen: d.keyLen, + maxDepth: d.maxDepth, + dbStore: d.dbStore, + received: maps.Clone(d.received), + } +} + +// Has returns true if the DBSet contains the given item. +func (d *DBSet) Has(k rangesync.KeyBytes) (bool, error) { + if err := d.EnsureLoaded(); err != nil { + return false, err + } + + // checkKey may have false positives, but not false negatives, and it's much + // faster than querying the database + if !d.ft.CheckKey(k) { + return false, nil + } + + first, err := d.dbStore.From(k, 1).First() + if err != nil { + return false, err + } + return first.Compare(k) == 0, nil +} + +// Recent returns a sequence of items that have been added to the DBSet since the given time. +func (d *DBSet) Recent(since time.Time) (rangesync.SeqResult, int) { + return d.dbStore.Since(make(rangesync.KeyBytes, d.keyLen), since.UnixNano()) +} + +// Release releases resources associated with the DBSet. +func (d *DBSet) Release() error { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if d.ft != nil { + d.ft.Release() + d.ft = nil + } + return nil +} diff --git a/sync2/dbset/dbset_test.go b/sync2/dbset/dbset_test.go new file mode 100644 index 0000000000..3d5654d9f0 --- /dev/null +++ b/sync2/dbset/dbset_test.go @@ -0,0 +1,353 @@ +package dbset_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/dbset" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +const ( + testKeyLen = 32 + testDepth = 24 +) + +func requireEmpty(t *testing.T, sr rangesync.SeqResult) { + for range sr.Seq { + require.Fail(t, "expected an empty sequence") + } + require.NoError(t, sr.Error()) +} + +func firstKey(t *testing.T, sr rangesync.SeqResult) rangesync.KeyBytes { + k, err := sr.First() + require.NoError(t, err) + return k +} + +func TestDBSet_Empty(t *testing.T) { + db := sqlstore.PopulateDB(t, testKeyLen, nil) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + empty, err := s.Empty() + require.NoError(t, err) + require.True(t, empty) + requireEmpty(t, s.Items()) + requireEmpty(t, s.Received()) + + info, err := s.GetRangeInfo(nil, nil) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.String()) + requireEmpty(t, info.Items) + + info, err = s.GetRangeInfo( + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.String()) + requireEmpty(t, info.Items) + + info, err = s.GetRangeInfo( + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("9999000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.String()) + requireEmpty(t, info.Items) +} + +func TestDBSet(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, s.Items()).String()) + has, err := s.Has( + rangesync.MustParseHexKeyBytes("9876000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + require.False(t, has) + + for _, tc := range []struct { + xIdx, yIdx int + limit int + fp string + count int + startIdx, endIdx int + }{ + { + xIdx: 1, + yIdx: 1, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + startIdx: 1, + endIdx: 1, + }, + { + xIdx: -1, + yIdx: -1, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 3, + limit: -1, + fp: "4761032dcfe98ba555555555", + count: 3, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 2, + yIdx: 0, + limit: -1, + fp: "761032cfe98ba54ddddddddd", + count: 3, + startIdx: 2, + endIdx: 0, + }, + { + xIdx: 3, + yIdx: 2, + limit: 3, + fp: "2345679abcdef01888888888", + count: 3, + startIdx: 3, + endIdx: 1, + }, + } { + name := fmt.Sprintf("%d-%d_%d", tc.xIdx, tc.yIdx, tc.limit) + t.Run(name, func(t *testing.T) { + var x, y rangesync.KeyBytes + if tc.xIdx >= 0 { + x = ids[tc.xIdx] + y = ids[tc.yIdx] + } + t.Logf("x %v y %v limit %d", x, y, tc.limit) + var info rangesync.RangeInfo + if tc.limit < 0 { + info, err = s.GetRangeInfo(x, y) + require.NoError(t, err) + } else { + sr, err := s.SplitRange(x, y, tc.limit) + require.NoError(t, err) + info = sr.Parts[0] + } + require.Equal(t, tc.count, info.Count) + require.Equal(t, tc.fp, info.Fingerprint.String()) + require.Equal(t, ids[tc.startIdx], firstKey(t, info.Items)) + has, err := s.Has(ids[tc.startIdx]) + require.NoError(t, err) + require.True(t, has) + has, err = s.Has(ids[tc.endIdx]) + require.NoError(t, err) + require.True(t, has) + }) + } +} + +func TestDBItemStore_Receive(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, s.Items()).String()) + + newID := rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000") + require.NoError(t, s.Receive(newID)) + + recvd := s.Received() + items, err := recvd.FirstN(1) + require.NoError(t, err) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{newID}, items) + + info, err := s.GetRangeInfo(ids[2], ids[0]) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) +} + +func TestDBItemStore_Copy(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, s.Items()).String()) + + copy := s.Copy(false) + + info, err := copy.GetRangeInfo(ids[2], ids[0]) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) + + newID := rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000") + require.NoError(t, copy.Receive(newID)) + + info, err = s.GetRangeInfo(ids[2], ids[0]) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) + + items, err := s.Received().FirstN(100) + require.NoError(t, err) + require.Empty(t, items) + + info, err = s.GetRangeInfo(ids[2], ids[0]) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) + + items, err = copy.(*dbset.DBSet).Received().FirstN(100) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{newID}, items) +} + +func TestDBItemStore_Advance(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + require.NoError(t, s.EnsureLoaded()) + + copy := s.Copy(false) + + info, err := s.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + sqlstore.InsertDBItems(t, db, []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000"), + }) + + info, err = s.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + require.NoError(t, s.Advance()) + + info, err = s.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 5, info.Count) + require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = s.Copy(false).GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 5, info.Count) + require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) +} + +func TestDBSet_Added(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + requireEmpty(t, s.Received()) + + add := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("3333333333333333333333333333333333333333333333333333333333333333"), + rangesync.MustParseHexKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + } + for _, item := range add { + require.NoError(t, s.Receive(item)) + } + + require.NoError(t, s.EnsureLoaded()) + + added, err := s.Received().FirstN(3) + require.NoError(t, err) + require.ElementsMatch(t, []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("3333333333333333333333333333333333333333333333333333333333333333"), + rangesync.MustParseHexKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + }, added) + + added1, err := s.Copy(false).(*dbset.DBSet).Received().FirstN(3) + require.NoError(t, err) + require.ElementsMatch(t, added, added1) +} diff --git a/sync2/dbset/p2p_test.go b/sync2/dbset/p2p_test.go new file mode 100644 index 0000000000..598c7415ae --- /dev/null +++ b/sync2/dbset/p2p_test.go @@ -0,0 +1,448 @@ +package dbset_test + +import ( + "context" + "errors" + "io" + "slices" + "testing" + "time" + + "github.com/jonboulle/clockwork" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/p2p/server" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/dbset" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +var startDate = time.Date(2024, 8, 29, 18, 0, 0, 0, time.UTC) + +type fooRow struct { + id rangesync.KeyBytes + ts int64 +} + +func insertRow(t *testing.T, db sql.Executor, row fooRow) { + _, err := db.Exec( + "insert into foo (id, received) values (?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, row.id) + stmt.BindInt64(2, row.ts) + }, nil) + require.NoError(t, err) +} + +func populateFoo(t *testing.T, rows []fooRow) sql.Database { + db := sql.InMemoryTest(t) + require.NoError(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { + _, err := tx.Exec( + "create table foo(id char(32) not null primary key, received int)", + nil, nil) + require.NoError(t, err) + for _, row := range rows { + insertRow(t, tx, row) + } + return nil + })) + return db +} + +type syncTracer struct { + receivedItems int + sentItems int +} + +var _ rangesync.Tracer = &syncTracer{} + +func (tr *syncTracer) OnDumbSync() {} + +func (tr *syncTracer) OnRecent(receivedItems, sentItems int) { + tr.receivedItems += receivedItems + tr.sentItems += sentItems +} + +func addReceived(t *testing.T, db sql.Executor, to, from *dbset.DBSet) { + sr := from.Received() + for k := range sr.Seq { + has, err := to.Has(k) + require.NoError(t, err) + if !has { + insertRow(t, db, fooRow{id: k, ts: time.Now().UnixNano()}) + } + } + require.NoError(t, sr.Error()) + require.NoError(t, to.Advance()) +} + +func verifyP2P( + t *testing.T, + rowsA, rowsB []fooRow, + combinedItems []rangesync.KeyBytes, + clockAt time.Time, + receivedRecent, sentRecent bool, + opts ...rangesync.RangeSetReconcilerOption, +) { + const maxDepth = 24 + log := zaptest.NewLogger(t) + dbAx := populateFoo(t, rowsA) + dbA, err := dbAx.Tx(context.Background()) + require.NoError(t, err) + defer dbA.Release() + dbBx := populateFoo(t, rowsB) + dbB, err := dbBx.Tx(context.Background()) + require.NoError(t, err) + defer dbB.Release() + mesh, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + proto := "itest" + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + TimestampColumn: "received", + } + + setA := dbset.NewDBSet(dbA, st, testKeyLen, maxDepth) + loadStart := time.Now() + require.NoError(t, setA.EnsureLoaded()) + t.Logf("loaded setA in %v", time.Since(loadStart)) + + setB := dbset.NewDBSet(dbB, st, testKeyLen, maxDepth) + loadStart = time.Now() + require.NoError(t, setB.EnsureLoaded()) + t.Logf("loaded setB in %v", time.Since(loadStart)) + + empty, err := setB.Empty() + require.NoError(t, err) + var x rangesync.KeyBytes + if !empty { + k, err := setB.Items().First() + require.NoError(t, err) + x := k.Clone() + x.Trim(maxDepth) + } + + var tr syncTracer + opts = append(opts, + rangesync.WithClock(clockwork.NewFakeClockAt(clockAt)), + rangesync.WithTracer(&tr), + ) + opts = opts[:len(opts):len(opts)] + + srvPeerID := mesh.Hosts()[0].ID() + pssA := rangesync.NewPairwiseSetSyncer(nil, "test", append( + opts, + rangesync.WithMaxSendRange(1), + // uncomment to enable verbose logging which may slow down tests + // rangesync.WithLogger(log.Named("sideA")), + ), nil) + d := rangesync.NewDispatcher(log) + syncSetA := setA.Copy(false).(*dbset.DBSet) + pssA.Register(d, syncSetA) + srv := server.New(mesh.Hosts()[0], proto, + func(ctx context.Context, req []byte, stream io.ReadWriter) error { + return d.Dispatch(ctx, req, stream) + }, + server.WithTimeout(10*time.Second), + server.WithLog(log)) + + var eg errgroup.Group + + client := server.New(mesh.Hosts()[1], proto, + func(ctx context.Context, req []byte, stream io.ReadWriter) error { + return errors.New("client should not receive requests") + }, + server.WithTimeout(10*time.Second), + server.WithLog(log)) + + defer func() { + cancel() + eg.Wait() + }() + eg.Go(func() error { + return srv.Run(ctx) + }) + + require.Eventually(t, func() bool { + for _, h := range mesh.Hosts() { + if len(h.Mux().Protocols()) == 0 { + return false + } + } + return true + }, time.Second, 10*time.Millisecond) + + pssB := rangesync.NewPairwiseSetSyncer(client, "test", append( + opts, + rangesync.WithMaxSendRange(1), + // uncomment to enable verbose logging which may slow down tests + // rangesync.WithLogger(log.Named("sideB")), + ), nil) + + tStart := time.Now() + syncSetB := setB.Copy(false).(*dbset.DBSet) + require.NoError(t, pssB.Sync(ctx, srvPeerID, syncSetB, x, x)) + t.Logf("synced in %v, sent %d, recv %d", time.Since(tStart), pssB.Sent(), pssB.Received()) + addReceived(t, dbA, setA, syncSetA) + addReceived(t, dbB, setB, syncSetB) + + require.Equal(t, receivedRecent, tr.receivedItems > 0) + require.Equal(t, sentRecent, tr.sentItems > 0) + + if len(combinedItems) == 0 { + return + } + + actItemsA, err := setA.Items().Collect() + require.NoError(t, err) + + actItemsB, err := setB.Items().Collect() + require.NoError(t, err) + + assert.Equal(t, combinedItems, actItemsA) + assert.Equal(t, actItemsA, actItemsB) +} + +func fooR(id string, seconds int) fooRow { + return fooRow{ + rangesync.MustParseHexKeyBytes(id), + startDate.Add(time.Duration(seconds) * time.Second).UnixNano(), + } +} + +func TestP2P(t *testing.T) { + hexID := rangesync.MustParseHexKeyBytes + t.Run("predefined items", func(t *testing.T) { + verifyP2P( + t, []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 10), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 20), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 30), + fooR("abcdef1234567890000000000000000000000000000000000000000000000000", 40), + }, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + []rangesync.KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + hexID("abcdef1234567890000000000000000000000000000000000000000000000000"), + }, + startDate, + false, + false, + ) + }) + t.Run("predefined items 2", func(t *testing.T) { + verifyP2P( + t, []fooRow{ + fooR("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", 10), + fooR("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", 20), + fooR("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", 30), + fooR("72e1adaaf140d809a5da325a197341a453b00807ef8d8995fd3c8079b917c9d7", 40), + fooR("782c24553b0a8cf1d95f632054b7215be192facfb177cfd1312901dd4c9e0bfd", 50), + fooR("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", 60), + }, + []fooRow{ + fooR("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", 11), + fooR("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", 12), + fooR("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", 13), + fooR("90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241", 14), + fooR("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", 15), + fooR("c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128", 16), + }, + []rangesync.KeyBytes{ + hexID("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0"), + hexID("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187"), + hexID("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3"), + hexID("72e1adaaf140d809a5da325a197341a453b00807ef8d8995fd3c8079b917c9d7"), + hexID("782c24553b0a8cf1d95f632054b7215be192facfb177cfd1312901dd4c9e0bfd"), + hexID("90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241"), + hexID("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5"), + hexID("c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128"), + }, + startDate, + false, + false, + ) + }) + t.Run("predefined items 3", func(t *testing.T) { + verifyP2P( + t, []fooRow{ + fooR("08addda193ce5c8dfa56d58efaaaa51ccb534738027c4c73631f76811702e54f", 5), + fooR("112d34ac1724faa17502e9f1654808daa43d8e99c384c42faeccc6c713993079", 3), + fooR("8599b0264623ede5d198fd2caa537720e011ce17bd9f34c140de269f472a1126", 4), + fooR("9e8dc977998b3cbc30071202cb8ebb0c8bfa2c400fd28067f6d43c7e92acd077", 2), + fooR("a67249d334bd0c68e92b4c6d8716cdc218130c0e765838e52890133d07d35d48", 0), + fooR("e7f3c0ecf1410711cf16d8188dc0075f10d17c95208bdbf7a5c910a0ecb68085", 1), + }, + []fooRow{ + fooR("112d34ac1724faa17502e9f1654808daa43d8e99c384c42faeccc6c713993079", 3), + fooR("8599b0264623ede5d198fd2caa537720e011ce17bd9f34c140de269f472a1126", 4), + fooR("9e8dc977998b3cbc30071202cb8ebb0c8bfa2c400fd28067f6d43c7e92acd077", 2), + fooR("a67249d334bd0c68e92b4c6d8716cdc218130c0e765838e52890133d07d35d48", 0), + fooR("dc5938b62a49a31e947d48d85cf358a77dbbed0f3ad5d06e2df63da3cbe7c80a", 5), + fooR("e7f3c0ecf1410711cf16d8188dc0075f10d17c95208bdbf7a5c910a0ecb68085", 1), + }, + []rangesync.KeyBytes{ + hexID("08addda193ce5c8dfa56d58efaaaa51ccb534738027c4c73631f76811702e54f"), + hexID("112d34ac1724faa17502e9f1654808daa43d8e99c384c42faeccc6c713993079"), + hexID("8599b0264623ede5d198fd2caa537720e011ce17bd9f34c140de269f472a1126"), + hexID("9e8dc977998b3cbc30071202cb8ebb0c8bfa2c400fd28067f6d43c7e92acd077"), + hexID("a67249d334bd0c68e92b4c6d8716cdc218130c0e765838e52890133d07d35d48"), + hexID("dc5938b62a49a31e947d48d85cf358a77dbbed0f3ad5d06e2df63da3cbe7c80a"), + hexID("e7f3c0ecf1410711cf16d8188dc0075f10d17c95208bdbf7a5c910a0ecb68085"), + }, + startDate, + false, + false, + ) + }) + t.Run("predefined items with recent", func(t *testing.T) { + verifyP2P( + t, []fooRow{ + fooR("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236", 10), + fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 20), + fooR("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90", 30), + fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 40), + }, + []fooRow{ + fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 11), + fooR("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701", 12), + fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 13), + fooR("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f", 14), + }, + []rangesync.KeyBytes{ + hexID("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), + hexID("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), + hexID("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), + hexID("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), + hexID("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), + hexID("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), + }, + startDate.Add(time.Minute), + true, + true, + rangesync.WithRecentTimeSpan(48*time.Second), + ) + }) + t.Run("empty to non-empty", func(t *testing.T) { + verifyP2P( + t, nil, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + []rangesync.KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate, + false, + false, + ) + }) + t.Run("empty to non-empty with recent", func(t *testing.T) { + verifyP2P( + t, nil, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + []rangesync.KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate.Add(time.Minute), + true, + true, + rangesync.WithRecentTimeSpan(48*time.Second), + ) + }) + t.Run("non-empty to empty with recent", func(t *testing.T) { + verifyP2P( + t, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + nil, + []rangesync.KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate.Add(time.Minute), + // no actual recent exchange happens due to the initial EmptySet message + false, + false, + rangesync.WithRecentTimeSpan(48*time.Second), + ) + }) + t.Run("empty to empty", func(t *testing.T) { + verifyP2P(t, nil, nil, nil, startDate, false, false) + }) + t.Run("random test", func(t *testing.T) { + // higher values for "stress testing": + // const nShared = 8000000 + // const nUniqueA = 100 + // const nUniqueB = 80000 + const nShared = 80000 + const nUniqueA = 400 + const nUniqueB = 800 + + combined := make([]rangesync.KeyBytes, 0, nShared+nUniqueA+nUniqueB) + rowsA := make([]fooRow, nShared+nUniqueA) + for i := range rowsA { + k := rangesync.RandomKeyBytes(testKeyLen) + rowsA[i] = fooRow{ + id: k, + ts: startDate.Add(time.Duration(i) * time.Second).UnixNano(), + } + combined = append(combined, k) + } + rowsB := make([]fooRow, nShared+nUniqueB) + for i := range rowsB { + if i < nShared { + rowsB[i] = fooRow{ + id: slices.Clone(rowsA[i].id), + ts: rowsA[i].ts, + } + } else { + k := rangesync.RandomKeyBytes(testKeyLen) + rowsB[i] = fooRow{ + id: k, + ts: startDate.Add(time.Duration(i) * time.Second).UnixNano(), + } + combined = append(combined, k) + } + } + slices.SortFunc(combined, func(a, b rangesync.KeyBytes) int { + return a.Compare(b) + }) + verifyP2P(t, rowsA, rowsB, combined, startDate, false, false) + }) +} diff --git a/sync2/fptree/dbbackedstore.go b/sync2/fptree/dbbackedstore.go new file mode 100644 index 0000000000..d0a3e6e72f --- /dev/null +++ b/sync2/fptree/dbbackedstore.go @@ -0,0 +1,75 @@ +package fptree + +import ( + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +// DBBackedStore is an implementation of IDStore that keeps track of the rows in a +// database table, using an FPTree to store items that have arrived from a sync peer. +type DBBackedStore struct { + *sqlstore.SQLIDStore + *FPTree +} + +var _ sqlstore.IDStore = &DBBackedStore{} + +// NewDBBackedStore creates a new DB-backed store. +// sizeHint is the expected number of items added to the store via RegisterHash _after_ +// the store is created. +func NewDBBackedStore( + db sql.Executor, + sts *sqlstore.SyncedTableSnapshot, + sizeHint int, + keyLen int, +) *DBBackedStore { + return &DBBackedStore{ + SQLIDStore: sqlstore.NewSQLIDStore(db, sts, keyLen), + FPTree: NewFPTreeWithValues(sizeHint, keyLen), + } +} + +// Clone creates a copy of the store. +// Implements IDStore.Clone. +func (s *DBBackedStore) Clone() sqlstore.IDStore { + return &DBBackedStore{ + SQLIDStore: s.SQLIDStore.Clone().(*sqlstore.SQLIDStore), + FPTree: s.FPTree.Clone().(*FPTree), + } +} + +// RegisterKey adds a hash to the store, using the FPTree so that the underlying database +// table is unchanged. +// Implements IDStore. +func (s *DBBackedStore) RegisterKey(k rangesync.KeyBytes) error { + return s.FPTree.RegisterKey(k) +} + +// All returns all the items currently in the store. +// Implements IDStore. +func (s *DBBackedStore) All() rangesync.SeqResult { + return rangesync.CombineSeqs(nil, s.SQLIDStore.All(), s.FPTree.All()) +} + +// From returns all the items in the store that are greater than or equal to the given key. +// Implements IDStore. +func (s *DBBackedStore) From(from rangesync.KeyBytes, sizeHint int) rangesync.SeqResult { + return rangesync.CombineSeqs( + from, + // There may be fewer than sizeHint to be loaded from the database as some + // may be in FPTree, but for most cases that will do. + s.SQLIDStore.From(from, sizeHint), + s.FPTree.From(from, sizeHint)) +} + +// SetSnapshot sets the table snapshot to be used by the store. +func (s *DBBackedStore) SetSnapshot(sts *sqlstore.SyncedTableSnapshot) { + s.SQLIDStore.SetSnapshot(sts) + s.FPTree.Clear() +} + +// Release releases resources used by the store. +func (s *DBBackedStore) Release() { + s.FPTree.Release() +} diff --git a/sync2/fptree/dbbackedstore_test.go b/sync2/fptree/dbbackedstore_test.go new file mode 100644 index 0000000000..523e066141 --- /dev/null +++ b/sync2/fptree/dbbackedstore_test.go @@ -0,0 +1,96 @@ +package fptree + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +func TestDBBackedStore(t *testing.T) { + const keyLen = 12 + db := sql.InMemoryTest(t) + _, err := db.Exec( + fmt.Sprintf("create table foo(id char(%d) not null primary key, received int)", keyLen), + nil, nil) + require.NoError(t, err) + for _, row := range []struct { + id rangesync.KeyBytes + ts int64 + }{ + { + id: rangesync.KeyBytes{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 100, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 200, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 300, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 400, + }, + } { + _, err := db.Exec( + "insert into foo (id, received) values (?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, row.id) + stmt.BindInt64(2, row.ts) + }, nil) + require.NoError(t, err) + } + st := sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + TimestampColumn: "received", + } + sts, err := st.Snapshot(db) + require.NoError(t, err) + + store := NewDBBackedStore(db, sts, 0, keyLen) + actualIDs, err := store.From(rangesync.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 5).FirstN(5) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, // wrapped around + }, actualIDs) + + actualIDs1, err := store.All().FirstN(5) + require.NoError(t, err) + require.Equal(t, actualIDs, actualIDs1) + + sr, count := store.Since(rangesync.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 300) + require.Equal(t, 2, count) + actualIDs, err = sr.FirstN(3) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{ + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, // wrapped around + }, actualIDs) + + require.NoError(t, store.RegisterKey(rangesync.KeyBytes{0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0})) + require.NoError(t, store.RegisterKey(rangesync.KeyBytes{0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0})) + sr = store.From(rangesync.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1) + actualIDs, err = sr.FirstN(6) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0}, + }, actualIDs) +} diff --git a/sync2/fptree/export_test.go b/sync2/fptree/export_test.go new file mode 100644 index 0000000000..cd43cf160c --- /dev/null +++ b/sync2/fptree/export_test.go @@ -0,0 +1,20 @@ +package fptree + +import ( + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +var ErrEasySplitFailed = errEasySplitFailed + +func (ft *FPTree) EasySplit(x, y rangesync.KeyBytes, limit int) (sr SplitResult, err error) { + return ft.easySplit(x, y, limit) +} + +func (ft *FPTree) PoolNodeCount() int { + return ft.np.nodeCount() +} + +func (ft *FPTree) IDStore() sqlstore.IDStore { + return ft.idStore +} diff --git a/sync2/fptree/fptree.go b/sync2/fptree/fptree.go new file mode 100644 index 0000000000..7dafefe63c --- /dev/null +++ b/sync2/fptree/fptree.go @@ -0,0 +1,1295 @@ +package fptree + +import ( + "errors" + "fmt" + "io" + "runtime" + "strconv" + "strings" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +var errEasySplitFailed = errors.New("easy split failed") + +const ( + FingerprintSize = rangesync.FingerprintSize + // sizeHintCoef is used to calculate the number of pool entries to preallocate for + // an FPTree based on the expected number of items which this tree may contain. + sizeHintCoef = 2.1 +) + +// FPResult represents the result of a range fingerprint query against FPTree, as returned +// by FingerprintInterval. +type FPResult struct { + // Range fingerprint + FP rangesync.Fingerprint + // Number of items in the range + Count uint32 + // Interval type: -1 for normal, 0 for the whole set, 1 for wrapped around ("inverse") + IType int + // Items in the range + Items rangesync.SeqResult + // The item following the range + Next rangesync.KeyBytes +} + +// SplitResult represents the result of a split operation. +type SplitResult struct { + // The two parts of the inteval + Part0, Part1 FPResult + // Moddle point value + Middle rangesync.KeyBytes +} + +// aggContext is the context used for aggregation operations. +type aggContext struct { + // nodePool used by the tree + np *nodePool + // Bounds of the interval being aggregated + x, y rangesync.KeyBytes + // The current fingerprint of the items aggregated so far, since the beginning or + // after the split ("easy split") + fp rangesync.Fingerprint + // The fingerprint of the items aggregated in the first part of the split + fp0 rangesync.Fingerprint + // Number of items aggregated so far, since the beginning or after the split + // ("easy split") + count uint32 + // Number of items aggregated in the first part of the split + count0 uint32 + // Interval type: -1 for normal, 0 for the whole set, 1 for wrapped around ("inverse") + itype int + // Maximum remaining number of items to aggregate. + limit int + // The number of items aggregated so far. + total uint32 + // The resulting item sequence. + items rangesync.SeqResult + // The item immediately following the aggregated items. + next rangesync.KeyBytes + // The prefix corresponding to the last aggregated node. + lastPrefix *prefix + // The prefix corresponding to the last aggregated node in the first part of the split. + lastPrefix0 *prefix + // Whether the aggregation is being done for an "easy split" (split operation + // without querying the underlying IDStore). + easySplit bool +} + +// prefixAtOrAfterX verifies that the any key with the prefix p is at or after x. +// It can be used for the whole interval in case of a normal interval. +// With inverse intervals, it should only be used when processing the [x, max) part of the +// interval. +func (ac *aggContext) prefixAtOrAfterX(p prefix) bool { + b := make(rangesync.KeyBytes, len(ac.x)) + p.minID(b) + return b.Compare(ac.x) >= 0 +} + +// prefixBelowY verifies that the any key with the prefix p is below y. +// It can be used for the whole interval in case of a normal interval. +// With inverse intervals, it should only be used when processing the [0, y) part of the +// interval. +func (ac *aggContext) prefixBelowY(p prefix) bool { + b := make(rangesync.KeyBytes, len(ac.y)) + // If p.idAfter(b) is true, this means there's wraparound and + // b is zero whereas all the possible keys beginning with prefix p + // are non-zero. In this case, there can be no key y such that + // all the keys beginning with prefix p are below y. + return !p.idAfter(b) && b.Compare(ac.y) <= 0 +} + +// fingerprintAtOrAfterX verifies that the specified fingerprint, which should be derived +// from a single key, is at or after x bound of the interval. +func (ac *aggContext) fingreprintAtOrAfterX(fp rangesync.Fingerprint) bool { + k := make(rangesync.KeyBytes, len(ac.x)) + copy(k, fp[:]) + return k.Compare(ac.x) >= 0 +} + +// fingerprintBelowY verifies that the specified fingerprint, which should be derived from a +// single key, is below y bound of the interval. +func (ac *aggContext) fingreprintBelowY(fp rangesync.Fingerprint) bool { + k := make(rangesync.KeyBytes, len(ac.x)) + copy(k, fp[:]) + k[:FingerprintSize].Inc() // 1 after max key derived from the fingerprint + return k.Compare(ac.y) <= 0 +} + +// nodeAtOrAfterX verifies that the node with the given index is at or after x bound of the +// interval. +func (ac *aggContext) nodeAtOrAfterX(idx nodeIndex, p prefix) bool { + count, fp, _ := ac.np.info(idx) + if count == 1 { + v := ac.np.value(idx) + if v != nil { + return v.Compare(ac.x) >= 0 + } + return ac.fingreprintAtOrAfterX(fp) + } + return ac.prefixAtOrAfterX(p) +} + +// nodeBelowY verifies that the node with the given index is below y bound of the interval. +func (ac *aggContext) nodeBelowY(idx nodeIndex, p prefix) bool { + count, fp, _ := ac.np.info(idx) + if count == 1 { + v := ac.np.value(idx) + if v != nil { + return v.Compare(ac.y) < 0 + } + return ac.fingreprintBelowY(fp) + } + return ac.prefixBelowY(p) +} + +// pruneX returns true if the specified node can be pruned during left-aggregation because +// all of its keys are below the x bound of the interval. +func (ac *aggContext) pruneX(idx nodeIndex, p prefix) bool { + b := make(rangesync.KeyBytes, len(ac.x)) + if !p.idAfter(b) && b.Compare(ac.x) <= 0 { + // idAfter derived from the prefix is at or below y => prune + return true + } + count, fp, _ := ac.np.info(idx) + if count > 1 { + // node has count > 1, so we can't use its fingerprint or value to + // determine if it's at or after X + return false + } + k := ac.np.value(idx) + if k != nil { + return k.Compare(ac.x) < 0 + } + + k = make(rangesync.KeyBytes, len(ac.x)) + copy(k, fp[:]) + k[:FingerprintSize].Inc() // 1 after max key derived from the fingerprint + return k.Compare(ac.x) <= 0 +} + +// pruneY returns true if the specified node can be pruned during right-aggregation +// because all of its keys are at or after the y bound of the interval. +func (ac *aggContext) pruneY(idx nodeIndex, p prefix) bool { + b := make(rangesync.KeyBytes, len(ac.y)) + p.minID(b) + if b.Compare(ac.y) >= 0 { + // min ID derived from the prefix is at or after y => prune + return true + } + + count, fp, _ := ac.np.info(idx) + if count > 1 { + // node has count > 1, so we can't use its fingerprint or value to + // determine if it's below y + return false + } + k := ac.np.value(idx) + if k == nil { + k = make(rangesync.KeyBytes, len(ac.y)) + copy(k, fp[:]) + } + return k.Compare(ac.y) >= 0 +} + +// switchToSecondPart switches aggregation to the second part of the "easy split". +func (ac *aggContext) switchToSecondPart() { + ac.limit = -1 + ac.fp0 = ac.fp + ac.count0 = ac.count + ac.lastPrefix0 = ac.lastPrefix + clear(ac.fp[:]) + ac.count = 0 + ac.lastPrefix = nil +} + +// maybeIncludeNode returns tries to include the full contents of the specified node in +// the aggregation and returns if it succeeded, based on the remaining limit and the numer +// of items in the node. +// It also handles "easy split" happening at the node. +func (ac *aggContext) maybeIncludeNode(idx nodeIndex, p prefix) bool { + count, fp, leaf := ac.np.info(idx) + switch { + case ac.limit < 0: + case uint32(ac.limit) >= count: + ac.limit -= int(count) + case !ac.easySplit || !leaf: + return false + case ac.count == 0: + // We're doing a split and this node is over the limit, but the first part + // is still empty so we include this node in the first part and + // then switch to the second part + ac.limit = 0 + default: + // We're doing a split and this node is over the limit, so store count and + // fingerprint for the first part and include the current node in the + // second part + ac.limit = -1 + ac.fp0 = ac.fp + ac.count0 = ac.count + ac.lastPrefix0 = ac.lastPrefix + copy(ac.fp[:], fp[:]) + ac.count = count + ac.lastPrefix = &p + return true + } + ac.fp.Update(fp[:]) + ac.count += count + ac.lastPrefix = &p + if ac.easySplit && ac.limit == 0 { + // We're doing a split and this node is exactly at the limit, or it was + // above the limit but first part was still empty, so store count and + // fingerprint for the first part which includes the current node and zero + // out cound and figerprint for the second part + ac.switchToSecondPart() + } + return true +} + +// FPTree is a binary tree data structure designed to perform range fingerprint queries +// efficiently. +// FPTree can work on its own, with fingerprint query complexity being O(log n). +// It can also be backed by an IDStore with a depth limit the binary tree, in which +// case the query efficiency degrades with the number of items growing. +// O(log n) query efficiency can be retained in this case for queries which +// have the number of non-zero bits, starting from the high bit, below maxDepth. +// FPTree does not do any special balancing and relies on the IDs added on it being +// uniformly distributed, which is the case for the IDs based on cryptographic hashes. +type FPTree struct { + trace + idStore sqlstore.IDStore + np *nodePool + root nodeIndex + keyLen int + maxDepth int +} + +var _ sqlstore.IDStore = &FPTree{} + +// NewFPTreeWithValues creates an FPTree which also stores the items themselves and does +// not make use of a backing IDStore. +// sizeHint specifies the approximage expected number of items. +// keyLen specifies the number of bytes in keys used. +func NewFPTreeWithValues(sizeHint, keyLen int) *FPTree { + return NewFPTree(sizeHint, nil, keyLen, 0) +} + +// NewFPTree creates an FPTree of limited depth backed by an IDStore. +// sizeHint specifies the approximage expected number of items. +// keyLen specifies the number of bytes in keys used. +func NewFPTree(sizeHint int, idStore sqlstore.IDStore, keyLen, maxDepth int) *FPTree { + var np nodePool + if sizeHint > 0 { + size := int(float64(sizeHint) * sizeHintCoef) + if maxDepth > 0 { + size = min(size, 1<<(maxDepth+1)) + } + np.init(size) + } + if idStore == nil && maxDepth != 0 { + panic("BUG: newFPTree: no idStore, but maxDepth specified") + } + ft := &FPTree{ + np: &np, + idStore: idStore, + root: noIndex, + keyLen: keyLen, + maxDepth: maxDepth, + } + runtime.SetFinalizer(ft, (*FPTree).Release) + return ft +} + +// traverse traverses the subtree rooted in idx in order and calls the given function for +// each item. +func (ft *FPTree) traverse(idx nodeIndex, yield func(rangesync.KeyBytes) bool) (res bool) { + ft.enter("traverse: idx %d", idx) + defer func() { + ft.leave(res) + }() + if idx == noIndex { + ft.log("no index") + return true + } + l := ft.np.left(idx) + r := ft.np.right(idx) + if l == noIndex && r == noIndex { + v := ft.np.value(idx) + if v != nil { + ft.log("yield value %s", v.ShortString()) + } + if v != nil && !yield(v) { + return false + } + return true + } + return ft.traverse(l, yield) && ft.traverse(r, yield) +} + +// travereFrom traverses the subtree rooted in idx in order and calls the given function for +// each item starting from the given key. +func (ft *FPTree) traverseFrom( + idx nodeIndex, + p prefix, + from rangesync.KeyBytes, + yield func(rangesync.KeyBytes) bool, +) (res bool) { + ft.enter("traverseFrom: idx %d p %s from %s", idx, p, from) + defer func() { + ft.leave(res) + }() + if idx == noIndex { + return true + } + if p == emptyPrefix || ft.np.leaf(idx) { + v := ft.np.value(idx) + if v != nil && v.Compare(from) >= 0 { + ft.log("yield value %s", v.ShortString()) + if !yield(v) { + return false + } + } + return true + } + if !p.highBit() { + return ft.traverseFrom(ft.np.left(idx), p.shift(), from, yield) && + ft.traverse(ft.np.right(idx), yield) + } else { + return ft.traverseFrom(ft.np.right(idx), p.shift(), from, yield) + } +} + +// All returns all the items currently in the tree (including those in the IDStore). +// Implements sqlstore.All. +func (ft *FPTree) All() rangesync.SeqResult { + ft.np.lockRead() + defer ft.np.unlockRead() + switch { + case ft.root == noIndex: + return rangesync.EmptySeqResult() + case ft.storeValues(): + return rangesync.SeqResult{ + Seq: func(yield func(rangesync.KeyBytes) bool) { + for { + if !ft.traverse(ft.root, yield) { + break + } + } + }, + Error: rangesync.NoSeqError, + } + } + return ft.idStore.All() +} + +// From returns all the items in the tree that are greater than or equal to the given key. +// Implements sqlstore.IDStore. +func (ft *FPTree) From(from rangesync.KeyBytes, sizeHint int) rangesync.SeqResult { + ft.np.lockRead() + defer ft.np.unlockRead() + switch { + case ft.root == noIndex: + return rangesync.EmptySeqResult() + case ft.storeValues(): + return rangesync.SeqResult{ + Seq: func(yield func(rangesync.KeyBytes) bool) { + p := prefixFromKeyBytes(from) + if !ft.traverseFrom(ft.root, p, from, yield) { + return + } + for { + if !ft.traverse(ft.root, yield) { + break + } + } + }, + Error: rangesync.NoSeqError, + } + } + return ft.idStore.From(from, sizeHint) +} + +// Release releases resources used by the tree. +// Implements sqlstore.IDStore. +func (ft *FPTree) Release() { + ft.np.lockWrite() + defer ft.np.unlockWrite() + ft.np.release(ft.root) + ft.root = noIndex + if ft.idStore != nil { + ft.idStore.Release() + } +} + +// Clear removes all items from the tree. +// It should only be used with trees that were created using NewFPtreeWithValues. +func (ft *FPTree) Clear() { + if !ft.storeValues() { + // if we have an idStore, it can't be cleared and thus the tree can't be + // cleared either + panic("BUG: can only clear fpTree with values") + } + ft.Release() +} + +// Clone makes a copy of the tree. +// The copy operation is thread-safe and has complexity of O(1). +func (ft *FPTree) Clone() sqlstore.IDStore { + ft.np.lockWrite() + defer ft.np.unlockWrite() + if ft.root != noIndex { + ft.np.ref(ft.root) + } + var idStore sqlstore.IDStore + if !ft.storeValues() { + idStore = ft.idStore.Clone() + } + return &FPTree{ + np: ft.np, + idStore: idStore, + root: ft.root, + keyLen: ft.keyLen, + maxDepth: ft.maxDepth, + } +} + +// pushLeafDown pushes a leaf node down the tree when the node's path matches that of the +// new to be added, splitting it if necessary. +func (ft *FPTree) pushLeafDown( + idx nodeIndex, + replace bool, + singleFP, prevFP rangesync.Fingerprint, + depth int, + curCount uint32, + value, prevValue rangesync.KeyBytes, +) (newIdx nodeIndex) { + if idx == noIndex { + panic("BUG: pushLeafDown on a nonexistent node") + } + // Once we stumble upon a node with refCount > 1, we no longer can replace nodes + // as they're also referenced by another tree. + if replace && ft.np.refCount(idx) > 1 { + ft.np.releaseOne(idx) + replace = false + } + replace = replace && ft.np.refCount(idx) == 1 + replaceIdx := noIndex + if replace { + replaceIdx = idx + } + fpCombined := rangesync.CombineFingerprints(singleFP, prevFP) + if ft.maxDepth != 0 && depth == ft.maxDepth { + newIdx = ft.np.add(fpCombined, curCount+1, noIndex, noIndex, nil, replaceIdx) + return newIdx + } + if curCount != 1 { + panic("BUG: pushDown of non-1-leaf below maxDepth") + } + dirA := singleFP.BitFromLeft(depth) + dirB := prevFP.BitFromLeft(depth) + if dirA == dirB { + // TODO: in the proper radix tree, these 1-child nodes should never be + // created, accumulating the prefix instead + childIdx := ft.pushLeafDown(idx, replace, singleFP, prevFP, depth+1, 1, value, prevValue) + if dirA { + newIdx = ft.np.add(fpCombined, 2, noIndex, childIdx, nil, noIndex) + } else { + newIdx = ft.np.add(fpCombined, 2, childIdx, noIndex, nil, noIndex) + } + } else { + idxA := ft.np.add(singleFP, 1, noIndex, noIndex, value, noIndex) + idxB := ft.np.add(prevFP, curCount, noIndex, noIndex, prevValue, replaceIdx) + if dirA { + newIdx = ft.np.add(fpCombined, 2, idxB, idxA, nil, noIndex) + } else { + newIdx = ft.np.add(fpCombined, 2, idxA, idxB, nil, noIndex) + } + } + return newIdx +} + +// addValue adds a value to the subtree rooted in idx. +func (ft *FPTree) addValue( + idx nodeIndex, + replace bool, + fp rangesync.Fingerprint, + depth int, + value rangesync.KeyBytes, +) (newIdx nodeIndex) { + if idx == noIndex { + newIdx = ft.np.add(fp, 1, noIndex, noIndex, value, noIndex) + return newIdx + } + // Once we stumble upon a node with refCount > 1, we no longer can replace nodes + // as they're also referenced by another tree. + if replace && ft.np.refCount(idx) > 1 { + ft.np.releaseOne(idx) + replace = false + } + count, nodeFP, leaf := ft.np.info(idx) + left := ft.np.left(idx) + right := ft.np.right(idx) + nodeValue := ft.np.value(idx) + if leaf { + if count != 1 && (ft.maxDepth == 0 || depth != ft.maxDepth) { + panic("BUG: unexpected leaf node") + } + // we're at a leaf node, need to push down the old fingerprint, or, + // if we've reached the max depth, just update the current node + return ft.pushLeafDown(idx, replace, fp, nodeFP, depth, count, value, nodeValue) + } + replaceIdx := noIndex + if replace { + replaceIdx = idx + } + fpCombined := rangesync.CombineFingerprints(fp, nodeFP) + if fp.BitFromLeft(depth) { + newRight := ft.addValue(right, replace, fp, depth+1, value) + newIdx := ft.np.add(fpCombined, count+1, left, newRight, nil, replaceIdx) + if !replace && left != noIndex { + // the original node is not being replaced, so the reused left + // node has acquired another reference + ft.np.ref(left) + } + return newIdx + } else { + newLeft := ft.addValue(left, replace, fp, depth+1, value) + newIdx := ft.np.add(fpCombined, count+1, newLeft, right, nil, replaceIdx) + if !replace && right != noIndex { + // the original node is not being replaced, so the reused right + // node has acquired another reference + ft.np.ref(right) + } + return newIdx + } +} + +// AddStoredKey adds a key to the tree, assuming that either the tree doesn't have an +// IDStore ar the IDStore already contains the key. +func (ft *FPTree) AddStoredKey(k rangesync.KeyBytes) { + var fp rangesync.Fingerprint + fp.Update(k) + ft.log("addStoredHash: h %s fp %s", k, fp) + var v rangesync.KeyBytes + if ft.storeValues() { + v = k + } + ft.np.lockWrite() + defer ft.np.unlockWrite() + ft.root = ft.addValue(ft.root, true, fp, 0, v) +} + +// RegisterKey registers a key in the tree. +// If the tree has an IDStore, the key is also registered with the IDStore. +func (ft *FPTree) RegisterKey(k rangesync.KeyBytes) error { + ft.log("addHash: k %s", k) + if !ft.storeValues() { + if err := ft.idStore.RegisterKey(k); err != nil { + return err + } + } + ft.AddStoredKey(k) + return nil +} + +// storeValues returns true if the tree stores the values (has no IDStore). +func (ft *FPTree) storeValues() bool { + return ft.idStore == nil +} + +// CheckKey returns true if the tree contains or may contain the given key. +// If this function returns false, the tree definitely doesn't contain the key. +// If this function returns true and the tree stores the values, the key is definitely +// contained in the tree. +// If this function returns true and the tree doesn't store the values, the key may be +// contained in the tree. +func (ft *FPTree) CheckKey(k rangesync.KeyBytes) bool { + // We're unlikely to be able to find a node with the full prefix, but if we can + // find a leaf node with matching partial prefix, that's good enough except + // that we also need to check the node's fingerprint. + idx, _, _ := ft.followPrefix(ft.root, prefixFromKeyBytes(k), emptyPrefix) + if idx == noIndex { + return false + } + count, fp, _ := ft.np.info(idx) + if count != 1 { + return true + } + var kFP rangesync.Fingerprint + kFP.Update(k) + return fp == kFP +} + +// followPrefix follows the bit prefix p from the node idx. +func (ft *FPTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeIndex, rp prefix, found bool) { + ft.enter("followPrefix: from %d p %s highBit %v", from, p, p.highBit()) + defer func() { ft.leave(idx, rp, found) }() + + for from != noIndex { + switch { + case p.len() == 0: + return from, followed, true + case ft.np.leaf(from): + return from, followed, false + case p.highBit(): + from = ft.np.right(from) + p = p.shift() + followed = followed.right() + default: + from = ft.np.left(from) + p = p.shift() + followed = followed.left() + } + } + + return noIndex, followed, false +} + +// aggregateEdge aggregates an edge of the interval, which can be bounded by x, y, both x +// and y or none of x and y, have a common prefix and optionally bounded by a limit of N of +// aggregated items. +// It returns a boolean indicating whether the limit or the right edge (y) was reached and +// an error, if any. +func (ft *FPTree) aggregateEdge( + x, y rangesync.KeyBytes, + idx nodeIndex, + p prefix, + ac *aggContext, +) (cont bool, err error) { + ft.enter("aggregateEdge: x %s y %s p %s limit %d count %d", x, y, p, ac.limit, ac.count) + defer func() { + ft.leave(ac.limit, ac.count, cont, err) + }() + if ft.storeValues() { + panic("BUG: aggregateEdge should not be used for tree with values") + } + if ac.easySplit { + // easySplit means we should not be querying the database, + // so we'll have to retry using slower strategy + return false, errEasySplitFailed + } + if ac.limit == 0 && ac.next != nil { + ft.log("aggregateEdge: limit is 0 and end already set") + return false, nil + } + var startFrom rangesync.KeyBytes + if x == nil { + startFrom = make(rangesync.KeyBytes, ft.keyLen) + p.minID(startFrom) + } else { + startFrom = x + } + ft.log("aggregateEdge: startFrom %s", startFrom) + sizeHint := int(ft.np.count(idx)) + switch { + case ac.limit == 0: + sizeHint = 1 + case ac.limit > 0: + sizeHint = min(ac.limit, sizeHint) + } + sr := ft.From(startFrom, sizeHint) + if ac.limit == 0 { + next, err := sr.First() + if err != nil { + return false, err + } + ac.next = next.Clone() + if x != nil { + ft.log("aggregateEdge: limit 0: x is not nil, setting start to %s", ac.next.String()) + ac.items = sr + } + ft.log("aggregateEdge: limit is 0 at %s", ac.next.String()) + return false, nil + } + if x != nil { + ac.items = sr + ft.log("aggregateEdge: x is not nil, setting start to %v", sr) + } + + n := ft.np.count(ft.root) + for id := range sr.Seq { + if ac.limit == 0 && !ac.easySplit { + ac.next = id.Clone() + ft.log("aggregateEdge: limit exhausted") + return false, nil + } + if n == 0 { + break + } + ft.log("aggregateEdge: ID %s", id) + if y != nil && id.Compare(y) >= 0 { + ac.next = id.Clone() + ft.log("aggregateEdge: ID is over Y: %s", id) + return false, nil + } + if !p.match(id) { + ft.log("aggregateEdge: ID doesn't match the prefix: %s", id) + ac.lastPrefix = &p + return true, nil + } + if ac.limit == 0 { + ft.log("aggregateEdge: switching to second part of easySplit") + ac.switchToSecondPart() + } + ac.fp.Update(id) + ac.count++ + if ac.limit > 0 { + ac.limit-- + } + n-- + } + if err := sr.Error(); err != nil { + return false, err + } + + return true, nil +} + +// aggregateUpToLimit aggregates the subtree rooted in idx up to the limit of N of nodes. +func (ft *FPTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (cont bool, err error) { + ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count0 %d cur_count %d", idx, p, ac.limit, + ac.fp, ac.count0, ac.count) + defer func() { + ft.leave(ac.fp, ac.count0, ac.count, err) + }() + switch { + case idx == noIndex: + ft.log("stop: no node") + return true, nil + case ac.limit == 0: + return false, nil + case ac.maybeIncludeNode(idx, p): + // node is fully included + ft.log("included fully, lastPrefix = %s", ac.lastPrefix) + return true, nil + case ft.np.leaf(idx): + // reached the limit on this node, do not need to continue after + // done with it + cont, err := ft.aggregateEdge(nil, nil, idx, p, ac) + if err != nil { + return false, err + } + if cont { + panic("BUG: expected limit not reached") + } + return false, nil + default: + pLeft := p.left() + left := ft.np.left(idx) + if left != noIndex { + if ac.maybeIncludeNode(left, pLeft) { + // left node is fully included, after which + // we need to stop somewhere in the right subtree + ft.log("include left in full") + } else { + // we must stop somewhere in the left subtree, + // and the right subtree is irrelevant unless + // easySplit is being done and we must restart + // after the limit is exhausted + ft.log("descend to the left") + if cont, err := ft.aggregateUpToLimit(left, pLeft, ac); !cont || err != nil { + return cont, err + } + if !ac.easySplit { + return false, nil + } + } + } + ft.log("descend to the right") + return ft.aggregateUpToLimit(ft.np.right(idx), p.right(), ac) + } +} + +// aggregateLeft aggregates the subtree that covers the left subtree of the LCA in case of +// normal intervals, and the subtree that covers [x, MAX] part for the inverse (wrapped +// around) intervals. +func (ft *FPTree) aggregateLeft( + idx nodeIndex, + k rangesync.KeyBytes, + p prefix, + ac *aggContext, +) (cont bool, err error) { + ft.enter("aggregateLeft: idx %d k %s p %s limit %d", idx, k.ShortString(), p, ac.limit) + defer func() { + ft.leave(ac.fp, ac.count0, ac.count, err) + }() + switch { + case idx == noIndex: + // for ac.limit == 0, it's important that we still visit the node + // so that we can get the item immediately following the included items + ft.log("stop: no node") + return true, nil + case ac.limit == 0: + return false, nil + case ac.nodeAtOrAfterX(idx, p) && ac.maybeIncludeNode(idx, p): + ft.log("including node in full: %s limit %d", p, ac.limit) + return true, nil + case (ft.maxDepth != 0 && p.len() == ft.maxDepth) || ft.np.leaf(idx): + if ac.pruneX(idx, p) { + ft.log("node %d p %s pruned", idx, p) + // we've not reached X yet so we should not stop, thus true + return true, nil + } + return ft.aggregateEdge(ac.x, nil, idx, p, ac) + case !k.BitFromLeft(p.len()): + left := ft.np.left(idx) + right := ft.np.right(idx) + ft.log("incl right node %d + go left to node %d", right, left) + cont, err := ft.aggregateLeft(left, k, p.left(), ac) + if !cont || err != nil { + return false, err + } + if right != noIndex { + return ft.aggregateUpToLimit(right, p.right(), ac) + } + return true, nil + default: + right := ft.np.right(idx) + ft.log("go right to node %d", right) + return ft.aggregateLeft(right, k, p.right(), ac) + } +} + +// aggregateRight aggregates the subtree that covers the right subtree of the LCA in case +// of normal intervals, and the subtree that covers [0, y) part for the inverse (wrapped +// around) intervals. +func (ft *FPTree) aggregateRight( + idx nodeIndex, + k rangesync.KeyBytes, + p prefix, + ac *aggContext, +) (cont bool, err error) { + ft.enter("aggregateRight: idx %d k %s p %s limit %d", idx, k.ShortString(), p, ac.limit) + defer func() { + ft.leave(ac.fp, ac.count0, ac.count, err) + }() + switch { + case idx == noIndex: + ft.log("stop: no node") + return true, nil + case ac.limit == 0: + return false, nil + case ac.nodeBelowY(idx, p) && ac.maybeIncludeNode(idx, p): + ft.log("including node in full: %s limit %d", p, ac.limit) + return ac.limit != 0, nil + case (ft.maxDepth != 0 && p.len() == ft.maxDepth) || ft.np.leaf(idx): + if ac.pruneY(idx, p) { + ft.log("node %d p %s pruned", idx, p) + return false, nil + } + return ft.aggregateEdge(nil, ac.y, idx, p, ac) + case !k.BitFromLeft(p.len()): + left := ft.np.left(idx) + ft.log("go left to node %d", left) + return ft.aggregateRight(left, k, p.left(), ac) + default: + left := ft.np.left(idx) + right := ft.np.right(idx) + ft.log("incl left node %d + go right to node %d", left, right) + if left != noIndex { + cont, err := ft.aggregateUpToLimit(left, p.left(), ac) + if !cont || err != nil { + return false, err + } + } + return ft.aggregateRight(ft.np.right(idx), k, p.right(), ac) + } +} + +// aggregateXX aggregtes intervals of form [x, x) which denotes the whole set. +func (ft *FPTree) aggregateXX(ac *aggContext) (err error) { + // [x, x) interval which denotes the whole set unless + // the limit is specified, in which case we need to start aggregating + // with x and wrap around if necessary + ft.enter("aggregateXX: x %s limit %d", ac.x, ac.limit) + defer func() { + ft.leave(ac, err) + }() + if ft.root == noIndex { + ft.log("empty set (no root)") + } else if ac.maybeIncludeNode(ft.root, emptyPrefix) { + ft.log("whole set") + } else { + // We need to aggregate up to ac.limit number of items starting + // from x and wrapping around if necessary + return ft.aggregateInverse(ac) + } + return nil +} + +// aggregateSimple aggregates simple (normal) intervals of form [x, y) where x < y. +func (ft *FPTree) aggregateSimple(ac *aggContext) (err error) { + // "proper" interval: [x, lca); (lca, y) + ft.enter("aggregateSimple: x %s y %s limit %d", ac.x, ac.y, ac.limit) + defer func() { + ft.leave(ac, err) + }() + p := commonPrefix(ac.x, ac.y) + lcaIdx, lcaPrefix, fullPrefixFound := ft.followPrefix(ft.root, p, emptyPrefix) + ft.log("commonPrefix %s lcaPrefix %s lca %d found %v", p, lcaPrefix, lcaIdx, fullPrefixFound) + switch { + case fullPrefixFound && !ft.np.leaf(lcaIdx): + if lcaPrefix != p { + panic("BUG: bad followedPrefix") + } + if _, err := ft.aggregateLeft(ft.np.left(lcaIdx), ac.x, p.left(), ac); err != nil { + return err + } + if ac.limit != 0 { + if _, err := ft.aggregateRight(ft.np.right(lcaIdx), ac.y, p.right(), ac); err != nil { + return err + } + } + case lcaIdx == noIndex || !ft.np.leaf(lcaIdx): + ft.log("commonPrefix %s NOT found b/c no items have it", p) + case ac.nodeAtOrAfterX(lcaIdx, lcaPrefix) && ac.nodeBelowY(lcaIdx, lcaPrefix) && + ac.maybeIncludeNode(lcaIdx, lcaPrefix): + ft.log("commonPrefix %s -- lca node %d included in full", p, lcaIdx) + case ft.np.leaf(lcaIdx) && ft.np.value(lcaIdx) != nil: + // leaf 1-node with value that could not be included should be skipped + return nil + default: + ft.log("commonPrefix %s -- lca %d", p, lcaIdx) + _, err := ft.aggregateEdge(ac.x, ac.y, lcaIdx, lcaPrefix, ac) + return err + } + return nil +} + +// aggregateInverse aggregates inverse intervals of form [x, y) where x > y. +func (ft *FPTree) aggregateInverse(ac *aggContext) (err error) { + // inverse interval: [min, y); [x, max] + + // First, we handle [x, max] part + // For this, we process the subtree rooted in the LCA of 0x000000... (all 0s) and x + ft.enter("aggregateInverse: x %s y %s limit %d", ac.x, ac.y, ac.limit) + defer func() { + ft.leave(ac, err) + }() + pf0 := preFirst0(ac.x) + idx0, followedPrefix, found := ft.followPrefix(ft.root, pf0, emptyPrefix) + ft.log("pf0 %s idx0 %d found %v followedPrefix %s", pf0, idx0, found, followedPrefix) + switch { + case found && !ft.np.leaf(idx0): + if followedPrefix != pf0 { + panic("BUG: bad followedPrefix") + } + cont, err := ft.aggregateLeft(idx0, ac.x, pf0, ac) + if err != nil { + return err + } + if !cont { + return nil + } + case idx0 == noIndex || !ft.np.leaf(idx0): + // nothing to do + case ac.nodeAtOrAfterX(idx0, followedPrefix) && ac.maybeIncludeNode(idx0, followedPrefix): + // node is fully included + case ac.pruneX(idx0, followedPrefix): + // the node is below X + ft.log("node %d p %s pruned", idx0, followedPrefix) + default: + _, err := ft.aggregateEdge(ac.x, nil, idx0, followedPrefix, ac) + if err != nil { + return err + } + } + + if ac.limit == 0 && !ac.easySplit { + return nil + } + + // Then we handle [min, y) part. + // For this, we process the subtree rooted in the LCA of y and 0xffffff... (all 1s) + pf1 := preFirst1(ac.y) + idx1, followedPrefix, found := ft.followPrefix(ft.root, pf1, emptyPrefix) + ft.log("pf1 %s idx1 %d found %v", pf1, idx1, found) + switch { + case found && !ft.np.leaf(idx1): + if followedPrefix != pf1 { + panic("BUG: bad followedPrefix") + } + if _, err := ft.aggregateRight(idx1, ac.y, pf1, ac); err != nil { + return err + } + case idx1 == noIndex || !ft.np.leaf(idx1): + // nothing to do + case ac.nodeBelowY(idx1, followedPrefix) && ac.maybeIncludeNode(idx1, followedPrefix): + // node is fully included + case ac.pruneY(idx1, followedPrefix): + // the node is at or after Y + ft.log("node %d p %s pruned", idx1, followedPrefix) + return nil + default: + _, err := ft.aggregateEdge(nil, ac.y, idx1, followedPrefix, ac) + if err != nil { + return err + } + } + + return nil +} + +// aggregateInterval aggregates an interval, updating the aggContext accordingly. +func (ft *FPTree) aggregateInterval(ac *aggContext) (err error) { + ft.enter("aggregateInterval: x %s y %s limit %d", ac.x, ac.y, ac.limit) + defer func() { + ft.leave(ac, err) + }() + ac.itype = ac.x.Compare(ac.y) + if ft.root == noIndex { + return nil + } + ac.total = ft.np.count(ft.root) + switch ac.itype { + case 0: + return ft.aggregateXX(ac) + case -1: + return ft.aggregateSimple(ac) + default: + return ft.aggregateInverse(ac) + } +} + +// startFromPrefix returns a SeqResult which begins with the first item that has the +// specified prefix. +func (ft *FPTree) startFromPrefix(ac *aggContext, p prefix) rangesync.SeqResult { + k := make(rangesync.KeyBytes, ft.keyLen) + p.idAfter(k) + ft.log("startFromPrefix: p: %s idAfter: %s", p, k) + return ft.From(k, 1) +} + +// nextFromPrefix return the first item that has the prefix p. +func (ft *FPTree) nextFromPrefix(ac *aggContext, p prefix) (rangesync.KeyBytes, error) { + id, err := ft.startFromPrefix(ac, p).First() + if err != nil { + return nil, err + } + if id == nil { + return nil, nil + } + return id.Clone(), nil +} + +// FingerprintInteval performs a range fingerprint query with specified bounds and limit. +func (ft *FPTree) FingerprintInterval(x, y rangesync.KeyBytes, limit int) (fpr FPResult, err error) { + ft.np.lockRead() + defer ft.np.unlockRead() + return ft.fingerprintInterval(x, y, limit) +} + +func (ft *FPTree) fingerprintInterval(x, y rangesync.KeyBytes, limit int) (fpr FPResult, err error) { + ft.enter("fingerprintInterval: x %s y %s limit %d", x, y, limit) + defer func() { + ft.leave(fpr.FP, fpr.Count, fpr.IType, fpr.Items, fpr.Next, err) + }() + ac := aggContext{np: ft.np, x: x, y: y, limit: limit} + if err := ft.aggregateInterval(&ac); err != nil { + return FPResult{}, err + } + fpr = FPResult{ + FP: ac.fp, + Count: ac.count, + IType: ac.itype, + Items: rangesync.EmptySeqResult(), + } + + if ac.total == 0 { + return fpr, nil + } + + if ac.items.Seq != nil { + ft.log("fingerprintInterval: items %v", ac.items) + fpr.Items = ac.items + } else { + fpr.Items = ft.From(x, 1) + ft.log("fingerprintInterval: start from x: %v", fpr.Items) + } + + if ac.next != nil { + ft.log("fingerprintInterval: next %s", ac.next) + fpr.Next = ac.next + } else if (fpr.IType == 0 && limit < 0) || fpr.Count == 0 { + next, err := fpr.Items.First() + if err != nil { + return FPResult{}, err + } + if next != nil { + fpr.Next = next.Clone() + } + ft.log("fingerprintInterval: next at start %s", fpr.Next) + } else if ac.lastPrefix != nil { + fpr.Next, err = ft.nextFromPrefix(&ac, *ac.lastPrefix) + ft.log("fingerprintInterval: next at lastPrefix %s -> %s", *ac.lastPrefix, fpr.Next) + } else { + next, err := ft.From(y, 1).First() + if err != nil { + return FPResult{}, err + } + fpr.Next = next.Clone() + ft.log("fingerprintInterval: next at y: %s", fpr.Next) + } + + return fpr, nil +} + +// easySplit splits an interval in two parts trying to do it in such way that the first +// part has close to limit items while not making any idStore queries so that the database +// is not accessed. If the split can't be done, which includes the situation where one of +// the sides has 0 items, easySplit returns errEasySplitFailed error. +// easySplit never fails for a tree with values. +func (ft *FPTree) easySplit(x, y rangesync.KeyBytes, limit int) (sr SplitResult, err error) { + ft.enter("easySplit: x %s y %s limit %d", x, y, limit) + defer func() { + ft.leave(sr.Part0.FP, sr.Part0.Count, sr.Part0.IType, sr.Part0.Items, sr.Part0.Next, + sr.Part1.FP, sr.Part1.Count, sr.Part1.IType, sr.Part1.Items, sr.Part1.Next, err) + }() + if limit < 0 { + panic("BUG: easySplit with limit < 0") + } + ac := aggContext{np: ft.np, x: x, y: y, limit: limit, easySplit: true} + if err := ft.aggregateInterval(&ac); err != nil { + return SplitResult{}, err + } + + if ac.total == 0 { + return SplitResult{}, nil + } + + if ac.count0 == 0 || ac.count == 0 { + // need to get some items on both sides for the easy split to succeed + ft.log("easySplit failed: one side missing: count0 %d count %d", ac.count0, ac.count) + return SplitResult{}, errEasySplitFailed + } + + // It should not be possible to have ac.lastPrefix0 == nil or ac.lastPrefix == nil + // if both ac.count0 and ac.count are non-zero, b/c of how + // aggContext.maybeIncludeNode works + if ac.lastPrefix0 == nil || ac.lastPrefix == nil { + panic("BUG: easySplit lastPrefix or lastPrefix0 not set") + } + + // ac.start / ac.end are only set in aggregateEdge which fails with + // errEasySplitFailed if easySplit is enabled, so we can ignore them here + middle := make(rangesync.KeyBytes, ft.keyLen) + ac.lastPrefix0.idAfter(middle) + ft.log("easySplit: lastPrefix0 %s middle %s", ac.lastPrefix0, middle) + items := ft.From(x, 1) + part0 := FPResult{ + FP: ac.fp0, + Count: ac.count0, + IType: ac.itype, + Items: items, + // Next is only used during splitting itself, and thus not included + } + items = ft.startFromPrefix(&ac, *ac.lastPrefix0) + part1 := FPResult{ + FP: ac.fp, + Count: ac.count, + IType: ac.itype, + Items: items, + // Next is only used during splitting itself, and thus not included + } + return SplitResult{ + Part0: part0, + Part1: part1, + Middle: middle, + }, nil +} + +// Split splits an interval in two parts. +func (ft *FPTree) Split(x, y rangesync.KeyBytes, limit int) (sr SplitResult, err error) { + ft.np.lockRead() + defer ft.np.unlockRead() + sr, err = ft.easySplit(x, y, limit) + if err == nil { + return sr, nil + } + if err != errEasySplitFailed { + return SplitResult{}, err + } + + fpr0, err := ft.fingerprintInterval(x, y, limit) + if err != nil { + return SplitResult{}, err + } + + if fpr0.Count == 0 { + return SplitResult{}, errors.New("can't split empty range") + } + + fpr1, err := ft.fingerprintInterval(fpr0.Next, y, -1) + if err != nil { + return SplitResult{}, err + } + + if fpr1.Count == 0 { + return SplitResult{}, errors.New("split produced empty 2nd range") + } + + return SplitResult{ + Part0: fpr0, + Part1: fpr1, + Middle: fpr0.Next, + }, nil +} + +// dumpNode prints the node structure to the writer. +func (ft *FPTree) dumpNode(w io.Writer, idx nodeIndex, indent, dir string) { + if idx == noIndex { + return + } + + count, fp, leaf := ft.np.info(idx) + countStr := strconv.Itoa(int(count)) + if leaf { + countStr = "LEAF:" + countStr + } + var valStr string + if v := ft.np.value(idx); v != nil { + valStr = fmt.Sprintf(" ", v.ShortString()) + } + fmt.Fprintf(w, "%s%sidx=%d %s %s [%d]%s\n", indent, dir, idx, fp, countStr, ft.np.refCount(idx), valStr) + if !leaf { + indent += " " + ft.dumpNode(w, ft.np.left(idx), indent, "l: ") + ft.dumpNode(w, ft.np.right(idx), indent, "r: ") + } +} + +// Dump prints the tree structure to the writer. +func (ft *FPTree) Dump(w io.Writer) { + ft.np.lockRead() + defer ft.np.unlockRead() + if ft.root == noIndex { + fmt.Fprintln(w, "empty tree") + } else { + ft.dumpNode(w, ft.root, "", "") + } +} + +// DumpToString returns the tree structure as a string. +func (ft *FPTree) DumpToString() string { + var sb strings.Builder + ft.Dump(&sb) + return sb.String() +} + +// Count returns the number of items in the tree. +func (ft *FPTree) Count() int { + ft.np.lockRead() + defer ft.np.unlockRead() + if ft.root == noIndex { + return 0 + } + return int(ft.np.count(ft.root)) +} + +// EnableTrace enables or disables tracing for the tree. +func (ft *FPTree) EnableTrace(enable bool) { + ft.traceEnabled = enable +} diff --git a/sync2/fptree/fptree_test.go b/sync2/fptree/fptree_test.go new file mode 100644 index 0000000000..d156ae3d94 --- /dev/null +++ b/sync2/fptree/fptree_test.go @@ -0,0 +1,1243 @@ +package fptree_test + +import ( + "fmt" + "math/rand/v2" + "slices" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/fptree" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +const ( + testKeyLen = 32 + testDepth = 24 +) + +func requireEmpty(t *testing.T, sr rangesync.SeqResult) { + for range sr.Seq { + require.Fail(t, "expected an empty sequence") + } + require.NoError(t, sr.Error()) +} + +func firstKey(t *testing.T, sr rangesync.SeqResult) rangesync.KeyBytes { + k, err := sr.First() + require.NoError(t, err) + return k +} + +func testFPTree(t *testing.T, makeFPTrees mkFPTreesFunc) { + type rangeTestCase struct { + xIdx, yIdx int + x, y string + limit int + fp string + count uint32 + itype int + startIdx, endIdx int + } + for _, tc := range []struct { + name string + ids []string + ranges []rangeTestCase + x, y string + }{ + { + name: "empty", + ids: nil, + ranges: []rangeTestCase{ + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: -1, + fp: "000000000000000000000000", + count: 0, + itype: 0, + startIdx: -1, + endIdx: -1, + }, + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: "000000000000000000000000", + count: 0, + itype: 0, + startIdx: -1, + endIdx: -1, + }, + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "223456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: "000000000000000000000000", + count: 0, + itype: -1, + startIdx: -1, + endIdx: -1, + }, + { + x: "223456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: "000000000000000000000000", + count: 0, + itype: 1, + startIdx: -1, + endIdx: -1, + }, + }, + }, + { + name: "ids1", + ids: []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "abcdef1234567890000000000000000000000000000000000000000000000000", + }, + ranges: []rangeTestCase{ + { + xIdx: 0, + yIdx: 0, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + itype: 0, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 0, + limit: 0, + fp: "000000000000000000000000", + count: 0, + itype: 0, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 0, + limit: 3, + fp: "4761032dcfe98ba555555555", + count: 3, + itype: 0, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 4, + yIdx: 4, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + itype: 0, + startIdx: 4, + endIdx: 4, + }, + { + xIdx: 4, + yIdx: 4, + limit: 1, + fp: "abcdef123456789000000000", + count: 1, + itype: 0, + startIdx: 4, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 1, + limit: -1, + fp: "000000000000000000000000", + count: 1, + itype: -1, + startIdx: 0, + endIdx: 1, + }, + { + xIdx: 0, + yIdx: 3, + limit: -1, + fp: "4761032dcfe98ba555555555", + count: 3, + itype: -1, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 0, + yIdx: 4, + limit: 3, + fp: "4761032dcfe98ba555555555", + count: 3, + itype: -1, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 0, + yIdx: 4, + limit: 0, + fp: "000000000000000000000000", + count: 0, + itype: -1, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 1, + yIdx: 4, + limit: -1, + fp: "cfe98ba54761032ddddddddd", + count: 3, + itype: -1, + startIdx: 1, + endIdx: 4, + }, + { + xIdx: 1, + yIdx: 0, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 4, + itype: 1, + startIdx: 1, + endIdx: 0, + }, + { + xIdx: 2, + yIdx: 0, + limit: -1, + fp: "761032cfe98ba54ddddddddd", + count: 3, + itype: 1, + startIdx: 2, + endIdx: 0, + }, + { + xIdx: 2, + yIdx: 0, + limit: 0, + fp: "000000000000000000000000", + count: 0, + itype: 1, + startIdx: 2, + endIdx: 2, + }, + { + xIdx: 3, + yIdx: 1, + limit: -1, + fp: "2345679abcdef01888888888", + count: 3, + itype: 1, + startIdx: 3, + endIdx: 1, + }, + { + xIdx: 3, + yIdx: 2, + limit: -1, + fp: "317131e226622ee888888888", + count: 4, + itype: 1, + startIdx: 3, + endIdx: 2, + }, + { + xIdx: 3, + yIdx: 2, + limit: 3, + fp: "2345679abcdef01888888888", + count: 3, + itype: 1, + startIdx: 3, + endIdx: 1, + }, + { + x: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0", + y: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + limit: -1, + fp: "000000000000000000000000", + count: 0, + itype: -1, + startIdx: 0, + endIdx: 0, + }, + }, + }, + { + name: "ids2", + ids: []string{ + "6e476ca729c3840d0118785496e488124ee7dade1aef0c87c6edc78f72e4904f", + "829977b444c8408dcddc1210536f3b3bdc7fd97777426264b9ac8f70b97a7fd1", + "a280bcb8123393e0d4a15e5c9850aab5dddffa03d5efa92e59bc96202e8992bc", + "e93163f908630280c2a8bffd9930aa684be7a3085432035f5c641b0786590d1d", + }, + ranges: []rangeTestCase{ + { + xIdx: 0, + yIdx: 0, + limit: -1, + fp: "a76fc452775b55e0dacd8be5", + count: 4, + itype: 0, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 0, + limit: 3, + fp: "4e5ea7ab7f38576018653418", + count: 3, + itype: 0, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 0, + yIdx: 3, + limit: -1, + fp: "4e5ea7ab7f38576018653418", + count: 3, + itype: -1, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 3, + yIdx: 1, + limit: -1, + fp: "87760f5e21a0868dc3b0c7a9", + count: 2, + itype: 1, + startIdx: 3, + endIdx: 1, + }, + { + xIdx: 3, + yIdx: 2, + limit: -1, + fp: "05ef78ea6568c6000e6cd5b9", + count: 3, + itype: 1, + startIdx: 3, + endIdx: 2, + }, + }, + }, + { + name: "ids3", + ids: []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + "33940245f4aace670c84f471ff4e862d1d82ce0ada9b98a753038b4f9e60e330", + "366d9e7adb3932e52e0a92a0afc75a2875995e7de8e0c4159e22eb97526a3547", + "66883aa35d2c8d293f07c5c5c40c63416317423418fe5c7fd17b5fb68b3e976e", + "80fce3e9654459cff3441e1a96413f0872e0b6f093879609696042fcfe1c8115", + "8b2025fbe0bbebea4baee48bac9a63a4013a2ec898d7b0a518eccdb99bdb368e", + "8e3e609653adfddcdcb6ddda7461db3a2fc822c3f96874a002f715b80865e575", + "9b25e39d6cc3beac3ecc12140f46a699880ac8303555c694fd40ba8e61bb8b47", + "a3c8628a1b28d1ba6f3d8beb4a29315c02789c5b53a095fa7865c9b3041502d6", + "a98fdcab5e351a1bfd25ddcf9973e9c56a4b688d78743a8a03fa3b1d53da4949", + "ac9c015dd51defacfc14bd4c9c8eedb89aad884bef493553a189a2915c828e95", + "ba745196493a8368ef091860f2692978b381f67566d3413e85167672d672c8ac", + "c26353d8bc9a1eea8e79fd693c1a1e58dacded75ceda84ed6c356bcf02b6d0f1", + "c3f126a37c2e33b6258c87fd043026dacf0b8dd4df7a9afd7cdc293b075e1878", + "cefd0cc8b32929df07b6ebb5b6e433f28d5460f143814f3f651330ea15e5d6e7", + "d9390718256e71edfe671334edbfcbed8b4de3221db55805ebf606c73fe969f1", + "db7ee147da05a5cbec3f59b020cbdba88e40ab6b212ae93c98d5a210d83a4a7b", + "deab906f979a647eff85f3a54e5edd665f2536e0005812aee2e5e411ae71855e", + "e0b6ab7f483527771faadbee8b4ed99ae96167d054ae5c513faf00c78aa36bdd", + "e4ed6f5dcf179a4f10521d58d65d423098af5f6f18c42f3125a5917d338b7477", + "e53de3ec53ba88029a2a0459a3ab82cdb3726c8aeccabf38a04e048b9add92ef", + "f2aff99498615c44d94266060e948c11bb275ec37d0d3c651bb3ba0039a11a64", + "f7f81332b63b79718f0321660a5cd8f6970474ff873afcdebb0d3436a2ad12ac", + "fb42c36089a4883bc7ceaae9a57924d78557edb63ede3d5a2cf2d1f08db799d0", + "fe494ce48f5826c00f6bc6af74258ec6e47b92365850deed95b5bfcaeccc6be8", + }, + ranges: []rangeTestCase{ + { + x: "582485793d71c3e8429b9b2c8df360c2ea7bf90080d5bf375fe4618b00f59c0b", + y: "7eff517d2f11ed32f935be3001499ac779160a4891a496f88da0ceb33e3496cc", + limit: -1, + fp: "66883aa35d2c8d293f07c5c5", + count: 1, + itype: -1, + startIdx: 10, + endIdx: 11, + }, + }, + }, + { + name: "ids4", + ids: []string{ + "06a1f93f0dd88b60473d73127196631134382d59b7cd9b3e6bd6b4f25dd1c782", + "488da52a035df8674aa658d30ff58de82c9dc2ae9c474e004d585c52979eacbb", + "b5527010e990254702f77ffc8a6d6b499040bc3dc61b169a56fbc690e970c046", + "e10fc3141c5e3a00861a4dddb495a33736f845bff62fd295985b7dfa6bcbfc91", + }, + ranges: []rangeTestCase{ + { + xIdx: 2, + yIdx: 0, + limit: 1, + fp: "b5527010e990254702f77ffc", + count: 1, + itype: 1, + startIdx: 2, + endIdx: 3, + }, + }, + }, + { + name: "ids6", + ids: []string{ + "2727d39a2150ef91ef09fa0b60950a189d73e53fd73c1fc7a74e0a393582e51e", + "96a3a7cfdc9ec9101fd4a8bdf831c54053c2cd0b06a6914772edb68a0153fdec", + "b80318c43da5e4b56aa3b7f408a8f86c98418e5b364ef67a37db6017097c2ebc", + "b899092149e332f9686e02e2878e63b7ac85694eeadfe02c94f4f15627f41bcc", + }, + ranges: []rangeTestCase{ + { + xIdx: 3, + yIdx: 3, + limit: 2, + fp: "9fbedabb68b3dd688767f8e9", + count: 2, + itype: 0, + startIdx: 3, + endIdx: 1, + }, + }, + }, + { + name: "ids7", + ids: []string{ + "3595ec355452c94143c6bdae281b162e5b0997e6392dd1a345146861b8fb4586", + "68d02e8f0c69b0b16dc73dda147a231a09b32d709b9b4028f13ee7ffa2e820c8", + "7079bb2d00f961b4dc42911e2009411ceb7b8c950492a627111b60773a31c2ce", + "ad69fbf959a0b0ba1042a2b13d1b2c9a17f8507c642e55dd93277fe8dab378a6", + }, + ranges: []rangeTestCase{ + { + x: "4844a20cd5a83c101cc522fa37539412d0aac4c76a48b940e1845c3f2fe79c85", + y: "cb93566c2037bc8353162e9988974e4585c14f656bf6aed8fa51d00e1ae594de", + limit: -1, + fp: "b5c06e5b553061bfa1c70e75", + count: 3, + itype: -1, + startIdx: 1, + endIdx: 0, + }, + }, + }, + { + name: "ids8", + ids: []string{ + "0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", + "3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", + "66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", + "90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241", + "9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", + "c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128", + }, + ranges: []rangeTestCase{ + { + x: "9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", + y: "0e69880000000000000000000000000000000000000000000000000000000000", + limit: -1, + fp: "5f78f3f7e073844de4501d50", + count: 2, + itype: 1, + startIdx: 4, + endIdx: 0, + }, + }, + }, + { + name: "ids9", + ids: []string{ + "03744b955a21408f78eb4c1e51f897ed90c22cf561b8ecef25a4b6ec68f3e895", + "691a62eb05d21ee9407fd48d252b5e80a525fd017e941fba383ceabe2ce0c0ee", + "73e10ac8b36bc20195c5d1b162d05402eaef6622accf648399cb60874ac22165", + "845c0a945137ed6b52fbb96a57909869cf34f41100a3a60e5d385d28c42621e1", + "bc1ffc4d9fddbd9f3cd17c0fe53c6b86a2e36256f37e1e73c11e4c9effa911bf", + }, + ranges: []rangeTestCase{ + { + x: "1a4f33388cab82533de99d9370fe367f654c76cd7e71a28334d993a31aa3e87a", + y: "6c5fe0023abc90d0a9327083ebc73c442cec8854f99e378551b502448f2ce000", + limit: -1, + fp: "691a62eb05d21ee9407fd48d", + count: 1, + itype: -1, + startIdx: 1, + endIdx: 2, + }, + }, + }, + { + name: "ids10", + ids: []string{ + "0aea5e19b9f53af915110ba1e05494666e8a1f4bb597d6ca0193c34b525f3480", + "219d9f504af986492356061a68cd2355fd423768c70e511cd7802cd4fdbde1c5", + "277a6bbc173628948456cbeb90309ae70ab837296f504640b53a891a3ddefb65", + "2ff6f89a1f0655255a74ff0dc4eda3a67ff69bc9667261763536917db15d9fe2", + "46b9e5fb278225f28885717512a4b2e5fbbc79b61bde8417cc2e5caf0ad86b17", + "a732516bf7198a3c3cb4edc1c3b1ec11a2545844c45464df44e31135ad84fee0", + "ea238facb9e3b3b6b9ca66bd9472b505e982ed937b22eb127269723124bb9ce8", + "ff90f791d2678d09d12f1a672de85c5127ef1f8a47ae5e8f3b61de06fd803db7", + }, + ranges: []rangeTestCase{ + { + x: "64015400af6cc54ce62fe1b478b38abfef5ab609182d6df0fd46f16c880263b2", + y: "0fcc4ed4c932e1f6ba53418a0116d20ab119c1152644abe5ee1ab30599cd3780", + limit: -1, + fp: "b86b774f25688e7a41409aba", + count: 4, + itype: 1, + startIdx: 5, + endIdx: 1, + }, + }, + }, + { + name: "ids11", + ids: []string{ + "05ce2ac65bf22e2d196814d881125ce5e4f93078ab357e151c7bfccd9ef24f1d", + "81f9f4becc8f91f1c37075ec810828b13d4e8d98b8207c467537043a1bb5d72c", + "a15ecd17ec6674a14faf67649e0058366bf852bd51a0c41c15542861eaf55bac", + "baeaf7d94cc800d38215396e46ba9e1293107a7e5c5d1cd5771f341e570b9f95", + "bd666290c1e339e8cc9d4d1aaf3ce68169dfffbfbe112e22818c72eb373160fd", + "d598253954cbf6719829dd4dca89106622cfb87666991214fece997855478a1c", + "d9e7a5bfa187a248e894e5e72874b3bf40b0863f707c72ae70e2042ba497d3ec", + "e58ededd4c54788c451ede2a3b92e62e1148fcd4184262dab28056f03b639ef5", + }, + ranges: []rangeTestCase{ + { + xIdx: 7, + yIdx: 2, + limit: -1, + fp: "61b900a5db29c7509f06bf1e", + count: 3, + itype: 1, + startIdx: 7, + endIdx: 2, + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + trees := makeFPTrees(t) + ft := trees[0] + var hs []rangesync.KeyBytes + for _, hex := range tc.ids { + h := rangesync.MustParseHexKeyBytes(hex) + hs = append(hs, h) + require.NoError(t, ft.RegisterKey(h)) + fptree.AnalyzeTreeNodeRefs(t, trees...) + } + + var sb strings.Builder + ft.Dump(&sb) + t.Logf("tree:\n%s", sb.String()) + + fptree.CheckTree(t, ft) + for _, k := range hs { + require.True(t, ft.CheckKey(k), "checkKey(%s)", k.ShortString()) + } + require.False(t, ft.CheckKey(rangesync.RandomKeyBytes(testKeyLen)), "checkKey(random)") + + for _, rtc := range tc.ranges { + var x, y rangesync.KeyBytes + var name string + if rtc.x != "" { + x = rangesync.MustParseHexKeyBytes(rtc.x) + y = rangesync.MustParseHexKeyBytes(rtc.y) + name = fmt.Sprintf("%s-%s_%d", rtc.x, rtc.y, rtc.limit) + } else { + x = hs[rtc.xIdx] + y = hs[rtc.yIdx] + name = fmt.Sprintf("%d-%d_%d", rtc.xIdx, rtc.yIdx, rtc.limit) + } + t.Run(name, func(t *testing.T) { + fpr, err := ft.FingerprintInterval(x, y, rtc.limit) + require.NoError(t, err) + assert.Equal(t, rtc.fp, fpr.FP.String(), "fp") + assert.Equal(t, rtc.count, fpr.Count, "count") + assert.Equal(t, rtc.itype, fpr.IType, "itype") + + if rtc.startIdx == -1 { + requireEmpty(t, fpr.Items) + } else { + require.NotNil(t, fpr.Items, "items") + expK := rangesync.KeyBytes(hs[rtc.startIdx]) + assert.Equal(t, expK, firstKey(t, fpr.Items), "items") + } + + if rtc.endIdx == -1 { + require.Nil(t, fpr.Next, "next") + } else { + require.NotNil(t, fpr.Next, "next") + expK := rangesync.KeyBytes(hs[rtc.endIdx]) + assert.Equal(t, expK, fpr.Next, "next") + } + }) + } + + ft.Release() + require.Zero(t, ft.PoolNodeCount()) + }) + } +} + +type mkFPTreesFunc func(t *testing.T) []*fptree.FPTree + +func makeFPTreeWithValues(t *testing.T) []*fptree.FPTree { + ft := fptree.NewFPTreeWithValues(0, testKeyLen) + return []*fptree.FPTree{ft} +} + +func makeInMemoryFPTree(t *testing.T) []*fptree.FPTree { + store := fptree.NewFPTreeWithValues(0, testKeyLen) + ft := fptree.NewFPTree(0, store, testKeyLen, testDepth) + return []*fptree.FPTree{ft, store} +} + +func makeDBBackedFPTree(t *testing.T) []*fptree.FPTree { + db := sqlstore.CreateDB(t, testKeyLen) + st := sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + sts, err := st.Snapshot(db) + require.NoError(t, err) + store := fptree.NewDBBackedStore(db, sts, 0, testKeyLen) + ft := fptree.NewFPTree(0, store, testKeyLen, testDepth) + return []*fptree.FPTree{ft, store.FPTree} +} + +func TestFPTree(t *testing.T) { + t.Run("values in fpTree", func(t *testing.T) { + testFPTree(t, makeFPTreeWithValues) + }) + t.Run("in-memory fptree-based id store", func(t *testing.T) { + testFPTree(t, makeInMemoryFPTree) + }) + t.Run("db-backed store", func(t *testing.T) { + testFPTree(t, makeDBBackedFPTree) + }) +} + +func TestFPTreeAsStore(t *testing.T) { + s := fptree.NewFPTreeWithValues(0, testKeyLen) + + sr := s.All() + for range sr.Seq { + require.Fail(t, "sequence not empty") + } + require.NoError(t, sr.Error()) + + sr = s.From(rangesync.MustParseHexKeyBytes( + "0000000000000000000000000000000000000000000000000000000000000000"), + 1) + for range sr.Seq { + require.Fail(t, "sequence not empty") + } + require.NoError(t, sr.Error()) + + for _, h := range []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + } { + s.RegisterKey(rangesync.MustParseHexKeyBytes(h)) + } + + sr = s.All() + for range 3 { // make sure seq is reusable + var r []string + n := 15 + for k := range sr.Seq { + r = append(r, k.String()) + n-- + if n == 0 { + break + } + } + require.NoError(t, sr.Error()) + require.Equal(t, []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + }, r) + } + + sr = s.From(rangesync.MustParseHexKeyBytes( + "5555555555555555555555555555555555555555555555555555555555555555"), + 1) + for range 3 { // make sure seq is reusable + var r []string + n := 15 + for k := range sr.Seq { + r = append(r, k.String()) + n-- + if n == 0 { + break + } + } + require.NoError(t, sr.Error()) + require.Equal(t, []string{ + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + }, r) + } +} + +type noIDStore struct{} + +var _ sqlstore.IDStore = noIDStore{} + +func (noIDStore) Clone() sqlstore.IDStore { return &noIDStore{} } +func (noIDStore) RegisterKey(h rangesync.KeyBytes) error { return nil } +func (noIDStore) All() rangesync.SeqResult { panic("no ID store") } +func (noIDStore) Release() {} + +func (noIDStore) From(from rangesync.KeyBytes, sizeHint int) rangesync.SeqResult { + return rangesync.EmptySeqResult() +} + +// TestFPTreeNoIDStoreCalls tests that an fpTree can avoid using an idStore if X has only +// 0 bits below max-depth and Y has only 1 bits below max-depth. It also checks that an fpTree +// can avoid using an idStore in "relaxed count" mode for splitting ranges. +func TestFPTreeNoIDStoreCalls(t *testing.T) { + ft := fptree.NewFPTree(0, &noIDStore{}, testKeyLen, testDepth) + hashes := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("1111111111111111111111111111111111111111111111111111111111111111"), + rangesync.MustParseHexKeyBytes("2222222222222222222222222222222222222222222222222222222222222222"), + rangesync.MustParseHexKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + for _, h := range hashes { + ft.RegisterKey(h) + } + + for _, tc := range []struct { + x, y rangesync.KeyBytes + limit int + fp string + count uint32 + }{ + { + x: hashes[0], + y: hashes[0], + limit: -1, + fp: "ffffffffffffffffffffffff", + count: 4, + }, + { + x: rangesync.MustParseHexKeyBytes( + "1111110000000000000000000000000000000000000000000000000000000000"), + y: rangesync.MustParseHexKeyBytes( + "1111120000000000000000000000000000000000000000000000000000000000"), + limit: -1, + fp: "111111111111111111111111", + count: 1, + }, + { + x: rangesync.MustParseHexKeyBytes( + "0000000000000000000000000000000000000000000000000000000000000000"), + y: rangesync.MustParseHexKeyBytes( + "9000000000000000000000000000000000000000000000000000000000000000"), + limit: -1, + fp: "ffffffffffffffffffffffff", + count: 4, + }, + } { + fpr, err := ft.FingerprintInterval(tc.x, tc.y, tc.limit) + require.NoError(t, err) + require.Equal(t, tc.fp, fpr.FP.String(), "fp") + require.Equal(t, tc.count, fpr.Count, "count") + } +} + +func TestFPTreeClone(t *testing.T) { + store := fptree.NewFPTreeWithValues(10, testKeyLen) + ft1 := fptree.NewFPTree(10, store, testKeyLen, testDepth) + hashes := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("1111111111111111111111111111111111111111111111111111111111111111"), + rangesync.MustParseHexKeyBytes("3333333333333333333333333333333333333333333333333333333333333333"), + rangesync.MustParseHexKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + } + + ft1.RegisterKey(hashes[0]) + fptree.AnalyzeTreeNodeRefs(t, ft1, store) + + ft1.RegisterKey(hashes[1]) + + fpr, err := ft1.FingerprintInterval(hashes[0], hashes[0], -1) + require.NoError(t, err) + require.Equal(t, "222222222222222222222222", fpr.FP.String(), "fp") + require.Equal(t, uint32(2), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + + fptree.AnalyzeTreeNodeRefs(t, ft1, store) + + ft2 := ft1.Clone().(*fptree.FPTree) + + fpr, err = ft1.FingerprintInterval(hashes[0], hashes[0], -1) + require.NoError(t, err) + require.Equal(t, "222222222222222222222222", fpr.FP.String(), "fp") + require.Equal(t, uint32(2), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + + fptree.AnalyzeTreeNodeRefs(t, ft1, ft2, store, ft2.IDStore().(*fptree.FPTree)) + + t.Logf("add hash to copy") + ft2.RegisterKey(hashes[2]) + + fpr, err = ft2.FingerprintInterval(hashes[0], hashes[0], -1) + require.NoError(t, err) + require.Equal(t, "666666666666666666666666", fpr.FP.String(), "fp") + require.Equal(t, uint32(3), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + + // original tree unchanged + fpr, err = ft1.FingerprintInterval(hashes[0], hashes[0], -1) + require.NoError(t, err) + require.Equal(t, "222222222222222222222222", fpr.FP.String(), "fp") + require.Equal(t, uint32(2), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + + fptree.AnalyzeTreeNodeRefs(t, ft1, ft2, store, ft2.IDStore().(*fptree.FPTree)) + + ft1.Release() + ft2.Release() + fptree.AnalyzeTreeNodeRefs(t, ft1, ft2, store, ft2.IDStore().(*fptree.FPTree)) + + require.Zero(t, ft1.PoolNodeCount()) + require.Zero(t, ft2.PoolNodeCount()) +} + +func TestRandomClone(t *testing.T) { + trees := []*fptree.FPTree{ + fptree.NewFPTree(1000, fptree.NewFPTreeWithValues(1000, testKeyLen), testKeyLen, testDepth), + } + for range 100 { + n := len(trees) + for range rand.IntN(20) { + trees = append(trees, trees[rand.IntN(n)].Clone().(*fptree.FPTree)) + } + for range rand.IntN(100) { + trees[rand.IntN(len(trees))].RegisterKey(rangesync.RandomKeyBytes(testKeyLen)) + } + + trees = slices.DeleteFunc(trees, func(ft *fptree.FPTree) bool { + if n == 1 { + return false + } + n-- + if rand.IntN(3) == 0 { + ft.Release() + return true + } + return false + }) + allTrees := slices.Clone(trees) + for _, ft := range trees { + allTrees = append(allTrees, ft.IDStore().(*fptree.FPTree)) + } + fptree.AnalyzeTreeNodeRefs(t, allTrees...) + for _, ft := range trees { + fptree.CheckTree(t, ft) + fptree.CheckTree(t, ft.IDStore().(*fptree.FPTree)) + } + if t.Failed() { + break + } + } + for _, ft := range trees { + ft.Release() + } + for _, ft := range trees { + require.Zero(t, ft.PoolNodeCount()) + } +} + +type hashList []rangesync.KeyBytes + +func (l hashList) findGTE(h rangesync.KeyBytes) int { + p, _ := slices.BinarySearchFunc(l, h, func(a, b rangesync.KeyBytes) int { + return a.Compare(b) + }) + return p +} + +func (l hashList) keyAt(p int) rangesync.KeyBytes { + if p == len(l) { + p = 0 + } + return rangesync.KeyBytes(l[p]) +} + +type fpResultWithBounds struct { + fp rangesync.Fingerprint + //nolint:unused + count uint32 + itype int + start rangesync.KeyBytes + //nolint:unused + next rangesync.KeyBytes +} + +func toFPResultWithBounds(t *testing.T, fpr fptree.FPResult) fpResultWithBounds { + return fpResultWithBounds{ + fp: fpr.FP, + count: fpr.Count, + itype: fpr.IType, + next: fpr.Next, + start: firstKey(t, fpr.Items), + } +} + +func dumbFP(hs hashList, x, y rangesync.KeyBytes, limit int) fpResultWithBounds { + var fpr fpResultWithBounds + l := len(hs) + if l == 0 { + return fpr + } + fpr.itype = x.Compare(y) + switch fpr.itype { + case -1: + p := hs.findGTE(x) + pY := hs.findGTE(y) + fpr.start = hs.keyAt(p) + for { + if p >= pY || limit == 0 { + fpr.next = hs.keyAt(p) + break + } + fpr.fp.Update(hs.keyAt(p)) + limit-- + fpr.count++ + p++ + } + case 1: + p := hs.findGTE(x) + fpr.start = hs.keyAt(p) + for { + if p >= len(hs) || limit == 0 { + fpr.next = hs.keyAt(p) + break + } + fpr.fp.Update(hs.keyAt(p)) + limit-- + fpr.count++ + p++ + } + if limit == 0 { + return fpr + } + pY := hs.findGTE(y) + p = 0 + for { + if p == pY || limit == 0 { + fpr.next = hs.keyAt(p) + break + } + fpr.fp.Update(hs.keyAt(p)) + limit-- + fpr.count++ + p++ + } + default: + pX := hs.findGTE(x) + p := pX + fpr.start = hs.keyAt(p) + fpr.next = fpr.start + for { + if limit == 0 { + fpr.next = hs.keyAt(p) + break + } + fpr.fp.Update(hs.keyAt(p)) + limit-- + fpr.count++ + p = (p + 1) % l + if p == pX { + break + } + } + } + return fpr +} + +func verifyInterval(t *testing.T, hs hashList, ft *fptree.FPTree, x, y rangesync.KeyBytes, limit int) fptree.FPResult { + expFPR := dumbFP(hs, x, y, limit) + fpr, err := ft.FingerprintInterval(x, y, limit) + require.NoError(t, err) + require.Equal(t, expFPR, toFPResultWithBounds(t, fpr), + "x=%s y=%s limit=%d", x.String(), y.String(), limit) + + require.Equal(t, expFPR, toFPResultWithBounds(t, fpr), + "x=%s y=%s limit=%d", x.String(), y.String(), limit) + + return fpr +} + +func verifySubIntervals( + t *testing.T, + hs hashList, + ft *fptree.FPTree, + x, y rangesync.KeyBytes, + limit, d int, +) fptree.FPResult { + fpr := verifyInterval(t, hs, ft, x, y, limit) + if fpr.Count > 1 { + c := int((fpr.Count + 1) / 2) + if limit >= 0 { + require.Less(t, c, limit) + } + part := verifyInterval(t, hs, ft, x, y, c) + m := make(rangesync.KeyBytes, len(x)) + copy(m, part.Next) + verifySubIntervals(t, hs, ft, x, m, -1, d+1) + verifySubIntervals(t, hs, ft, m, y, -1, d+1) + } + return fpr +} + +func testFPTreeManyItems(t *testing.T, trees []*fptree.FPTree, randomXY bool, numItems, maxDepth, repeat int) { + ft := trees[0] + hs := make(hashList, numItems) + var fp rangesync.Fingerprint + for i := range hs { + h := rangesync.RandomKeyBytes(testKeyLen) + hs[i] = h + ft.RegisterKey(h) + fp.Update(h) + } + fptree.AnalyzeTreeNodeRefs(t, trees...) + slices.SortFunc(hs, func(a, b rangesync.KeyBytes) int { + return a.Compare(b) + }) + + fptree.CheckTree(t, ft) + for _, k := range hs { + require.True(t, ft.CheckKey(k), "checkKey(%s)", k.ShortString()) + } + + fpr, err := ft.FingerprintInterval(hs[0], hs[0], -1) + require.NoError(t, err) + require.Equal(t, fp, fpr.FP, "fp") + require.Equal(t, uint32(numItems), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + for i := 0; i < repeat; i++ { + var x, y rangesync.KeyBytes + if randomXY { + x = rangesync.RandomKeyBytes(testKeyLen) + y = rangesync.RandomKeyBytes(testKeyLen) + } else { + x = hs[rand.IntN(numItems)] + y = hs[rand.IntN(numItems)] + } + verifySubIntervals(t, hs, ft, x, y, -1, 0) + } +} + +func repeatTestFPTreeManyItems( + t *testing.T, + makeFPTrees mkFPTreesFunc, +) { + const ( + repeatOuter = 3 + repeatInner = 5 + numItems = 1 << 10 + maxDepth = 12 + ) + for _, tc := range []struct { + name string + randomXY bool + }{ + { + name: "bounds from the set", + randomXY: false, + }, + { + name: "random bounds", + randomXY: true, + }, + } { + for i := 0; i < repeatOuter; i++ { + testFPTreeManyItems(t, makeFPTrees(t), tc.randomXY, numItems, maxDepth, repeatInner) + } + } +} + +func TestFPTreeManyItems(t *testing.T) { + t.Run("values in fpTree", func(t *testing.T) { + repeatTestFPTreeManyItems(t, makeFPTreeWithValues) + }) + t.Run("in-memory fptree-based id store", func(t *testing.T) { + repeatTestFPTreeManyItems(t, makeInMemoryFPTree) + }) + t.Run("db-backed store", func(t *testing.T) { + repeatTestFPTreeManyItems(t, makeDBBackedFPTree) + }) +} + +func verifyEasySplit( + t *testing.T, + ft *fptree.FPTree, + x, y rangesync.KeyBytes, + depth, + maxDepth int, +) ( + succeeded, failed int, +) { + fpr, err := ft.FingerprintInterval(x, y, -1) + require.NoError(t, err) + if fpr.Count <= 1 { + return 0, 0 + } + a := firstKey(t, fpr.Items) + require.NoError(t, err) + b := fpr.Next + require.NotNil(t, b) + + m := fpr.Count / 2 + sr, err := ft.EasySplit(x, y, int(m)) + if err != nil { + require.ErrorIs(t, err, fptree.ErrEasySplitFailed) + failed++ + sr, err = ft.Split(x, y, int(m)) + require.NoError(t, err) + } + require.NoError(t, err) + require.NotZero(t, sr.Part0.Count) + require.NotZero(t, sr.Part1.Count) + require.Equal(t, fpr.Count, sr.Part0.Count+sr.Part1.Count) + require.Equal(t, fpr.IType, sr.Part0.IType) + require.Equal(t, fpr.IType, sr.Part1.IType) + fp := sr.Part0.FP + fp.Update(sr.Part1.FP[:]) + require.Equal(t, fpr.FP, fp) + require.Equal(t, a, firstKey(t, sr.Part0.Items)) + precMiddle := firstKey(t, sr.Part1.Items) + + fpr11, err := ft.FingerprintInterval(x, precMiddle, -1) + require.NoError(t, err) + require.Equal(t, sr.Part0.Count, fpr11.Count) + require.Equal(t, sr.Part0.FP, fpr11.FP) + require.Equal(t, a, firstKey(t, fpr11.Items)) + + fpr12, err := ft.FingerprintInterval(precMiddle, y, -1) + require.NoError(t, err) + require.Equal(t, sr.Part1.Count, fpr12.Count) + require.Equal(t, sr.Part1.FP, fpr12.FP) + require.Equal(t, precMiddle, firstKey(t, fpr12.Items)) + + fpr11, err = ft.FingerprintInterval(x, sr.Middle, -1) + require.NoError(t, err) + require.Equal(t, sr.Part0.Count, fpr11.Count) + require.Equal(t, sr.Part0.FP, fpr11.FP) + require.Equal(t, a, firstKey(t, fpr11.Items)) + + fpr12, err = ft.FingerprintInterval(sr.Middle, y, -1) + require.NoError(t, err) + require.Equal(t, sr.Part1.Count, fpr12.Count) + require.Equal(t, sr.Part1.FP, fpr12.FP) + require.Equal(t, precMiddle, firstKey(t, fpr12.Items)) + + if maxDepth > 0 && depth >= maxDepth { + return 1, 0 + } + s1, f1 := verifyEasySplit(t, ft, x, sr.Middle, depth+1, maxDepth) + s2, f2 := verifyEasySplit(t, ft, sr.Middle, y, depth+1, maxDepth) + return succeeded + s1 + s2 + 1, failed + f1 + f2 +} + +func TestEasySplit(t *testing.T) { + maxDepth := 17 + count := 10000 + for range 5 { + store := fptree.NewFPTreeWithValues(10000, testKeyLen) + ft := fptree.NewFPTree(10000, store, testKeyLen, maxDepth) + for range count { + h := rangesync.RandomKeyBytes(testKeyLen) + ft.RegisterKey(h) + } + x := firstKey(t, ft.All()).Clone() + x.Trim(maxDepth) + + succeeded, failed := verifyEasySplit(t, ft, x, x, 0, maxDepth-2) + successRate := float64(succeeded) * 100 / float64(succeeded+failed) + t.Logf("succeeded %d, failed %d, success rate %.2f%%", + succeeded, failed, successRate) + require.GreaterOrEqual(t, successRate, 95.0) + } +} + +func TestEasySplitFPTreeWithValues(t *testing.T) { + count := 10000 + + for range 5 { + ft := fptree.NewFPTreeWithValues(10000, testKeyLen) + for range count { + h := rangesync.RandomKeyBytes(testKeyLen) + ft.RegisterKey(h) + } + + x := firstKey(t, ft.All()).Clone() + _, failed := verifyEasySplit(t, ft, x, x, 0, -1) + require.Zero(t, failed) + } +} diff --git a/sync2/fptree/nodepool.go b/sync2/fptree/nodepool.go new file mode 100644 index 0000000000..b04ac77cb5 --- /dev/null +++ b/sync2/fptree/nodepool.go @@ -0,0 +1,213 @@ +package fptree + +import ( + "slices" + "sync" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// nodeIndex represents an index of a node in the node pool. +type nodeIndex uint32 + +const ( + // noIndex represents an invalid node index. + noIndex = ^nodeIndex(0) + // leafFlag is a flag that indicates that a node is a leaf node. + leafFlag = uint32(1 << 31) +) + +// node represents an fpTree node. +type node struct { + // Fingerprint + fp rangesync.Fingerprint + // Item count + c uint32 + // Left child, noIndex if not present. + l nodeIndex + // Right child, noIndex if not present. + r nodeIndex +} + +// nodePool represents a pool of tree nodes. +// The pool is shared between the orignal tree and its clones. +type nodePool struct { + mtx sync.RWMutex + rcPool rcPool[node, uint32] + leafMap map[uint32]rangesync.KeyBytes +} + +// init pre-allocates the node pool with n nodes. +func (np *nodePool) init(n int) { + np.rcPool.init(n) +} + +// lockWrite locks the node pool for writing. +// There can only be one writer at a time. +// This blocks until all other reader and writer locks are released. +func (np *nodePool) lockWrite() { np.mtx.Lock() } + +// unlockWrite unlocks the node pool for writing. +func (np *nodePool) unlockWrite() { np.mtx.Unlock() } + +// lockRead locks the node pool for reading. +// There can be multiple reader locks held at a time. +// This blocks until the writer lock is released, if it's held. +func (np *nodePool) lockRead() { np.mtx.RLock() } + +// unlockRead unlocks the node pool for reading. +func (np *nodePool) unlockRead() { np.mtx.RUnlock() } + +// add adds a new node to the pool. +func (np *nodePool) add( + fp rangesync.Fingerprint, + c uint32, + left, right nodeIndex, + v rangesync.KeyBytes, + replaceIdx nodeIndex, +) nodeIndex { + if c == 1 || left == noIndex && right == noIndex { + c |= leafFlag + } + newNode := node{fp: fp, c: c, l: noIndex, r: noIndex} + if left != noIndex { + newNode.l = left + } + if right != noIndex { + newNode.r = right + } + var idx uint32 + if replaceIdx != noIndex { + np.rcPool.replace(uint32(replaceIdx), newNode) + idx = uint32(replaceIdx) + } else { + idx = np.rcPool.add(newNode) + } + if v != nil { + if c != 1|leafFlag { + panic("BUG: non-leaf node with a value") + } + if np.leafMap == nil { + np.leafMap = make(map[uint32]rangesync.KeyBytes) + } + np.leafMap[idx] = slices.Clone(v) + } else if replaceIdx != noIndex { + delete(np.leafMap, idx) + } + return nodeIndex(idx) +} + +// value returns the value of the node at the given index. +func (np *nodePool) value(idx nodeIndex) rangesync.KeyBytes { + if idx == noIndex { + return nil + } + return np.leafMap[uint32(idx)] +} + +// left returns the left child of the node at the given index. +func (np *nodePool) left(idx nodeIndex) nodeIndex { + if idx == noIndex { + return noIndex + } + node := np.rcPool.item(uint32(idx)) + if node.c&leafFlag != 0 || node.l == noIndex { + return noIndex + } + return node.l +} + +// right returns the right child of the node at the given index. +func (np *nodePool) right(idx nodeIndex) nodeIndex { + if idx == noIndex { + return noIndex + } + node := np.rcPool.item(uint32(idx)) + if node.c&leafFlag != 0 || node.r == noIndex { + return noIndex + } + return node.r +} + +// leaf returns true if this is a leaf node. +func (np *nodePool) leaf(idx nodeIndex) bool { + if idx == noIndex { + panic("BUG: bad node index") + } + node := np.rcPool.item(uint32(idx)) + return node.c&leafFlag != 0 +} + +// count returns number of set items to which the node at the given index corresponds. +func (np *nodePool) count(idx nodeIndex) uint32 { + if idx == noIndex { + return 0 + } + node := np.rcPool.item(uint32(idx)) + if node.c == 1 { + panic("BUG: single-count node w/o the leaf flag") + } + return node.c &^ leafFlag +} + +// info returns the count, fingerprint, and leaf flag of the node at the given index. +func (np *nodePool) info(idx nodeIndex) (count uint32, fp rangesync.Fingerprint, leaf bool) { + if idx == noIndex { + panic("BUG: bad node index") + } + node := np.rcPool.item(uint32(idx)) + if node.c == 1 { + panic("BUG: single-count node w/o the leaf flag") + } + return node.c &^ leafFlag, node.fp, node.c&leafFlag != 0 +} + +// releaseOne releases the node at the given index, returning it to the pool. +func (np *nodePool) releaseOne(idx nodeIndex) bool { + if idx == noIndex { + return false + } + if np.rcPool.release(uint32(idx)) { + delete(np.leafMap, uint32(idx)) + return true + } + return false +} + +// release releases the node at the given index, returning it to the pool, and recursively +// releases its children. +func (np *nodePool) release(idx nodeIndex) bool { + if idx == noIndex { + return false + } + node := np.rcPool.item(uint32(idx)) + if !np.rcPool.release(uint32(idx)) { + return false + } + if node.c&leafFlag == 0 { + if node.l != noIndex { + np.release(node.l) + } + if node.r != noIndex { + np.release(node.r) + } + } else { + delete(np.leafMap, uint32(idx)) + } + return true +} + +// ref adds a reference to the given node. +func (np *nodePool) ref(idx nodeIndex) { + np.rcPool.ref(uint32(idx)) +} + +// refCount returns the reference count for the node at the given index. +func (np *nodePool) refCount(idx nodeIndex) uint32 { + return np.rcPool.refCount(uint32(idx)) +} + +// nodeCount returns the number of nodes in the pool. +func (np *nodePool) nodeCount() int { + return np.rcPool.count() +} diff --git a/sync2/fptree/nodepool_test.go b/sync2/fptree/nodepool_test.go new file mode 100644 index 0000000000..32dcb2bbd8 --- /dev/null +++ b/sync2/fptree/nodepool_test.go @@ -0,0 +1,80 @@ +package fptree + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +func TestNodePool(t *testing.T) { + var np nodePool + require.Zero(t, np.nodeCount()) + idx1 := np.add(rangesync.MustParseHexFingerprint("000000000000000000000001"), 1, noIndex, noIndex, + rangesync.KeyBytes("foo"), noIndex) + idx2 := np.add(rangesync.MustParseHexFingerprint("000000000000000000000002"), 1, noIndex, noIndex, + rangesync.KeyBytes("bar"), noIndex) + idx3 := np.add(rangesync.MustParseHexFingerprint("000000000000000000000003"), 2, idx1, idx2, nil, noIndex) + + require.Equal(t, nodeIndex(0), idx1) + require.Equal(t, rangesync.KeyBytes("foo"), np.value(idx1)) + require.Equal(t, noIndex, np.left(idx1)) + require.Equal(t, noIndex, np.right(idx1)) + require.True(t, np.leaf(idx1)) + require.Equal(t, uint32(1), np.count(idx1)) + count, fp, leaf := np.info(idx1) + require.Equal(t, uint32(1), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000001"), fp) + require.True(t, leaf) + require.Equal(t, uint32(1), np.refCount(idx1)) + + require.Equal(t, nodeIndex(1), idx2) + require.Equal(t, rangesync.KeyBytes("bar"), np.value(idx2)) + require.Equal(t, noIndex, np.left(idx2)) + require.Equal(t, noIndex, np.right(idx2)) + require.True(t, np.leaf(idx2)) + require.Equal(t, uint32(1), np.count(idx2)) + count, fp, leaf = np.info(idx2) + require.Equal(t, uint32(1), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000002"), fp) + require.True(t, leaf) + require.Equal(t, uint32(1), np.refCount(idx2)) + + require.Equal(t, nodeIndex(2), idx3) + require.Nil(t, nil, idx3) + require.Equal(t, idx1, np.left(idx3)) + require.Equal(t, idx2, np.right(idx3)) + require.False(t, np.leaf(idx3)) + require.Equal(t, uint32(2), np.count(idx3)) + count, fp, leaf = np.info(idx3) + require.Equal(t, uint32(2), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000003"), fp) + require.False(t, leaf) + require.Equal(t, uint32(1), np.refCount(idx3)) + + require.Equal(t, 3, np.nodeCount()) + + np.ref(idx2) + require.Equal(t, uint32(2), np.refCount(idx2)) + + np.release(idx3) + require.Equal(t, 1, np.nodeCount()) + require.Equal(t, uint32(1), np.refCount(idx2)) + count, fp, leaf = np.info(idx2) + require.Equal(t, uint32(1), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000002"), fp) + require.True(t, leaf) + + require.Equal(t, idx2, np.add( + rangesync.MustParseHexFingerprint("000000000000000000000004"), 1, noIndex, noIndex, + rangesync.KeyBytes("bar2"), idx2)) + count, fp, leaf = np.info(idx2) + require.Equal(t, uint32(1), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000004"), fp) + require.True(t, leaf) + require.Equal(t, rangesync.KeyBytes("bar2"), np.value(idx2)) + + np.release(idx2) + require.Zero(t, np.nodeCount()) +} diff --git a/sync2/fptree/prefix.go b/sync2/fptree/prefix.go new file mode 100644 index 0000000000..6fd9ffea76 --- /dev/null +++ b/sync2/fptree/prefix.go @@ -0,0 +1,198 @@ +package fptree + +import ( + "fmt" + "math/bits" + "strings" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +const ( + // prefixBytes is the number of bytes in a prefix. + prefixBytes = rangesync.FingerprintSize + // maxPrefixLen is the maximum length of a prefix in bits. + maxPrefixLen = prefixBytes * 8 +) + +// prefix is a prefix of a key, represented as a bit string. +type prefix struct { + // the bytes of the prefix, starting from the highest byte. + b [prefixBytes]byte + // length of the prefix in bits. + l uint16 +} + +// emptyPrefix is the empty prefix (length 0). +var emptyPrefix = prefix{} + +// prefixFromKeyBytes returns a prefix made from a key by using the maximum possible +// number of its bytes. +func prefixFromKeyBytes(k rangesync.KeyBytes) (p prefix) { + p.l = uint16(copy(p.b[:], k) * 8) + return p +} + +// len returns the length of the prefix. +func (p prefix) len() int { + return int(p.l) +} + +// left returns the prefix with one more 0 bit. +func (p prefix) left() prefix { + if p.l == maxPrefixLen { + panic("BUG: max prefix len reached") + } + p.b[p.l/8] &^= 1 << (7 - p.l%8) + p.l++ + return p +} + +// right returns the prefix with one more 1 bit. +func (p prefix) right() prefix { + if p.l == maxPrefixLen { + panic("BUG: max prefix len reached") + } + p.b[p.l/8] |= 1 << (7 - p.l%8) + p.l++ + return p +} + +// String implements fmt.Stringer. +func (p prefix) String() string { + if p.len() == 0 { + return "<0>" + } + var sb strings.Builder + for _, b := range p.b[:(p.l+7)/8] { + sb.WriteString(fmt.Sprintf("%08b", b)) + } + return fmt.Sprintf("<%d:%s>", p.l, sb.String()[:p.l]) +} + +// highBit returns the highest bit of the prefix as bool (false=0, true=1). +// If the prefix is empty, it returns false. +func (p prefix) highBit() bool { + return p.l != 0 && p.b[0]&0x80 != 0 +} + +// minID sets the key to the smallest key with the prefix. +func (p prefix) minID(k rangesync.KeyBytes) { + nb := (p.l + 7) / 8 + if len(k) < int(nb) { + panic("BUG: id slice too small") + } + copy(k[:nb], p.b[:nb]) + clear(k[nb:]) +} + +// idAfter sets the key to the key immediately after the largest key with the prefix. +// idAfter returns true if the resulting id is zero, meaning wraparound. +func (p prefix) idAfter(k rangesync.KeyBytes) bool { + nb := (p.l + 7) / 8 + if len(k) < int(nb) { + panic("BUG: id slice too small") + } + // Copy prefix bits to the key, set all the bits after the prefix to 1, then + // increment the key. + copy(k[:nb], p.b[:nb]) + if p.l%8 != 0 { + k[nb-1] |= (1<<(8-p.l%8) - 1) + } + for i := int(nb); i < len(k); i++ { + k[i] = 0xff + } + return k.Inc() +} + +// shift removes the highest bit from the prefix. +func (p prefix) shift() prefix { + switch l := p.len(); l { + case 0: + panic("BUG: can't shift zero prefix") + case 1: + return emptyPrefix + default: + var c byte + for nb := int((p.l+7)/8) - 1; nb >= 0; nb-- { + c, p.b[nb] = (p.b[nb]&0x80)>>7, (p.b[nb]<<1)|c + } + p.l-- + return p + } +} + +// match returns true if the prefix matches the key, that is, +// all the prefix bits are equal to the corresponding bits of the key. +func (p prefix) match(b rangesync.KeyBytes) bool { + if int(p.l) > len(b)*8 { + panic("BUG: id slice too small") + } + if p.l == 0 { + return true + } + bi := p.l / 8 + for i, v := range p.b[:bi] { + if b[i] != v { + return false + } + } + s := p.l % 8 + return s == 0 || p.b[bi]>>(8-s) == b[bi]>>(8-s) +} + +// preFirst0 returns the longest prefix of the key that consists entirely of binary 1s. +func preFirst0(k rangesync.KeyBytes) prefix { + var p prefix + nb := min(prefixBytes, len(k)) + for n, b := range k[:nb] { + if b != 0xff { + nOnes := bits.LeadingZeros8(^b) + if nOnes != 0 { + p.b[n] = 0xff << (8 - nOnes) + p.l += uint16(nOnes) + } + break + } + p.b[n] = 0xff + p.l += 8 + } + return p +} + +// preFirst1 returns the longest prefix of the key that consists entirely of binary 0s. +func preFirst1(k rangesync.KeyBytes) prefix { + var p prefix + nb := min(prefixBytes, len(k)) + for _, b := range k[:nb] { + if b != 0 { + p.l += uint16(bits.LeadingZeros8(b)) + break + } + p.l += 8 + } + return p +} + +// commonPrefix returns common prefix between two keys. +func commonPrefix(a, b rangesync.KeyBytes) prefix { + var p prefix + nb := min(prefixBytes, len(a), len(b)) + for n, v1 := range a[:nb] { + v2 := b[n] + p.b[n] = v1 + if v1 != v2 { + nEqBits := bits.LeadingZeros8(v1 ^ v2) + if nEqBits != 0 { + // Clear unused bits in the last used prefix byte + p.b[n] &^= 1<<(8-nEqBits) - 1 + p.l += uint16(nEqBits) + } else { + p.b[n] = 0 + } + break + } + p.l += 8 + } + return p +} diff --git a/sync2/fptree/prefix_test.go b/sync2/fptree/prefix_test.go new file mode 100644 index 0000000000..3394948191 --- /dev/null +++ b/sync2/fptree/prefix_test.go @@ -0,0 +1,328 @@ +package fptree + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +func verifyPrefix(t *testing.T, p prefix) { + for i := (p.len() + 7) / 8; i < prefixBytes; i++ { + require.Zero(t, p.b[i], "p.bs[%d]", i) + } +} + +func TestPrefix(t *testing.T) { + for _, tc := range []struct { + p prefix + s string + left prefix + right prefix + shift prefix + minID string + idAfter string + }{ + { + p: emptyPrefix, + s: "<0>", + left: prefix{b: [prefixBytes]byte{0}, l: 1}, + right: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + idAfter: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0}, l: 1}, + s: "<1:0>", + left: prefix{b: [prefixBytes]byte{0}, l: 2}, + right: prefix{b: [prefixBytes]byte{0x40}, l: 2}, + shift: emptyPrefix, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + idAfter: "8000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + s: "<1:1>", + left: prefix{b: [prefixBytes]byte{0x80}, l: 2}, + right: prefix{b: [prefixBytes]byte{0xc0}, l: 2}, + shift: emptyPrefix, + minID: "8000000000000000000000000000000000000000000000000000000000000000", + idAfter: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0}, l: 2}, + s: "<2:00>", + left: prefix{b: [prefixBytes]byte{0}, l: 3}, + right: prefix{b: [prefixBytes]byte{0x20}, l: 3}, + shift: prefix{b: [prefixBytes]byte{0}, l: 1}, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + idAfter: "4000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0x40}, l: 2}, + s: "<2:01>", + left: prefix{b: [prefixBytes]byte{0x40}, l: 3}, + right: prefix{b: [prefixBytes]byte{0x60}, l: 3}, + shift: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + minID: "4000000000000000000000000000000000000000000000000000000000000000", + idAfter: "8000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0x80}, l: 2}, + s: "<2:10>", + left: prefix{b: [prefixBytes]byte{0x80}, l: 3}, + right: prefix{b: [prefixBytes]byte{0xa0}, l: 3}, + shift: prefix{b: [prefixBytes]byte{0}, l: 1}, + minID: "8000000000000000000000000000000000000000000000000000000000000000", + idAfter: "c000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0xc0}, l: 2}, + s: "<2:11>", + left: prefix{b: [prefixBytes]byte{0xc0}, l: 3}, + right: prefix{b: [prefixBytes]byte{0xe0}, l: 3}, + shift: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + minID: "c000000000000000000000000000000000000000000000000000000000000000", + idAfter: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff}, l: 24}, + s: "<24:111111111111111111111111>", + left: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0}, l: 25}, + right: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0x80}, l: 25}, + shift: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xfe}, l: 23}, + minID: "ffffff0000000000000000000000000000000000000000000000000000000000", + idAfter: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0}, l: 25}, + s: "<25:1111111111111111111111110>", + left: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0}, l: 26}, + right: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0x40}, l: 26}, + shift: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xfe}, l: 24}, + minID: "ffffff0000000000000000000000000000000000000000000000000000000000", + idAfter: "ffffff8000000000000000000000000000000000000000000000000000000000", + }, + } { + t.Run(fmt.Sprint(tc.p), func(t *testing.T) { + require.Equal(t, tc.s, tc.p.String()) + require.Equal(t, tc.left, tc.p.left()) + verifyPrefix(t, tc.p.left()) + require.Equal(t, tc.right, tc.p.right()) + verifyPrefix(t, tc.p.right()) + if tc.p != emptyPrefix { + require.Equal(t, tc.shift, tc.p.shift()) + verifyPrefix(t, tc.p.shift()) + } + + minID := make(rangesync.KeyBytes, 32) + tc.p.minID(minID) + require.Equal(t, tc.minID, minID.String()) + + idAfter := make(rangesync.KeyBytes, 32) + tc.p.idAfter(idAfter) + require.Equal(t, tc.idAfter, idAfter.String()) + }) + } +} + +func TestCommonPrefix(t *testing.T) { + for _, tc := range []struct { + a, b, p string + }{ + { + a: "0000000000000000000000000000000000000000000000000000000000000000", + b: "8000000000000000000000000000000000000000000000000000000000000000", + p: "<0>", + }, + { + a: "A000000000000000000000000000000000000000000000000000000000000000", + b: "8000000000000000000000000000000000000000000000000000000000000000", + p: "<2:10>", + }, + { + a: "A000000000000000000000000000000000000000000000000000000000000000", + b: "A800000000000000000000000000000000000000000000000000000000000000", + p: "<4:1010>", + }, + { + a: "ABCDEF1234567890000000000000000000000000000000000000000000000000", + b: "ABCDEF1234567800000000000000000000000000000000000000000000000000", + p: "<56:10101011110011011110111100010010001101000101011001111000>", + }, + { + a: "ABCDEF1234567890123456789ABCDEF000000000000000000000000000000000", + b: "ABCDEF1234567890123456789ABCDEF000000000000000000000000000000000", + p: "<96:1010101111001101111011110001001000110100010101100111100010010000" + + "00010010001101000101011001111000>", + }, + } { + a := rangesync.MustParseHexKeyBytes(tc.a) + b := rangesync.MustParseHexKeyBytes(tc.b) + require.Equal(t, tc.p, commonPrefix(a, b).String()) + verifyPrefix(t, commonPrefix(a, b)) + } +} + +func TestPreFirst0(t *testing.T) { + for _, tc := range []struct { + k, exp string + }{ + { + k: "00000000", + exp: "<0>", + }, + { + k: "10000000", + exp: "<0>", + }, + { + k: "40000000", + exp: "<0>", + }, + { + k: "00040000", + exp: "<0>", + }, + { + k: "80000000", + exp: "<1:1>", + }, + { + k: "c0000000", + exp: "<2:11>", + }, + { + k: "cc000000", + exp: "<2:11>", + }, + { + k: "ffc00000", + exp: "<10:1111111111>", + }, + { + k: "ffffffff", + exp: "<32:11111111111111111111111111111111>", + }, + } { + k := rangesync.MustParseHexKeyBytes(tc.k) + require.Equal(t, tc.exp, preFirst0(k).String(), "k=%s", tc.k) + verifyPrefix(t, preFirst0(k)) + } +} + +func TestPreFirst1(t *testing.T) { + for _, tc := range []struct { + k, exp string + }{ + { + k: "ffffffff", + exp: "<0>", + }, + { + k: "80000000", + exp: "<0>", + }, + { + k: "c0000000", + exp: "<0>", + }, + { + k: "ffffffc0", + exp: "<0>", + }, + { + k: "70000000", + exp: "<1:0>", + }, + { + k: "30000000", + exp: "<2:00>", + }, + { + k: "00300000", + exp: "<10:0000000000>", + }, + { + k: "00000000", + exp: "<32:00000000000000000000000000000000>", + }, + } { + k := rangesync.MustParseHexKeyBytes(tc.k) + require.Equal(t, tc.exp, preFirst1(k).String(), "k=%s", tc.k) + verifyPrefix(t, preFirst1(k)) + } +} + +func TestMatch(t *testing.T) { + for _, tc := range []struct { + k string + p prefix + match bool + }{ + { + k: "12345678", + p: emptyPrefix, + match: true, + }, + { + k: "12345678", + p: prefix{l: 1}, + match: true, + }, + { + k: "12345678", + p: prefix{l: 3}, + match: true, + }, + { + k: "12345678", + p: prefix{l: 4}, + match: false, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + match: false, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x80}, l: 2}, + match: false, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x10}, l: 4}, + match: true, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x12, 0x34, 0x50}, l: 20}, + match: true, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x12, 0x34, 0x50}, l: 24}, + match: false, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x12, 0x34, 0x56, 0x78}, l: 32}, + match: true, + }, + } { + k := rangesync.MustParseHexKeyBytes(tc.k) + require.Equal(t, tc.match, tc.p.match(k), "k=%s p=%s", tc.k, tc.p) + } +} + +func TestPrefixFromKeyBytes(t *testing.T) { + p := prefixFromKeyBytes(rangesync.MustParseHexKeyBytes( + "123456789abcdef0123456789abcdef111111111111111111111111111111111")) + require.Equal(t, + "<96:000100100011010001010110011110001001101010111100"+ + "110111101111000000010010001101000101011001111000>", + p.String()) +} diff --git a/sync2/fptree/refcountpool.go b/sync2/fptree/refcountpool.go new file mode 100644 index 0000000000..cc0c33b1d1 --- /dev/null +++ b/sync2/fptree/refcountpool.go @@ -0,0 +1,120 @@ +package fptree + +import ( + "strconv" + "sync/atomic" +) + +// freeBit is a bit that indicates that an entry is free. +const freeBit = 1 << 31 + +// freeListMask is a mask that extracts the free list index from a refCount. +const freeListMask = freeBit - 1 + +// poolEntry is an entry in the rcPool. +type poolEntry[T any, I ~uint32] struct { + refCount uint32 + content T +} + +// rcPool is a reference-counted pool of items. +// The zero value is a valid, empty rcPool. +// Unlike sync.Pool, rcPool does not shrink, but uint32 indices can be used +// to reference items instead of larger 64-bit pointers, and the items +// can be shared between. +type rcPool[T any, I ~uint32] struct { + entries []poolEntry[T, I] + // freeList is 1-based so that rcPool doesn't need a constructor + freeList uint32 + allocCount atomic.Int64 +} + +// init pre-allocates the rcPool with n items. +func (rc *rcPool[T, I]) init(n int) { + rc.entries = make([]poolEntry[T, I], 0, n) + rc.freeList = 0 + rc.allocCount.Store(0) +} + +// count returns the number of items in the rcPool. +func (rc *rcPool[T, I]) count() int { + return int(rc.allocCount.Load()) +} + +// item returns the item at the given index. +func (rc *rcPool[T, I]) item(idx I) T { + return rc.entry(idx).content +} + +// entry returns the pool entry at the given index. +func (rc *rcPool[T, I]) entry(idx I) *poolEntry[T, I] { + entry := &rc.entries[idx] + if entry.refCount&freeBit != 0 { + panic("BUG: referencing a free nodePool entry " + strconv.Itoa(int(idx))) + } + return entry +} + +// replace replaces the item at the given index. +func (rc *rcPool[T, I]) replace(idx I, item T) { + entry := &rc.entries[idx] + if entry.refCount&freeBit != 0 { + panic("BUG: replace of a free rcPool[T, I] entry") + } + if entry.refCount != 1 { + panic("BUG: bad rcPool[T, I] entry refcount for replace") + } + entry.content = item +} + +// add adds an item to the rcPool and returns its index. +func (rc *rcPool[T, I]) add(item T) I { + var idx I + if rc.freeList != 0 { + idx = I(rc.freeList - 1) + rc.freeList = rc.entries[idx].refCount & freeListMask + if rc.freeList > uint32(len(rc.entries)) { + panic("BUG: bad freeList linkage") + } + rc.entries[idx].refCount = 1 + } else { + idx = I(len(rc.entries)) + rc.entries = append(rc.entries, poolEntry[T, I]{refCount: 1}) + } + rc.entries[idx].content = item + rc.allocCount.Add(1) + return idx +} + +// release releases the item at the given index. +func (rc *rcPool[T, I]) release(idx I) bool { + entry := &rc.entries[idx] + if entry.refCount&freeBit != 0 { + panic("BUG: release of a free rcPool[T, I] entry") + } + if entry.refCount <= 0 { + panic("BUG: bad rcPool[T, I] entry refcount") + } + entry.refCount-- + if entry.refCount == 0 { + if rc.freeList > uint32(len(rc.entries)) { + panic("BUG: bad freeList") + } + entry.refCount = rc.freeList | freeBit + rc.freeList = uint32(idx + 1) + rc.allocCount.Add(-1) + return true + } + + return false +} + +// ref adds a reference to the item at the given index. +func (rc *rcPool[T, I]) ref(idx I) { + rc.entries[idx].refCount++ +} + +// refCount returns the reference count for the item at the given index. +func (rc *rcPool[T, I]) refCount(idx I) uint32 { + return rc.entries[idx].refCount +} diff --git a/sync2/fptree/refcountpool_test.go b/sync2/fptree/refcountpool_test.go new file mode 100644 index 0000000000..d8fe7b8420 --- /dev/null +++ b/sync2/fptree/refcountpool_test.go @@ -0,0 +1,76 @@ +package fptree + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRCPool(t *testing.T) { + type foo struct { + //nolint:unused + x int + } + type fooIndex uint32 + + var pool rcPool[foo, fooIndex] + idx1 := pool.add(foo{x: 1}) + foo1 := pool.item(idx1) + require.Equal(t, 1, pool.count()) + idx2 := pool.add(foo{x: 2}) + foo2 := pool.item(idx2) + require.Equal(t, 2, pool.count()) + require.Equal(t, foo{x: 1}, foo1) + require.Equal(t, foo{x: 2}, foo2) + idx3 := pool.add(foo{x: 3}) + idx4 := pool.add(foo{x: 4}) + require.Equal(t, fooIndex(3), idx4) + pool.ref(idx4) + require.Equal(t, 4, pool.count()) + + require.False(t, pool.release(idx4)) + // not yet released due to an extra ref + require.Equal(t, fooIndex(4), pool.add(foo{x: 5})) + require.Equal(t, 5, pool.count()) + + require.True(t, pool.release(idx4)) + // idx4 was freed + require.Equal(t, idx4, pool.add(foo{x: 6})) + require.Equal(t, 5, pool.count()) + + // free item used just once + require.Equal(t, fooIndex(5), pool.add(foo{x: 7})) + require.Equal(t, 6, pool.count()) + + // form a free list containing several items + require.True(t, pool.release(idx3)) + require.True(t, pool.release(idx2)) + require.True(t, pool.release(idx1)) + require.Equal(t, 3, pool.count()) + + // the free list is LIFO + require.Equal(t, idx1, pool.add(foo{x: 8})) + require.Equal(t, idx2, pool.add(foo{x: 9})) + require.Equal(t, idx3, pool.add(foo{x: 10})) + require.Equal(t, 6, pool.count()) + + // the free list is exhausted + idx5 := pool.add(foo{x: 11}) + require.Equal(t, fooIndex(6), idx5) + require.Equal(t, 7, pool.count()) + + // replace the item + pool.replace(idx5, foo{x: 12}) + require.Equal(t, foo{x: 12}, pool.item(idx5)) + + // // don't replace an item with multiple refs + // pool.ref(idx5) + // idx6, replaced := pool.addOrReplace(idx5, foo{x: 13}) + // require.False(t, replaced) + // require.Equal(t, fooIndex(7), idx6) + // require.Equal(t, foo{x: 12}, pool.item(idx5)) + // require.Equal(t, foo{x: 13}, pool.item(idx6)) + + // // but failing to replace the item should have still decreased its ref count + // require.True(t, pool.release(idx5)) +} diff --git a/sync2/fptree/testtree.go b/sync2/fptree/testtree.go new file mode 100644 index 0000000000..bfce519657 --- /dev/null +++ b/sync2/fptree/testtree.go @@ -0,0 +1,97 @@ +package fptree + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// checkNode checks that the tree node at the given index is correct and also recursively +// checks its children. +func checkNode(t *testing.T, ft *FPTree, idx nodeIndex, depth int) { + left := ft.np.left(idx) + right := ft.np.right(idx) + if left == noIndex && right == noIndex { + if ft.np.count(idx) != 1 { + assert.Equal(t, depth, ft.maxDepth) + } else if ft.maxDepth == 0 && ft.idStore == nil { + assert.NotNil(t, ft.np.value(idx), "leaf node must have a value if there's no idStore") + } + } else { + if ft.maxDepth != 0 { + assert.Less(t, depth, ft.maxDepth) + } + var expFP rangesync.Fingerprint + var expCount uint32 + if left != noIndex { + checkNode(t, ft, left, depth+1) + count, fp, _ := ft.np.info(left) + expFP.Update(fp[:]) + expCount += count + } + if right != noIndex { + checkNode(t, ft, right, depth+1) + count, fp, _ := ft.np.info(right) + expFP.Update(fp[:]) + expCount += count + } + count, fp, _ := ft.np.info(idx) + assert.Equal(t, expFP, fp, "node fp at depth %d", depth) + assert.Equal(t, expCount, count, "node count at depth %d", depth) + } +} + +// CheckTree checks that the tree has correct structure. +func CheckTree(t *testing.T, ft *FPTree) { + if ft.root != noIndex { + checkNode(t, ft, ft.root, 0) + } +} + +// analyzeTreeNodeRefs checks that the reference counts in the node pool are correct. +func analyzeTreeNodeRefs(t *testing.T, np *nodePool, trees ...*FPTree) { + m := make(map[nodeIndex]map[nodeIndex]bool) + var rec func(*FPTree, nodeIndex, nodeIndex) + rec = func(ft *FPTree, idx, from nodeIndex) { + if idx == noIndex { + return + } + if _, ok := m[idx]; !ok { + m[idx] = make(map[nodeIndex]bool) + } + m[idx][from] = true + rec(ft, np.left(idx), idx) + rec(ft, np.right(idx), idx) + } + for n, ft := range trees { + treeRef := nodeIndex(-n - 1) + rec(ft, ft.root, treeRef) + } + for n, entry := range np.rcPool.entries { + if entry.refCount&freeBit != 0 { + continue + } + numTreeRefs := len(m[nodeIndex(n)]) + if numTreeRefs == 0 { + assert.Fail(t, "analyzeUnref: NOT REACHABLE", "idx: %d", n) + } else { + assert.Equal(t, numTreeRefs, int(entry.refCount), "analyzeRef: refCount for %d", n) + } + } +} + +// AnalyzeTreeNodeRefs checks that the reference counts are correct for the given trees in +// their respective node pools. +func AnalyzeTreeNodeRefs(t *testing.T, trees ...*FPTree) { + t.Helper() + // group trees by node pool they use + nodePools := make(map[*nodePool][]*FPTree) + for _, ft := range trees { + nodePools[ft.np] = append(nodePools[ft.np], ft) + } + for np, trees := range nodePools { + analyzeTreeNodeRefs(t, np, trees...) + } +} diff --git a/sync2/fptree/trace.go b/sync2/fptree/trace.go new file mode 100644 index 0000000000..6cc24f1eab --- /dev/null +++ b/sync2/fptree/trace.go @@ -0,0 +1,102 @@ +package fptree + +import ( + "fmt" + "os" + "strings" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// trace represents a logging facility for tracing FPTree operations, using indentation to +// show their nested structure. +type trace struct { + traceEnabled bool + traceStack []string +} + +func (t *trace) out(msg string) { + fmt.Fprintf(os.Stderr, "TRACE: %s%s\n", strings.Repeat(" ", len(t.traceStack)), msg) +} + +// enter marks the entry to a function, printing the log message with the given format +// string and arguments. +func (t *trace) enter(format string, args ...any) { + if !t.traceEnabled { + return + } + for n, arg := range args { + if sr, ok := arg.(rangesync.SeqResult); ok { + args[n] = formatSeqResult(sr) + } + } + msg := fmt.Sprintf(format, args...) + t.out("ENTER: " + msg) + t.traceStack = append(t.traceStack, msg) +} + +// leave marks the exit from a function, printing the results of the function call +// together with the same log message contents which was used in the corresponding enter +// call. +func (t *trace) leave(results ...any) { + if !t.traceEnabled { + return + } + if len(t.traceStack) == 0 { + panic("BUG: trace stack underflow") + } + for n, r := range results { + if err, ok := r.(error); ok { + results = []any{fmt.Sprintf("", err)} + break + } + if sr, ok := r.(rangesync.SeqResult); ok { + results[n] = formatSeqResult(sr) + } + } + msg := t.traceStack[len(t.traceStack)-1] + if len(results) != 0 { + var r []string + for _, res := range results { + r = append(r, fmt.Sprint(res)) + } + msg += " => " + strings.Join(r, ", ") + } + t.traceStack = t.traceStack[:len(t.traceStack)-1] + t.out("LEAVE: " + msg) +} + +// log prints a log message with the given format string and arguments. +func (t *trace) log(format string, args ...any) { + if t.traceEnabled { + for n, arg := range args { + if sr, ok := arg.(rangesync.SeqResult); ok { + args[n] = formatSeqResult(sr) + } + } + msg := fmt.Sprintf(format, args...) + t.out(msg) + } +} + +// seqFormatter is a lazy formatter for SeqResult. +type seqFormatter struct { + sr rangesync.SeqResult +} + +// String implements fmt.Stringer. +func (f seqFormatter) String() string { + for k := range f.sr.Seq { + return k.String() + } + if err := f.sr.Error(); err != nil { + return fmt.Sprintf("", err) + } + return "" +} + +// formatSeqResult returns a fmt.Stringer for the SeqResult that +// formats the sequence result lazily. +func formatSeqResult(sr rangesync.SeqResult) fmt.Stringer { + return seqFormatter{sr: sr} +}