diff --git a/any_table.go b/any_table.go index ebea155..4a1d419 100644 --- a/any_table.go +++ b/any_table.go @@ -31,7 +31,7 @@ func (t AnyTable) UnmarshalYAML(data []byte) (any, error) { func (t AnyTable) Insert(txn WriteTxn, obj any) (old any, hadOld bool, err error) { var iobj object - iobj, hadOld, err = txn.getTxn().insert(t.Meta, Revision(0), obj) + iobj, hadOld, _, err = txn.getTxn().insert(t.Meta, Revision(0), obj) if hadOld { old = iobj.data } diff --git a/db_test.go b/db_test.go index 4f5053a..b093efc 100644 --- a/db_test.go +++ b/db_test.go @@ -181,6 +181,32 @@ func TestDB_Insert_SamePointer(t *testing.T) { require.NoError(t, err, "Insert failed") } +func TestDB_InsertWatch(t *testing.T) { + db, table := newTestDBWithMetrics(t, &NopMetrics{}, tagsIndex) + + txn := db.WriteTxn(table) + _, _, watch, err := table.InsertWatch(txn, testObject{ID: 42, Tags: part.NewSet("hello")}) + require.NoError(t, err, "Insert failed") + txn.Commit() + + select { + case <-watch: + t.Fatal("watch channel unexpectedly closed") + default: + } + + txn = db.WriteTxn(table) + _, _, err = table.Insert(txn, testObject{ID: 42, Tags: part.NewSet("hello", "world")}) + require.NoError(t, err, "Insert failed") + txn.Commit() + + select { + case <-watch: + case <-time.After(watchCloseTimeout): + t.Fatal("watch channel not closed") + } +} + func TestDB_LowerBound_ByRevision(t *testing.T) { t.Parallel() diff --git a/part/part_test.go b/part/part_test.go index de4d2af..9ededde 100644 --- a/part/part_test.go +++ b/part/part_test.go @@ -45,9 +45,10 @@ func Test_insertion_and_watches(t *testing.T) { txn := tree.Txn() txn.Insert([]byte("abc"), 1) - txn.Insert([]byte("ab"), 2) + _, _, watch_ab := txn.InsertWatch([]byte("ab"), 2) txn.Insert([]byte("abd"), 3) tree = txn.Commit() + assertOpen(t, watch_ab) _, w, f := tree.Get([]byte("ab")) assert.True(t, f) @@ -63,6 +64,7 @@ func Test_insertion_and_watches(t *testing.T) { _, _, tree = tree.Insert([]byte("ab"), 42) assertClosed(t, w) assertClosed(t, w2) + assertClosed(t, watch_ab) assertOpen(t, w3) diff --git a/part/txn.go b/part/txn.go index 943ab23..cc417cd 100644 --- a/part/txn.go +++ b/part/txn.go @@ -49,7 +49,15 @@ func (txn *Txn[T]) Clone() *Txn[T] { // Insert or update the tree with the given key and value. // Returns the old value if it exists. func (txn *Txn[T]) Insert(key []byte, value T) (old T, hadOld bool) { - old, hadOld, txn.root = txn.insert(txn.root, key, value) + old, hadOld, _ = txn.InsertWatch(key, value) + return +} + +// Insert or update the tree with the given key and value. +// Returns the old value if it exists and a watch channel that closes when the +// key changes again. +func (txn *Txn[T]) InsertWatch(key []byte, value T) (old T, hadOld bool, watch <-chan struct{}) { + old, hadOld, watch, txn.root = txn.insert(txn.root, key, value) if !hadOld { txn.size++ } @@ -61,7 +69,17 @@ func (txn *Txn[T]) Insert(key []byte, value T) (old T, hadOld bool) { // caller to not mutate the value in-place and to return a clone. // Returns the old value if it exists. func (txn *Txn[T]) Modify(key []byte, mod func(T) T) (old T, hadOld bool) { - old, hadOld, txn.root = txn.modify(txn.root, key, mod) + old, hadOld, _ = txn.ModifyWatch(key, mod) + return +} + +// Modify a value in the tree. If the key does not exist the modify +// function is called with the zero value for T. It is up to the +// caller to not mutate the value in-place and to return a clone. +// Returns the old value if it exists and a watch channel that closes +// when the key changes again. +func (txn *Txn[T]) ModifyWatch(key []byte, mod func(T) T) (old T, hadOld bool, watch <-chan struct{}) { + old, hadOld, watch, txn.root = txn.modify(txn.root, key, mod) if !hadOld { txn.size++ } @@ -166,11 +184,11 @@ func (txn *Txn[T]) cloneNode(n *header[T]) *header[T] { return n } -func (txn *Txn[T]) insert(root *header[T], key []byte, value T) (oldValue T, hadOld bool, newRoot *header[T]) { +func (txn *Txn[T]) insert(root *header[T], key []byte, value T) (oldValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { return txn.modify(root, key, func(_ T) T { return value }) } -func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue T, hadOld bool, newRoot *header[T]) { +func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue T, hadOld bool, watch <-chan struct{}, newRoot *header[T]) { fullKey := key this := root @@ -212,8 +230,10 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue this = txn.cloneNode(this) } var zero T - this.insert(idx, newLeaf(txn.opts, key, fullKey, mod(zero)).self()) + leaf := newLeaf(txn.opts, key, fullKey, mod(zero)) + this.insert(idx, leaf.self()) *thisp = this + watch = leaf.watch return } @@ -237,7 +257,9 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue hadOld = true this = txn.cloneNode(this) *thisp = this - this.getLeaf().value = mod(oldValue) + leaf := this.getLeaf() + leaf.value = mod(oldValue) + watch = leaf.watch } else { // Partially matching prefix. newNode := &node4[T]{ @@ -253,6 +275,7 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue key = key[len(common):] var zero T newLeaf := newLeaf(txn.opts, key, fullKey, mod(zero)) + watch = newLeaf.watch // Insert the two leaves into the node we created. If one has // a key that is a subset of the other, then we can insert them @@ -298,11 +321,14 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue hadOld = true leaf = txn.cloneNode(leaf.self()).getLeaf() leaf.value = mod(oldValue) + watch = leaf.watch this.setLeaf(leaf) } else { // Set the leaf var zero T - this.setLeaf(newLeaf(txn.opts, this.prefix, fullKey, mod(zero))) + leaf := newLeaf(txn.opts, this.prefix, fullKey, mod(zero)) + watch = leaf.watch + this.setLeaf(leaf) } default: @@ -316,6 +342,7 @@ func (txn *Txn[T]) modify(root *header[T], key []byte, mod func(T) T) (oldValue var zero T newLeaf := newLeaf(txn.opts, key, fullKey, mod(zero)) + watch = newLeaf.watch newNode := &node4[T]{ header: header[T]{prefix: common}, } diff --git a/table.go b/table.go index 1d75f07..c7d41fb 100644 --- a/table.go +++ b/table.go @@ -404,8 +404,13 @@ func (t *genTable[Obj]) ListWatch(txn ReadTxn, q Query[Obj]) (iter.Seq2[Obj, Rev } func (t *genTable[Obj]) Insert(txn WriteTxn, obj Obj) (oldObj Obj, hadOld bool, err error) { + oldObj, hadOld, _, err = t.InsertWatch(txn, obj) + return +} + +func (t *genTable[Obj]) InsertWatch(txn WriteTxn, obj Obj) (oldObj Obj, hadOld bool, watch <-chan struct{}, err error) { var old object - old, hadOld, err = txn.getTxn().insert(t, Revision(0), obj) + old, hadOld, watch, err = txn.getTxn().insert(t, Revision(0), obj) if hadOld { oldObj = old.data.(Obj) } @@ -414,7 +419,7 @@ func (t *genTable[Obj]) Insert(txn WriteTxn, obj Obj) (oldObj Obj, hadOld bool, func (t *genTable[Obj]) Modify(txn WriteTxn, obj Obj, merge func(old, new Obj) Obj) (oldObj Obj, hadOld bool, err error) { var old object - old, hadOld, err = txn.getTxn().modify(t, Revision(0), obj, + old, hadOld, _, err = txn.getTxn().modify(t, Revision(0), obj, func(old any) any { return merge(old.(Obj), obj) }) @@ -426,7 +431,7 @@ func (t *genTable[Obj]) Modify(txn WriteTxn, obj Obj, merge func(old, new Obj) O func (t *genTable[Obj]) CompareAndSwap(txn WriteTxn, rev Revision, obj Obj) (oldObj Obj, hadOld bool, err error) { var old object - old, hadOld, err = txn.getTxn().insert(t, rev, obj) + old, hadOld, _, err = txn.getTxn().insert(t, rev, obj) if hadOld { oldObj = old.data.(Obj) } diff --git a/txn.go b/txn.go index c634977..6c75c88 100644 --- a/txn.go +++ b/txn.go @@ -145,20 +145,20 @@ func (txn *txn) mustIndexWriteTxn(meta TableMeta, indexPos int) indexTxn { return indexTxn } -func (txn *txn) insert(meta TableMeta, guardRevision Revision, data any) (object, bool, error) { +func (txn *txn) insert(meta TableMeta, guardRevision Revision, data any) (object, bool, <-chan struct{}, error) { return txn.modify(meta, guardRevision, data, func(_ any) any { return data }) } -func (txn *txn) modify(meta TableMeta, guardRevision Revision, newData any, merge func(any) any) (object, bool, error) { +func (txn *txn) modify(meta TableMeta, guardRevision Revision, newData any, merge func(any) any) (object, bool, <-chan struct{}, error) { if txn.modifiedTables == nil { - return object{}, false, ErrTransactionClosed + return object{}, false, nil, ErrTransactionClosed } // Look up table and allocate a new revision. tableName := meta.Name() table := txn.modifiedTables[meta.tablePos()] if table == nil { - return object{}, false, tableError(tableName, ErrTableNotLockedForWriting) + return object{}, false, nil, tableError(tableName, ErrTableNotLockedForWriting) } oldRevision := table.revision table.revision++ @@ -169,7 +169,7 @@ func (txn *txn) modify(meta TableMeta, guardRevision Revision, newData any, merg idIndexTxn := txn.mustIndexWriteTxn(meta, PrimaryIndexPos) var obj object - oldObj, oldExists := idIndexTxn.Modify(idKey, + oldObj, oldExists, watch := idIndexTxn.ModifyWatch(idKey, func(old object) object { obj = object{ revision: revision, @@ -204,7 +204,7 @@ func (txn *txn) modify(meta TableMeta, guardRevision Revision, newData any, merg // the insert. idIndexTxn.Delete(idKey) table.revision = oldRevision - return object{}, false, ErrObjectNotFound + return object{}, false, watch, ErrObjectNotFound } if oldObj.revision != guardRevision { // Revert the change. We're assuming here that it's rarer for CompareAndSwap() to @@ -212,7 +212,7 @@ func (txn *txn) modify(meta TableMeta, guardRevision Revision, newData any, merg // (versus doing a Get() and then Insert()). idIndexTxn.Insert(idKey, oldObj) table.revision = oldRevision - return oldObj, true, ErrRevisionNotEqual + return oldObj, true, watch, ErrRevisionNotEqual } } @@ -266,7 +266,7 @@ func (txn *txn) modify(meta TableMeta, guardRevision Revision, newData any, merg }) } - return oldObj, oldExists, nil + return oldObj, oldExists, watch, nil } func (txn *txn) hasDeleteTrackers(meta TableMeta) bool { diff --git a/types.go b/types.go index 5492e64..418e597 100644 --- a/types.go +++ b/types.go @@ -143,6 +143,18 @@ type RWTable[Obj any] interface { // revision. Insert(WriteTxn, Obj) (oldObj Obj, hadOld bool, err error) + // InsertWatch an object into the table. Returns the object that was + // replaced if there was one and a watch channel that closes when the + // object is modified again. + // + // Possible errors: + // - ErrTableNotLockedForWriting: table was not locked for writing + // - ErrTransactionClosed: the write transaction already committed or aborted + // + // Each inserted or updated object will be assigned a new unique + // revision. + InsertWatch(WriteTxn, Obj) (oldObj Obj, hadOld bool, watch <-chan struct{}, err error) + // Modify an existing object or insert a new object into the table. If an old object // exists the [merge] function is called with the old and new objects. //