From 28d80e981b3f06a87fc36e29ef45bee580572370 Mon Sep 17 00:00:00 2001 From: CharlesCheung <61726649+CharlesCheung96@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:35:38 +0800 Subject: [PATCH] better kv map & fix some tests (#254) * use more safety less function in btree_map * fix some test --- heartbeatpb/table_span.go | 20 +- maintainer/maintainer.go | 2 +- maintainer/scheduler_test.go | 6 +- maintainer/split/splitter_test.go | 2 +- pkg/eventservice/event_broker_test.go | 9 +- pkg/eventservice/event_service.go | 12 +- .../event_service_performance_test.go | 14 +- pkg/eventservice/event_service_test.go | 329 ++++++++++-------- utils/btree_map.go | 79 ++--- 9 files changed, 247 insertions(+), 226 deletions(-) diff --git a/heartbeatpb/table_span.go b/heartbeatpb/table_span.go index 7fe71f8a..21b3e385 100644 --- a/heartbeatpb/table_span.go +++ b/heartbeatpb/table_span.go @@ -10,23 +10,25 @@ var DDLSpanSchemaID int64 = 0 // DDLSpan is the special span for Table Trigger Event Dispatcher var DDLSpan = &TableSpan{TableID: 0, StartKey: nil, EndKey: nil} +func LessTableSpan(t1, t2 *TableSpan) bool { + return t1.Less(t2) +} + // Less compares two Spans, defines the order between spans. -func (s *TableSpan) Less(other any) bool { - tbl := other.(*TableSpan) - if s.TableID < tbl.TableID { +func (s *TableSpan) Less(other *TableSpan) bool { + if s.TableID < other.TableID { return true } - if bytes.Compare(s.StartKey, tbl.StartKey) < 0 { + if bytes.Compare(s.StartKey, other.StartKey) < 0 { return true } return false } -func (s *TableSpan) Equal(inferior any) bool { - tbl := inferior.(*TableSpan) - return s.TableID == tbl.TableID && - bytes.Equal(s.StartKey, tbl.StartKey) && - bytes.Equal(s.EndKey, tbl.EndKey) +func (s *TableSpan) Equal(other *TableSpan) bool { + return s.TableID == other.TableID && + bytes.Equal(s.StartKey, other.StartKey) && + bytes.Equal(s.EndKey, other.EndKey) } func (s *TableSpan) Copy() *TableSpan { diff --git a/maintainer/maintainer.go b/maintainer/maintainer.go index 92b3c926..c240ff89 100644 --- a/maintainer/maintainer.go +++ b/maintainer/maintainer.go @@ -495,7 +495,7 @@ func (m *Maintainer) onBootstrapDone(cachedResp map[common.NodeID]*heartbeatpb.M if stm.State == scheduler.SchedulerStatusWorking { tableMap, ok := workingMap[span.TableID] if !ok { - tableMap = utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine]() + tableMap = utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine](heartbeatpb.LessTableSpan) workingMap[span.TableID] = tableMap } tableMap.ReplaceOrInsert(span, stm) diff --git a/maintainer/scheduler_test.go b/maintainer/scheduler_test.go index ab2849f9..82242247 100644 --- a/maintainer/scheduler_test.go +++ b/maintainer/scheduler_test.go @@ -233,7 +233,7 @@ func TestFinishBootstrap(t *testing.T) { DDLStatus: nil, }, }, NewReplicaSet(model.ChangeFeedID{}, dispatcherID2, 1, span, 1)) - cached := utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine]() + cached := utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine](heartbeatpb.LessTableSpan) cached.ReplaceOrInsert(span, stm2) require.False(t, s.bootstrapped) s.FinishBootstrap(map[uint64]utils.Map[*heartbeatpb.TableSpan, *scheduler.StateMachine]{ @@ -334,7 +334,7 @@ func TestSplitTableWhenBootstrapFinished(t *testing.T) { {TableID: 1, StartKey: appendNew(totalSpan.StartKey, 'a'), EndKey: appendNew(totalSpan.StartKey, 'b')}, // 1 region // 1 region {TableID: 1, StartKey: appendNew(totalSpan.StartKey, 'b'), EndKey: appendNew(totalSpan.StartKey, 'c')}, } - cached := utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine]() + cached := utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine](heartbeatpb.LessTableSpan) for _, span := range reportedSpans { dispatcherID1 := common.NewDispatcherID() stm1 := scheduler.NewStateMachine(dispatcherID1, map[model.CaptureID]scheduler.InferiorStatus{ @@ -355,7 +355,7 @@ func TestSplitTableWhenBootstrapFinished(t *testing.T) { CheckpointTs: 10, }, }, NewReplicaSet(model.ChangeFeedID{}, ddlDispatcherID, heartbeatpb.DDLSpanSchemaID, heartbeatpb.DDLSpan, 1)) - ddlCache := utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine]() + ddlCache := utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine](heartbeatpb.LessTableSpan) ddlCache.ReplaceOrInsert(heartbeatpb.DDLSpan, ddlStm) require.False(t, s.bootstrapped) diff --git a/maintainer/split/splitter_test.go b/maintainer/split/splitter_test.go index 3c24e969..13e8baaf 100644 --- a/maintainer/split/splitter_test.go +++ b/maintainer/split/splitter_test.go @@ -87,7 +87,7 @@ func TestMapFindHole(t *testing.T) { } for i, cs := range cases { - m := utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine]() + m := utils.NewBtreeMap[*heartbeatpb.TableSpan, *scheduler.StateMachine](heartbeatpb.LessTableSpan) for _, span := range cs.spans { m.ReplaceOrInsert(span, &scheduler.StateMachine{}) } diff --git a/pkg/eventservice/event_broker_test.go b/pkg/eventservice/event_broker_test.go index 424a1827..2c7b623b 100644 --- a/pkg/eventservice/event_broker_test.go +++ b/pkg/eventservice/event_broker_test.go @@ -26,6 +26,7 @@ func TestNewDispatcherStat(t *testing.T) { require.Equal(t, startTs, stat.spanSubscription.watermark.Load()) require.Equal(t, 0, int(stat.spanSubscription.newEventCount.Load())) require.NotEmpty(t, stat.workerIndex) + require.Nil(t, stat.filter) } func TestDispatcherStatUpdateWatermark(t *testing.T) { @@ -91,6 +92,7 @@ func TestDispatcherStatUpdateWatermark(t *testing.T) { wg.Wait() } + func TestScanTaskPool_PushTask(t *testing.T) { pool := newScanTaskPool() span := newTableSpan(1, "a", "b") @@ -150,20 +152,19 @@ func TestScanTaskPool_PushTask(t *testing.T) { receivedTask := <-pool.pendingTaskQueue[dispatcherStat.workerIndex] require.Equal(t, expectedTask, receivedTask) - // Verify that the task is set to nil in the taskSet + // Verify that the task is removed from taskSet task, ok = pool.taskSet[dispatcherInfo.GetID()] - require.True(t, ok) + require.False(t, ok) require.Nil(t, task) } func newTableSpan(tableID uint64, start, end string) *heartbeatpb.TableSpan { - res := &heartbeatpb.TableSpan{ + return &heartbeatpb.TableSpan{ TableID: tableID, StartKey: []byte(start), EndKey: []byte(end), } - return res } func TestResolvedTsCache(t *testing.T) { diff --git a/pkg/eventservice/event_service.go b/pkg/eventservice/event_service.go index ed6ede5f..3096af02 100644 --- a/pkg/eventservice/event_service.go +++ b/pkg/eventservice/event_service.go @@ -21,14 +21,6 @@ const ( defaultScanWorkerCount = 8192 ) -// EventService accepts the requests of pulling events. -// The EventService is a singleton in the system. -type EventService interface { - Name() string - Run(ctx context.Context) error - Close(context.Context) error -} - type DispatcherInfo interface { // GetID returns the ID of the dispatcher. GetID() common.DispatcherID @@ -43,6 +35,8 @@ type DispatcherInfo interface { GetFilterConfig() *config.FilterConfig } +// EventService accepts the requests of pulling events. +// The EventService is a singleton in the system. type eventService struct { mc messaging.MessageCenter eventStore eventstore.EventStore @@ -54,7 +48,7 @@ type eventService struct { tz *time.Location } -func NewEventService() EventService { +func NewEventService() common.SubModule { mc := appcontext.GetService[messaging.MessageCenter](appcontext.MessageCenter) eventStore := appcontext.GetService[eventstore.EventStore](appcontext.EventStore) schemaStore := appcontext.GetService[schemastore.SchemaStore](appcontext.SchemaStore) diff --git a/pkg/eventservice/event_service_performance_test.go b/pkg/eventservice/event_service_performance_test.go index 7be9bb3d..9e10cc1f 100644 --- a/pkg/eventservice/event_service_performance_test.go +++ b/pkg/eventservice/event_service_performance_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/flowbehappy/tigate/pkg/common" - appcontext "github.com/flowbehappy/tigate/pkg/common/context" "github.com/flowbehappy/tigate/pkg/messaging" "github.com/pingcap/log" "go.uber.org/zap" @@ -48,18 +47,7 @@ func TestEventServiceOneMillionTable(t *testing.T) { } }() - appcontext.SetService(appcontext.MessageCenter, mc) - appcontext.SetService(appcontext.EventStore, mockStore) - es := NewEventService() - esImpl := es.(*eventService) - wg.Add(1) - go func() { - defer wg.Done() - err := es.Run(ctx) - if err != nil { - t.Errorf("EventService.Run() error = %v", err) - } - }() + esImpl := initEventService(ctx, t, mc, mockStore) start := time.Now() dispatchers := make([]DispatcherInfo, 0, tableNum) diff --git a/pkg/eventservice/event_service_test.go b/pkg/eventservice/event_service_test.go index 480d8b63..70078104 100644 --- a/pkg/eventservice/event_service_test.go +++ b/pkg/eventservice/event_service_test.go @@ -11,11 +11,14 @@ import ( "github.com/flowbehappy/tigate/downstreamadapter/eventcollector" "github.com/flowbehappy/tigate/downstreamadapter/sink" "github.com/flowbehappy/tigate/downstreamadapter/writer" + "github.com/flowbehappy/tigate/eventpb" "github.com/flowbehappy/tigate/heartbeatpb" "github.com/flowbehappy/tigate/logservice/eventstore" + "github.com/flowbehappy/tigate/logservice/schemastore" "github.com/flowbehappy/tigate/pkg/common" appcontext "github.com/flowbehappy/tigate/pkg/common/context" "github.com/flowbehappy/tigate/pkg/config" + "github.com/flowbehappy/tigate/pkg/filter" "github.com/flowbehappy/tigate/pkg/messaging" "github.com/pingcap/log" "github.com/pingcap/tiflow/cdc/model" @@ -24,6 +27,159 @@ import ( "go.uber.org/zap" ) +func initEventService( + ctx context.Context, t *testing.T, + mc messaging.MessageCenter, mockStore eventstore.EventStore, +) *eventService { + appcontext.SetService(appcontext.MessageCenter, mc) + appcontext.SetService(appcontext.EventStore, mockStore) + appcontext.SetService(appcontext.SchemaStore, newMockSchemaStore()) + es := NewEventService() + esImpl := es.(*eventService) + go func() { + err := esImpl.Run(ctx) + if err != nil { + t.Errorf("EventService.Run() error = %v", err) + } + }() + return esImpl +} + +func TestEventServiceBasic(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockStore := newMockEventStore() + mc := &mockMessageCenter{ + messageCh: make(chan *messaging.TargetMessage, 100), + } + esImpl := initEventService(ctx, t, mc, mockStore) + + acceptorInfo := newMockAcceptorInfo(common.NewDispatcherID(), 1) + // register acceptor + esImpl.acceptorInfoCh <- acceptorInfo + // wait for eventService to process the acceptorInfo + time.Sleep(time.Second * 2) + + require.Equal(t, 1, len(esImpl.brokers)) + require.NotNil(t, esImpl.brokers[acceptorInfo.GetClusterID()]) + + // add events to logpuller + txnEvent := &common.TxnEvent{ + DispatcherID: acceptorInfo.GetID(), + Span: acceptorInfo.span, + StartTs: 1, + CommitTs: 5, + Rows: []*common.RowChangedEvent{ + { + PhysicalTableID: 1, + StartTs: 1, + CommitTs: 5, + }, + { + PhysicalTableID: 1, + StartTs: 1, + CommitTs: 5, + }, + }, + } + + sourceSpanStat, ok := mockStore.spans[acceptorInfo.span.TableID] + require.True(t, ok) + + sourceSpanStat.update([]*common.TxnEvent{txnEvent}, txnEvent.CommitTs) + + expectedEvent := &common.TxnEvent{ + DispatcherID: acceptorInfo.GetID(), + StartTs: 1, + CommitTs: 5, + Rows: []*common.RowChangedEvent{ + { + PhysicalTableID: 1, + StartTs: 1, + CommitTs: 5, + }, + { + PhysicalTableID: 1, + StartTs: 1, + CommitTs: 5, + }, + }, + } + + // receive events from msg center + for { + msg := <-mc.messageCh + txn := msg.Message[0].(*common.TxnEvent) + if len(txn.Rows) == 0 { + log.Info("received watermark", zap.Uint64("ts", txn.ResolvedTs)) + continue + } + require.NotNil(t, msg) + require.Equal(t, "event-collector", msg.Topic) + require.Equal(t, expectedEvent, msg.Message[0].(*common.TxnEvent)) + return + } +} + +// The test mainly focus on the communication between dispatcher and event service. +// When dispatcher created and register in event service, event service need to send events to dispatcher. +func TestDispatcherCommunicateWithEventService(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverId := messaging.NewServerId() + mc := messaging.NewMessageCenter(ctx, serverId, 1, config.NewDefaultMessageCenterConfig()) + mockStore := newMockEventStore() + _ = initEventService(ctx, t, mc, mockStore) + appcontext.SetService(appcontext.EventCollector, eventcollector.NewEventCollector(100*1024*1024*1024, serverId)) // 100GB for demo + + db, _ := newTestMockDB(t) + defer db.Close() + + mysqlSink := sink.NewMysqlSink(model.DefaultChangeFeedID("test1"), 8, writer.NewMysqlConfig(), db) + tableSpan := &heartbeatpb.TableSpan{TableID: 1, StartKey: nil, EndKey: nil} + startTs := uint64(1) + id := common.NewDispatcherID() + tableEventDispatcher := dispatcher.NewDispatcher(id, tableSpan, mysqlSink, startTs, nil, nil, 0) + appcontext.GetService[*eventcollector.EventCollector](appcontext.EventCollector).RegisterDispatcher( + eventcollector.RegisterInfo{ + Dispatcher: tableEventDispatcher, + StartTs: startTs, + FilterConfig: &eventpb.FilterConfig{Rules: []string{"*.*"}}, + }, + ) + + time.Sleep(1 * time.Second) + // add events to logpuller + txnEvent := &common.TxnEvent{ + ClusterID: 1, + Span: tableSpan, + StartTs: 1, + CommitTs: 5, + Rows: []*common.RowChangedEvent{ + { + PhysicalTableID: 1, + StartTs: 1, + CommitTs: 5, + }, + }, + } + + sourceSpanStat, ok := mockStore.spans[tableSpan.TableID] + require.True(t, ok) + + sourceSpanStat.update([]*common.TxnEvent{txnEvent}, txnEvent.CommitTs) + + // <-tableEventDispatcher.GetEventChan() +} + +func newTestMockDB(t *testing.T) (db *sql.DB, mock sqlmock.Sqlmock) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + require.Nil(t, err) + return +} + var _ messaging.MessageCenter = &mockMessageCenter{} // mockMessageCenter is a mock implementation of the MessageCenter interface @@ -120,7 +276,9 @@ func (m *mockDispatcherInfo) GetChangefeedID() (namespace, id string) { } func (m *mockDispatcherInfo) GetFilterConfig() *tconfig.FilterConfig { - return nil + return &tconfig.FilterConfig{ + Rules: []string{"*.*"}, + } } type mockSpanStats struct { @@ -143,6 +301,8 @@ func (m *mockSpanStats) update(event []*common.TxnEvent, watermark uint64) { } +var _ eventstore.EventStore = &mockEventStore{} + // mockEventStore is a mock implementation of the EventStore interface type mockEventStore struct { spans map[uint64]*mockSpanStats @@ -192,7 +352,7 @@ func (m *mockEventStore) UnregisterDispatcher(dispatcherID common.DispatcherID) return nil } -func (m *mockEventStore) GetIterator(dataRange *common.DataRange) (eventstore.EventIterator, error) { +func (m *mockEventStore) GetIterator(dispatcherID common.DispatcherID, dataRange *common.DataRange) (eventstore.EventIterator, error) { iter := &mockEventIterator{ events: make([]*common.TxnEvent, 0), } @@ -295,153 +455,36 @@ func TestMockEventIterator(t *testing.T) { require.NotNil(t, row) } -func TestEventServiceBasic(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockStore := newMockEventStore() - mc := &mockMessageCenter{ - messageCh: make(chan *messaging.TargetMessage, 100), - } - - appcontext.SetService(appcontext.MessageCenter, mc) - appcontext.SetService(appcontext.EventStore, mockStore) - es := NewEventService() - esImpl := es.(*eventService) - go func() { - err := es.Run(ctx) - if err != nil { - t.Errorf("EventService.Run() error = %v", err) - } - }() - - acceptorInfo := newMockAcceptorInfo(common.NewDispatcherID(), 1) - // register acceptor - esImpl.acceptorInfoCh <- acceptorInfo - // wait for eventService to process the acceptorInfo - time.Sleep(time.Second * 2) - - require.Equal(t, 1, len(esImpl.brokers)) - require.NotNil(t, esImpl.brokers[acceptorInfo.GetClusterID()]) +type mockSchemaStore struct { + schemastore.SchemaStore - // add events to logpuller - txnEvent := &common.TxnEvent{ - DispatcherID: acceptorInfo.GetID(), - Span: acceptorInfo.span, - StartTs: 1, - CommitTs: 5, - Rows: []*common.RowChangedEvent{ - { - PhysicalTableID: 1, - StartTs: 1, - CommitTs: 5, - }, - { - PhysicalTableID: 1, - StartTs: 1, - CommitTs: 5, - }, - }, - } - - sourceSpanStat, ok := mockStore.spans[acceptorInfo.span.TableID] - require.True(t, ok) - - sourceSpanStat.update([]*common.TxnEvent{txnEvent}, txnEvent.CommitTs) - - expectedEvent := &common.TxnEvent{ - DispatcherID: acceptorInfo.GetID(), - StartTs: 1, - CommitTs: 5, - Rows: []*common.RowChangedEvent{ - { - PhysicalTableID: 1, - StartTs: 1, - CommitTs: 5, - }, - { - PhysicalTableID: 1, - StartTs: 1, - CommitTs: 5, - }, - }, - } + dispatchers map[common.DispatcherID]common.TableID + resolvedTs uint64 +} - // receive events from msg center - for { - msg := <-mc.messageCh - txn := msg.Message[0].(*common.TxnEvent) - if len(txn.Rows) == 0 { - log.Info("received watermark", zap.Uint64("ts", txn.ResolvedTs)) - continue - } - require.NotNil(t, msg) - require.Equal(t, acceptorInfo.GetTopic(), msg.Topic) - require.Equal(t, expectedEvent, msg.Message[0].(*common.TxnEvent)) - return +func newMockSchemaStore() *mockSchemaStore { + return &mockSchemaStore{ + dispatchers: make(map[common.DispatcherID]common.TableID), + resolvedTs: 0, } } -func newTestMockDB(t *testing.T) (db *sql.DB, mock sqlmock.Sqlmock) { - db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) - require.Nil(t, err) - return +func (m *mockSchemaStore) RegisterDispatcher( + dispatcherID common.DispatcherID, span *heartbeatpb.TableSpan, + startTS common.Ts, filter filter.Filter, +) error { + m.dispatchers[dispatcherID] = common.TableID(span.TableID) + return nil } -// The test mainly focus on the communication between dispatcher and event service. -// When dispatcher created and register in event service, event service need to send events to dispatcher. -func TestDispatcherCommunicateWithEventService(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverId := messaging.NewServerId() - appcontext.SetService(appcontext.MessageCenter, messaging.NewMessageCenter(ctx, serverId, 1, config.NewDefaultMessageCenterConfig())) - appcontext.SetService(appcontext.EventCollector, eventcollector.NewEventCollector(100*1024*1024*1024, serverId)) // 100GB for demo - - mockStore := newMockEventStore() - appcontext.SetService(appcontext.EventStore, mockStore) - eventService := NewEventService() - go func() { - err := eventService.Run(ctx) - if err != nil { - t.Errorf("EventService.Run() error = %v", err) - } - }() - - db, _ := newTestMockDB(t) - defer db.Close() - - mysqlSink := sink.NewMysqlSink(model.DefaultChangeFeedID("test1"), 8, writer.NewMysqlConfig(), db) - tableSpan := &heartbeatpb.TableSpan{TableID: 1, StartKey: nil, EndKey: nil} - startTs := uint64(1) - id := common.NewDispatcherID() - tableEventDispatcher := dispatcher.NewDispatcher(id, tableSpan, mysqlSink, startTs, nil, nil, 0) - appcontext.GetService[*eventcollector.EventCollector](appcontext.EventCollector).RegisterDispatcher( - eventcollector.RegisterInfo{ - Dispatcher: tableEventDispatcher, - StartTs: startTs, - FilterConfig: nil, - }, - ) - - time.Sleep(1 * time.Second) - // add events to logpuller - txnEvent := &common.TxnEvent{ - ClusterID: 1, - Span: tableSpan, - StartTs: 1, - CommitTs: 5, - Rows: []*common.RowChangedEvent{ - { - PhysicalTableID: 1, - }, - }, - } - - sourceSpanStat, ok := mockStore.spans[tableSpan.TableID] - require.True(t, ok) - - sourceSpanStat.update([]*common.TxnEvent{txnEvent}, txnEvent.CommitTs) +func (m *mockSchemaStore) UnregisterDispatcher(dispatcherID common.DispatcherID) error { + delete(m.dispatchers, dispatcherID) + return nil +} - // <-tableEventDispatcher.GetEventChan() +func (m *mockSchemaStore) GetNextDDLEvents(id common.TableID, start, end common.Ts) ([]common.DDLEvent, common.Ts, error) { + return nil, end, nil +} +func (m *mockSchemaStore) GetNextTableTriggerEvents(f filter.Filter, start common.Ts, limit int) ([]common.DDLEvent, common.Ts, error) { + return nil, m.resolvedTs, nil } diff --git a/utils/btree_map.go b/utils/btree_map.go index 1b9132f4..d961d70c 100644 --- a/utils/btree_map.go +++ b/utils/btree_map.go @@ -17,75 +17,68 @@ import ( "github.com/google/btree" ) -// MapKey is the comparable key of the map -type MapKey interface { - Less(other any) bool -} +const defaultDegree = 16 // ItemIterator iterates the map, return false to stop the iteration -type ItemIterator[Key MapKey, Value any] func(key Key, value Value) bool +type ItemIterator[KeyT, ValueT any] func(KeyT, ValueT) bool // Map is the general interface of a map -type Map[Key MapKey, Value any] interface { +type Map[KeyT, ValueT any] interface { Len() int - Has(Key) bool - Get(Key) (Value, bool) - Delete(Key) (Value, bool) - ReplaceOrInsert(Key, Value) (Value, bool) - Ascend(iterator ItemIterator[Key, Value]) -} - -// Item is a btree item that wraps a (key) and an item (value). -type Item[Key MapKey, T any] struct { - Key Key - Value T + Has(KeyT) bool + Get(KeyT) (ValueT, bool) + Delete(KeyT) (ValueT, bool) + ReplaceOrInsert(KeyT, ValueT) (ValueT, bool) + Ascend(iterator ItemIterator[KeyT, ValueT]) } -// lessItem compares two Spans, defines the order between spans. -func lessItem[Key MapKey, T any](a, b Item[Key, T]) bool { - return a.Key.Less(b.Key) +// Item is a btree item that wraps a (key) and an item (value). +type Item[KeyT, ValueT any] struct { + k KeyT + v ValueT } // BtreeMap is a specialized btree map that map a Span to a value. -type BtreeMap[Key MapKey, T any] struct { - tree *btree.BTreeG[Item[Key, T]] +type BtreeMap[KeyT, ValueT any] struct { + tree *btree.BTreeG[Item[KeyT, ValueT]] } // NewBtreeMap returns a new BtreeMap. -func NewBtreeMap[Key MapKey, T any]() *BtreeMap[Key, T] { - const defaultDegree = 16 - return NewBtreeMapWithDegree[Key, T](defaultDegree) +func NewBtreeMap[KeyT, ValueT any](lessKeyF func(KeyT, KeyT) bool) *BtreeMap[KeyT, ValueT] { + return NewBtreeMapWithDegree[KeyT, ValueT](defaultDegree, nil) } // NewBtreeMapWithDegree returns a new BtreeMap with the given degree. -func NewBtreeMapWithDegree[Key MapKey, T any](degree int) *BtreeMap[Key, T] { - return &BtreeMap[Key, T]{ - tree: btree.NewG(degree, lessItem[Key, T]), +func NewBtreeMapWithDegree[KeyT, ValueT any](degree int, lessKey func(KeyT, KeyT) bool) *BtreeMap[KeyT, ValueT] { + return &BtreeMap[KeyT, ValueT]{ + tree: btree.NewG(degree, func(a, b Item[KeyT, ValueT]) bool { + return lessKey(a.k, b.k) + }), } } // Len returns the number of items currently in the tree. -func (m *BtreeMap[Key, T]) Len() int { +func (m *BtreeMap[KeyT, ValueT]) Len() int { return m.tree.Len() } // Has returns true if the given key is in the tree. -func (m *BtreeMap[Key, T]) Has(key Key) bool { - return m.tree.Has(Item[Key, T]{Key: key}) +func (m *BtreeMap[KeyT, ValueT]) Has(key KeyT) bool { + return m.tree.Has(Item[KeyT, ValueT]{k: key}) } // Get looks for the key item in the tree, returning it. // It returns (zeroValue, false) if unable to find that item. -func (m *BtreeMap[Key, T]) Get(key Key) (T, bool) { - item, ok := m.tree.Get(Item[Key, T]{Key: key}) - return item.Value, ok +func (m *BtreeMap[KeyT, ValueT]) Get(key KeyT) (ValueT, bool) { + item, ok := m.tree.Get(Item[KeyT, ValueT]{k: key}) + return item.v, ok } // Delete removes an item equal to the passed in item from the tree, returning // it. If no such item exists, returns (zeroValue, false). -func (m *BtreeMap[Key, T]) Delete(key Key) (T, bool) { - item, ok := m.tree.Delete(Item[Key, T]{Key: key}) - return item.Value, ok +func (m *BtreeMap[KeyT, ValueT]) Delete(key KeyT) (ValueT, bool) { + item, ok := m.tree.Delete(Item[KeyT, ValueT]{k: key}) + return item.v, ok } // ReplaceOrInsert adds the given item to the tree. If an item in the tree @@ -93,15 +86,15 @@ func (m *BtreeMap[Key, T]) Delete(key Key) (T, bool) { // and the second return value is true. Otherwise, (zeroValue, false) // // nil cannot be added to the tree (will panic). -func (m *BtreeMap[Key, T]) ReplaceOrInsert(key Key, value T) (T, bool) { - old, ok := m.tree.ReplaceOrInsert(Item[Key, T]{Key: key, Value: value}) - return old.Value, ok +func (m *BtreeMap[KeyT, ValueT]) ReplaceOrInsert(key KeyT, value ValueT) (ValueT, bool) { + old, ok := m.tree.ReplaceOrInsert(Item[KeyT, ValueT]{k: key, v: value}) + return old.v, ok } // Ascend calls the iterator for every value in the tree within the range // [first, last], until iterator returns false. -func (m *BtreeMap[Key, T]) Ascend(iterator ItemIterator[Key, T]) { - m.tree.Ascend(func(item Item[Key, T]) bool { - return iterator(item.Key, item.Value) +func (m *BtreeMap[KeyT, ValueT]) Ascend(iterator ItemIterator[KeyT, ValueT]) { + m.tree.Ascend(func(item Item[KeyT, ValueT]) bool { + return iterator(item.k, item.v) }) }