From c727c31ab629ce52efcbe8c35fcdc52e8c456b0a Mon Sep 17 00:00:00 2001 From: Braden Walker Date: Sat, 8 Jun 2024 00:20:22 -0400 Subject: [PATCH] container/tree and container/xheap accept compare functions since the standard library settled on them over less --- container/tree/btree.go | 47 ++++++++++++++++++------------------ container/tree/btree_test.go | 34 +++++++++++++++++--------- container/tree/map.go | 8 +++++- container/tree/set.go | 8 +++++- container/xheap/xheap.go | 15 ++++++++++++ test_all_versions.sh | 2 +- xsort/xsort.go | 13 ++++++++++ xsort/xsort_old.go | 27 +++++++++++++++++++++ 8 files changed, 117 insertions(+), 37 deletions(-) diff --git a/container/tree/btree.go b/container/tree/btree.go index f2c6960..051ede4 100644 --- a/container/tree/btree.go +++ b/container/tree/btree.go @@ -3,7 +3,6 @@ package tree import ( "github.com/bradenaw/juniper/iterator" "github.com/bradenaw/juniper/xslices" - "github.com/bradenaw/juniper/xsort" ) // Maximum number of children each node can have. @@ -35,19 +34,19 @@ const minKVs = maxKVs / 2 // - Notably, most nodes are leaves so we can do better space-wise if we can elide the children // array from internal nodes entirely. type btree[K, V any] struct { - root *node[K, V] - less xsort.Less[K] - size int + root *node[K, V] + compare func(K, K) int + size int // incremented when tree structure changes - used to quickly avoid reseeking cursor moving // through an unchanging tree gen int } -func newBtree[K any, V any](less xsort.Less[K]) *btree[K, V] { +func newBtree[K any, V any](compare func(K, K) int) *btree[K, V] { return &btree[K, V]{ - less: less, - root: &node[K, V]{}, - size: 0, + compare: compare, + root: &node[K, V]{}, + size: 0, } } @@ -205,7 +204,7 @@ func (t *btree[K, V]) Cursor() cursor[K, V] { func (t *btree[K, V]) insertIntoLeaf(x *node[K, V], k K, v V) { idx := 0 for idx < int(x.n) { - if t.less(k, x.keys[idx]) { + if t.compare(k, x.keys[idx]) < 0 { break } idx++ @@ -219,7 +218,7 @@ func (t *btree[K, V]) insertIntoLeaf(x *node[K, V], k K, v V) { // adds a separater to x's parent, which may cause it to overflow and also need a split. func (t *btree[K, V]) overfill(x *node[K, V], k K, v V, afterK *node[K, V]) { for { - all := newAmalgam1(t.less, &x.keys, &x.values, &x.children, k, v, afterK) + all := newAmalgam1(t.compare, &x.keys, &x.values, &x.children, k, v, afterK) left := x right := &node[K, V]{} @@ -491,9 +490,10 @@ func (t *btree[K, V]) searchNode(k K, x *node[K, V]) (idx int, inNode bool) { // benchmark suggests that linear search is in fact faster than binary search, at least for int // keys and branchFactor <= 32. for i := 0; i < int(x.n); i++ { - if t.less(k, x.keys[i]) { + c := t.compare(k, x.keys[i]) + if c < 0 { return i, false - } else if !t.less(x.keys[i], k) { + } else if c == 0 { return i, true } } @@ -559,7 +559,7 @@ type amalgam1[K any, V any] struct { // [a c d e] // 0 1 2 extraChild 3 func newAmalgam1[K any, V any]( - less xsort.Less[K], + compare func(K, K) int, keys *[maxKVs]K, values *[maxKVs]V, children *[branchFactor]*node[K, V], @@ -569,7 +569,7 @@ func newAmalgam1[K any, V any]( ) amalgam1[K, V] { extraIdx := func() int { for i := range *keys { - if less(extraKey, keys[i]) { + if compare(extraKey, keys[i]) < 0 { return i } } @@ -736,7 +736,8 @@ func (c *cursor[K, V]) SeekLastLess(k K) { if !c.seek(k) { return } - if xsort.LessOrEqual(c.t.less, k, c.k) { + + if c.t.compare(k, c.k) <= 0 { c.Prev() } } @@ -745,7 +746,7 @@ func (c *cursor[K, V]) SeekLastLessOrEqual(k K) { if !c.seek(k) { return } - if c.t.less(k, c.k) { + if c.t.compare(k, c.k) < 0 { c.Prev() } } @@ -754,7 +755,7 @@ func (c *cursor[K, V]) SeekFirstGreaterOrEqual(k K) { if !c.seek(k) { return } - if xsort.Greater(c.t.less, k, c.k) { + if c.t.compare(k, c.k) > 0 { c.Next() } } @@ -763,7 +764,7 @@ func (c *cursor[K, V]) SeekFirstGreater(k K) { if !c.seek(k) { return } - if xsort.GreaterOrEqual(c.t.less, k, c.k) { + if c.t.compare(k, c.k) >= 0 { c.Next() } } @@ -843,7 +844,7 @@ func (c *cursor[K, V]) lost() bool { // zero value also. Unlinking a node during merge sets n=0, so that's handled here too. return c.gen != c.t.gen && c.curr != nil && - (c.i >= int(c.curr.n) || !xsort.Equal(c.t.less, c.k, c.curr.keys[c.i])) + (c.i >= int(c.curr.n) || c.t.compare(c.k, c.curr.keys[c.i]) != 0) } func (c *cursor[K, V]) Forward() iterator.Iterator[KVPair[K, V]] { @@ -907,11 +908,11 @@ func (t *btree[K, V]) Range(lower Bound[K], upper Bound[K]) iterator.Iterator[KV switch upper.type_ { case boundInclude: return iterator.While(c.Forward(), func(pair KVPair[K, V]) bool { - return xsort.LessOrEqual(t.less, pair.Key, upper.key) + return t.compare(pair.Key, upper.key) <= 0 }) case boundExclude: return iterator.While(c.Forward(), func(pair KVPair[K, V]) bool { - return t.less(pair.Key, upper.key) + return t.compare(pair.Key, upper.key) < 0 }) case boundUnbounded: return c.Forward() @@ -935,11 +936,11 @@ func (t *btree[K, V]) RangeReverse(lower Bound[K], upper Bound[K]) iterator.Iter switch lower.type_ { case boundInclude: return iterator.While(c.Backward(), func(pair KVPair[K, V]) bool { - return xsort.GreaterOrEqual(t.less, pair.Key, lower.key) + return t.compare(pair.Key, lower.key) >= 0 }) case boundExclude: return iterator.While(c.Backward(), func(pair KVPair[K, V]) bool { - return xsort.Greater(t.less, pair.Key, lower.key) + return t.compare(pair.Key, lower.key) > 0 }) case boundUnbounded: return c.Backward() diff --git a/container/tree/btree_test.go b/container/tree/btree_test.go index e9b15ea..486280e 100644 --- a/container/tree/btree_test.go +++ b/container/tree/btree_test.go @@ -13,13 +13,23 @@ import ( "github.com/bradenaw/juniper/xsort" ) +func compare[T byte | uint16](x, y T) int { + if x < y { + return -1 + } + if x > y { + return 1 + } + return 0 +} + func orderedhashmapKVPairToKVPair[K any, V any](kv orderedhashmap.KVPair[uint16, int]) KVPair[uint16, int] { return KVPair[uint16, int]{kv.K, kv.V} } func FuzzBtree(f *testing.F) { f.Fuzz(func(t *testing.T, b []byte) { - tree := newBtree[uint16, int](xsort.OrderedLess[uint16]) + tree := newBtree[uint16, int](compare[uint16]) cursor := tree.Cursor() cursor.SeekFirst() oracle := orderedhashmap.NewMap[uint16, int](xsort.OrderedLess[uint16]) @@ -389,7 +399,7 @@ func TestRotateLeft(t *testing.T) { } func TestMergeMulti(t *testing.T) { - tree := newBtree[uint16, int](xsort.OrderedLess[uint16]) + tree := newBtree[uint16, int](compare[uint16]) i := 0 for treeHeight(tree) < 3 { tree.Put(uint16(i), i) @@ -506,8 +516,8 @@ func requireTreesEqual(t *testing.T, a, b *btree[byte, int]) { func makeTree(t *testing.T, root *node[byte, int]) *btree[byte, int] { tree := &btree[byte, int]{ - root: root, - less: xsort.OrderedLess[byte], + root: root, + compare: compare[byte], } tree.size = numItems(tree) checkTree(t, tree) @@ -623,8 +633,8 @@ func checkTree[K comparable, V comparable](t *testing.T, tree *btree[K, V]) { left := x.children[i] right := x.children[i+1] k := x.keys[i] - require2.True(t, tree.less(left.keys[int(left.n)-1], k)) - require2.True(t, tree.less(k, right.keys[0])) + require2.Less(t, tree.compare(left.keys[int(left.n)-1], k), 0) + require2.Less(t, tree.compare(k, right.keys[0]), 0) } } if x == tree.root { @@ -634,7 +644,9 @@ func checkTree[K comparable, V comparable](t *testing.T, tree *btree[K, V]) { } else { require2.GreaterOrEqual(t, int(x.n), minKVs) } - require2.True(t, xsort.SliceIsSorted(x.keys[:int(x.n)], tree.less)) + require2.True(t, xsort.SliceIsSorted(x.keys[:int(x.n)], func(a, b K) bool { + return tree.compare(a, b) < 0 + })) require2.True(t, xslices.All(x.keys[int(x.n):], isZero[K])) require2.True(t, xslices.All(x.values[int(x.n):], isZero[V])) require2.Truef( @@ -740,7 +752,7 @@ func TestAmalgam1(t *testing.T) { ) a := newAmalgam1( - xsort.OrderedLess[byte], + compare[byte], &keys, &values, &children, @@ -773,7 +785,7 @@ func TestAmalgam1(t *testing.T) { } func TestRange(t *testing.T) { - tree := newBtree[uint16, int](xsort.OrderedLess[uint16]) + tree := newBtree[uint16, int](compare[uint16]) for i := 0; i < 128; i++ { tree.Put(uint16(i), i) @@ -808,7 +820,7 @@ func TestRange(t *testing.T) { } func TestGetContains(t *testing.T) { - tree := newBtree[uint16, int](xsort.OrderedLess[uint16]) + tree := newBtree[uint16, int](compare[uint16]) for i := 0; i < 128; i++ { tree.Put(uint16(i*2), i*4) @@ -826,7 +838,7 @@ func TestGetContains(t *testing.T) { } func TestDelete(t *testing.T) { - tree := newBtree[uint16, int](xsort.OrderedLess[uint16]) + tree := newBtree[uint16, int](compare[uint16]) for i := 0; i < 128; i++ { tree.Put(uint16(i)+1, i*2) } diff --git a/container/tree/map.go b/container/tree/map.go index cc79bf0..56805c5 100644 --- a/container/tree/map.go +++ b/container/tree/map.go @@ -24,7 +24,13 @@ type Map[K any, V any] struct { // pair of keys while they are in the map. func NewMap[K any, V any](less xsort.Less[K]) Map[K, V] { return Map[K, V]{ - t: newBtree[K, V](less), + t: newBtree[K, V](xsort.LessCompare(less)), + } +} + +func NewMapCmp[K any, V any](compare func(K, K) int) Map[K, V] { + return Map[K, V]{ + t: newBtree[K, V](compare), } } diff --git a/container/tree/set.go b/container/tree/set.go index dab9441..bf0c33e 100644 --- a/container/tree/set.go +++ b/container/tree/set.go @@ -17,7 +17,13 @@ type Set[T any] struct { // any pair of items while they are in the set. func NewSet[T any](less xsort.Less[T]) Set[T] { return Set[T]{ - t: newBtree[T, struct{}](less), + t: newBtree[T, struct{}](xsort.LessCompare(less)), + } +} + +func NewSetCmp[T any](compare func(T, T) int) Set[T] { + return Set[T]{ + t: newBtree[T, struct{}](compare), } } diff --git a/container/xheap/xheap.go b/container/xheap/xheap.go index 2fdcd44..ef18b92 100644 --- a/container/xheap/xheap.go +++ b/container/xheap/xheap.go @@ -37,6 +37,12 @@ func New[T any](less xsort.Less[T], initial []T) Heap[T] { } } +func NewCmp[T any](compare func(T, T) int, initial []T) Heap[T] { + return New(func(a, b T) bool { + return compare(a, b) < 0 + }, initial) +} + // Len returns the current number of elements in the heap. func (h Heap[T]) Len() int { return h.inner.Len() @@ -129,6 +135,15 @@ func NewPriorityQueue[K comparable, P any]( return h } +func NewPriorityQueueCmp[K comparable, P any]( + compare func(P, P) int, + initial []KP[K, P], +) PriorityQueue[K, P] { + return NewPriorityQueue(func(a, b P) bool { + return compare(a, b) < 0 + }, initial) +} + // Len returns the current number of elements in the priority queue. func (h PriorityQueue[K, P]) Len() int { return h.inner.Len() diff --git a/test_all_versions.sh b/test_all_versions.sh index 7be6ab7..f8fa3bc 100644 --- a/test_all_versions.sh +++ b/test_all_versions.sh @@ -6,7 +6,7 @@ go_versions=(1.18 1.19 1.20 1.21) latest="${go_versions[-1]}" if ! go version | grep "go$latest"; then - echo >2 "go version expected $latest, got $(go version)" + echo >&2 "go version expected $latest, got $(go version)" exit 1 fi diff --git a/xsort/xsort.go b/xsort/xsort.go index 9490192..9723c5f 100644 --- a/xsort/xsort.go +++ b/xsort/xsort.go @@ -47,6 +47,19 @@ func Reverse[T any](less Less[T]) Less[T] { } } +// LessCompare returns a comparison function (as defined in [sort.SortFunc]) that matches less. +func LessCompare[T any](less Less[T]) func(T, T) int { + return func(a, b T) int { + if less(a, b) { + return -1 + } else if less(b, a) { + return 1 + } else { + return 0 + } + } +} + // Slice sorts x in-place using the given less function to compare items. // // Follows the same rules as sort.Slice. diff --git a/xsort/xsort_old.go b/xsort/xsort_old.go index 3b749ae..9c8f88c 100644 --- a/xsort/xsort_old.go +++ b/xsort/xsort_old.go @@ -10,3 +10,30 @@ import ( func OrderedLess[T constraints.Ordered](a, b T) bool { return a < b } + +func Compare[T constraints.Ordered](x, y T) int { + // Copied from the standard library, here for versions older than 1.21 when it was added. + // https://cs.opensource.google/go/go/+/refs/tags/go1.22.4:src/cmp/cmp.go;l=40 + + xNaN := isNaN(x) + yNaN := isNaN(y) + if xNaN && yNaN { + return 0 + } + if xNaN || x < y { + return -1 + } + if yNaN || x > y { + return +1 + } + return 0 +} + +// Copied from the standard library, here for versions older than 1.21 when it was added. +// https://cs.opensource.google/go/go/+/refs/tags/go1.22.4:src/cmp/cmp.go;l=40 +// +// isNaN reports whether x is a NaN without requiring the math package. +// This will always return false if T is not floating-point. +func isNaN[T constraints.Ordered](x T) bool { + return x != x +}