Skip to content

Commit

Permalink
container/tree and container/xheap accept compare functions since the…
Browse files Browse the repository at this point in the history
… standard library settled on them over less
  • Loading branch information
bradenaw committed Jun 8, 2024
1 parent 5b729a3 commit c727c31
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 37 deletions.
47 changes: 24 additions & 23 deletions container/tree/btree.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package tree
import (
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/xslices"
"github.com/bradenaw/juniper/xsort"
)

// Maximum number of children each node can have.
Expand Down Expand Up @@ -35,19 +34,19 @@ const minKVs = maxKVs / 2
// - Notably, most nodes are leaves so we can do better space-wise if we can elide the children
// array from internal nodes entirely.
type btree[K, V any] struct {
root *node[K, V]
less xsort.Less[K]
size int
root *node[K, V]
compare func(K, K) int
size int
// incremented when tree structure changes - used to quickly avoid reseeking cursor moving
// through an unchanging tree
gen int
}

func newBtree[K any, V any](less xsort.Less[K]) *btree[K, V] {
func newBtree[K any, V any](compare func(K, K) int) *btree[K, V] {
return &btree[K, V]{
less: less,
root: &node[K, V]{},
size: 0,
compare: compare,
root: &node[K, V]{},
size: 0,
}
}

Expand Down Expand Up @@ -205,7 +204,7 @@ func (t *btree[K, V]) Cursor() cursor[K, V] {
func (t *btree[K, V]) insertIntoLeaf(x *node[K, V], k K, v V) {
idx := 0
for idx < int(x.n) {
if t.less(k, x.keys[idx]) {
if t.compare(k, x.keys[idx]) < 0 {
break
}
idx++
Expand All @@ -219,7 +218,7 @@ func (t *btree[K, V]) insertIntoLeaf(x *node[K, V], k K, v V) {
// adds a separater to x's parent, which may cause it to overflow and also need a split.
func (t *btree[K, V]) overfill(x *node[K, V], k K, v V, afterK *node[K, V]) {
for {
all := newAmalgam1(t.less, &x.keys, &x.values, &x.children, k, v, afterK)
all := newAmalgam1(t.compare, &x.keys, &x.values, &x.children, k, v, afterK)

left := x
right := &node[K, V]{}
Expand Down Expand Up @@ -491,9 +490,10 @@ func (t *btree[K, V]) searchNode(k K, x *node[K, V]) (idx int, inNode bool) {
// benchmark suggests that linear search is in fact faster than binary search, at least for int
// keys and branchFactor <= 32.
for i := 0; i < int(x.n); i++ {
if t.less(k, x.keys[i]) {
c := t.compare(k, x.keys[i])
if c < 0 {
return i, false
} else if !t.less(x.keys[i], k) {
} else if c == 0 {
return i, true
}
}
Expand Down Expand Up @@ -559,7 +559,7 @@ type amalgam1[K any, V any] struct {
// [a c d e]
// 0 1 2 extraChild 3
func newAmalgam1[K any, V any](
less xsort.Less[K],
compare func(K, K) int,
keys *[maxKVs]K,
values *[maxKVs]V,
children *[branchFactor]*node[K, V],
Expand All @@ -569,7 +569,7 @@ func newAmalgam1[K any, V any](
) amalgam1[K, V] {
extraIdx := func() int {
for i := range *keys {
if less(extraKey, keys[i]) {
if compare(extraKey, keys[i]) < 0 {
return i
}
}
Expand Down Expand Up @@ -736,7 +736,8 @@ func (c *cursor[K, V]) SeekLastLess(k K) {
if !c.seek(k) {
return
}
if xsort.LessOrEqual(c.t.less, k, c.k) {

if c.t.compare(k, c.k) <= 0 {
c.Prev()
}
}
Expand All @@ -745,7 +746,7 @@ func (c *cursor[K, V]) SeekLastLessOrEqual(k K) {
if !c.seek(k) {
return
}
if c.t.less(k, c.k) {
if c.t.compare(k, c.k) < 0 {
c.Prev()
}
}
Expand All @@ -754,7 +755,7 @@ func (c *cursor[K, V]) SeekFirstGreaterOrEqual(k K) {
if !c.seek(k) {
return
}
if xsort.Greater(c.t.less, k, c.k) {
if c.t.compare(k, c.k) > 0 {
c.Next()
}
}
Expand All @@ -763,7 +764,7 @@ func (c *cursor[K, V]) SeekFirstGreater(k K) {
if !c.seek(k) {
return
}
if xsort.GreaterOrEqual(c.t.less, k, c.k) {
if c.t.compare(k, c.k) >= 0 {
c.Next()
}
}
Expand Down Expand Up @@ -843,7 +844,7 @@ func (c *cursor[K, V]) lost() bool {
// zero value also. Unlinking a node during merge sets n=0, so that's handled here too.
return c.gen != c.t.gen &&
c.curr != nil &&
(c.i >= int(c.curr.n) || !xsort.Equal(c.t.less, c.k, c.curr.keys[c.i]))
(c.i >= int(c.curr.n) || c.t.compare(c.k, c.curr.keys[c.i]) != 0)
}

func (c *cursor[K, V]) Forward() iterator.Iterator[KVPair[K, V]] {
Expand Down Expand Up @@ -907,11 +908,11 @@ func (t *btree[K, V]) Range(lower Bound[K], upper Bound[K]) iterator.Iterator[KV
switch upper.type_ {
case boundInclude:
return iterator.While(c.Forward(), func(pair KVPair[K, V]) bool {
return xsort.LessOrEqual(t.less, pair.Key, upper.key)
return t.compare(pair.Key, upper.key) <= 0
})
case boundExclude:
return iterator.While(c.Forward(), func(pair KVPair[K, V]) bool {
return t.less(pair.Key, upper.key)
return t.compare(pair.Key, upper.key) < 0
})
case boundUnbounded:
return c.Forward()
Expand All @@ -935,11 +936,11 @@ func (t *btree[K, V]) RangeReverse(lower Bound[K], upper Bound[K]) iterator.Iter
switch lower.type_ {
case boundInclude:
return iterator.While(c.Backward(), func(pair KVPair[K, V]) bool {
return xsort.GreaterOrEqual(t.less, pair.Key, lower.key)
return t.compare(pair.Key, lower.key) >= 0
})
case boundExclude:
return iterator.While(c.Backward(), func(pair KVPair[K, V]) bool {
return xsort.Greater(t.less, pair.Key, lower.key)
return t.compare(pair.Key, lower.key) > 0
})
case boundUnbounded:
return c.Backward()
Expand Down
34 changes: 23 additions & 11 deletions container/tree/btree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,23 @@ import (
"github.com/bradenaw/juniper/xsort"
)

func compare[T byte | uint16](x, y T) int {
if x < y {
return -1
}
if x > y {
return 1
}
return 0
}

func orderedhashmapKVPairToKVPair[K any, V any](kv orderedhashmap.KVPair[uint16, int]) KVPair[uint16, int] {
return KVPair[uint16, int]{kv.K, kv.V}
}

func FuzzBtree(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
tree := newBtree[uint16, int](xsort.OrderedLess[uint16])
tree := newBtree[uint16, int](compare[uint16])
cursor := tree.Cursor()
cursor.SeekFirst()
oracle := orderedhashmap.NewMap[uint16, int](xsort.OrderedLess[uint16])
Expand Down Expand Up @@ -389,7 +399,7 @@ func TestRotateLeft(t *testing.T) {
}

func TestMergeMulti(t *testing.T) {
tree := newBtree[uint16, int](xsort.OrderedLess[uint16])
tree := newBtree[uint16, int](compare[uint16])
i := 0
for treeHeight(tree) < 3 {
tree.Put(uint16(i), i)
Expand Down Expand Up @@ -506,8 +516,8 @@ func requireTreesEqual(t *testing.T, a, b *btree[byte, int]) {

func makeTree(t *testing.T, root *node[byte, int]) *btree[byte, int] {
tree := &btree[byte, int]{
root: root,
less: xsort.OrderedLess[byte],
root: root,
compare: compare[byte],
}
tree.size = numItems(tree)
checkTree(t, tree)
Expand Down Expand Up @@ -623,8 +633,8 @@ func checkTree[K comparable, V comparable](t *testing.T, tree *btree[K, V]) {
left := x.children[i]
right := x.children[i+1]
k := x.keys[i]
require2.True(t, tree.less(left.keys[int(left.n)-1], k))
require2.True(t, tree.less(k, right.keys[0]))
require2.Less(t, tree.compare(left.keys[int(left.n)-1], k), 0)
require2.Less(t, tree.compare(k, right.keys[0]), 0)
}
}
if x == tree.root {
Expand All @@ -634,7 +644,9 @@ func checkTree[K comparable, V comparable](t *testing.T, tree *btree[K, V]) {
} else {
require2.GreaterOrEqual(t, int(x.n), minKVs)
}
require2.True(t, xsort.SliceIsSorted(x.keys[:int(x.n)], tree.less))
require2.True(t, xsort.SliceIsSorted(x.keys[:int(x.n)], func(a, b K) bool {
return tree.compare(a, b) < 0
}))
require2.True(t, xslices.All(x.keys[int(x.n):], isZero[K]))
require2.True(t, xslices.All(x.values[int(x.n):], isZero[V]))
require2.Truef(
Expand Down Expand Up @@ -740,7 +752,7 @@ func TestAmalgam1(t *testing.T) {
)

a := newAmalgam1(
xsort.OrderedLess[byte],
compare[byte],
&keys,
&values,
&children,
Expand Down Expand Up @@ -773,7 +785,7 @@ func TestAmalgam1(t *testing.T) {
}

func TestRange(t *testing.T) {
tree := newBtree[uint16, int](xsort.OrderedLess[uint16])
tree := newBtree[uint16, int](compare[uint16])

for i := 0; i < 128; i++ {
tree.Put(uint16(i), i)
Expand Down Expand Up @@ -808,7 +820,7 @@ func TestRange(t *testing.T) {
}

func TestGetContains(t *testing.T) {
tree := newBtree[uint16, int](xsort.OrderedLess[uint16])
tree := newBtree[uint16, int](compare[uint16])

for i := 0; i < 128; i++ {
tree.Put(uint16(i*2), i*4)
Expand All @@ -826,7 +838,7 @@ func TestGetContains(t *testing.T) {
}

func TestDelete(t *testing.T) {
tree := newBtree[uint16, int](xsort.OrderedLess[uint16])
tree := newBtree[uint16, int](compare[uint16])
for i := 0; i < 128; i++ {
tree.Put(uint16(i)+1, i*2)
}
Expand Down
8 changes: 7 additions & 1 deletion container/tree/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ type Map[K any, V any] struct {
// pair of keys while they are in the map.
func NewMap[K any, V any](less xsort.Less[K]) Map[K, V] {
return Map[K, V]{
t: newBtree[K, V](less),
t: newBtree[K, V](xsort.LessCompare(less)),
}
}

func NewMapCmp[K any, V any](compare func(K, K) int) Map[K, V] {
return Map[K, V]{
t: newBtree[K, V](compare),
}
}

Expand Down
8 changes: 7 additions & 1 deletion container/tree/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ type Set[T any] struct {
// any pair of items while they are in the set.
func NewSet[T any](less xsort.Less[T]) Set[T] {
return Set[T]{
t: newBtree[T, struct{}](less),
t: newBtree[T, struct{}](xsort.LessCompare(less)),
}
}

func NewSetCmp[T any](compare func(T, T) int) Set[T] {
return Set[T]{
t: newBtree[T, struct{}](compare),
}
}

Expand Down
15 changes: 15 additions & 0 deletions container/xheap/xheap.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ func New[T any](less xsort.Less[T], initial []T) Heap[T] {
}
}

func NewCmp[T any](compare func(T, T) int, initial []T) Heap[T] {
return New(func(a, b T) bool {
return compare(a, b) < 0
}, initial)
}

// Len returns the current number of elements in the heap.
func (h Heap[T]) Len() int {
return h.inner.Len()
Expand Down Expand Up @@ -129,6 +135,15 @@ func NewPriorityQueue[K comparable, P any](
return h
}

func NewPriorityQueueCmp[K comparable, P any](
compare func(P, P) int,
initial []KP[K, P],
) PriorityQueue[K, P] {
return NewPriorityQueue(func(a, b P) bool {
return compare(a, b) < 0
}, initial)
}

// Len returns the current number of elements in the priority queue.
func (h PriorityQueue[K, P]) Len() int {
return h.inner.Len()
Expand Down
2 changes: 1 addition & 1 deletion test_all_versions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ go_versions=(1.18 1.19 1.20 1.21)

latest="${go_versions[-1]}"
if ! go version | grep "go$latest"; then
echo >2 "go version expected $latest, got $(go version)"
echo >&2 "go version expected $latest, got $(go version)"
exit 1
fi

Expand Down
13 changes: 13 additions & 0 deletions xsort/xsort.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ func Reverse[T any](less Less[T]) Less[T] {
}
}

// LessCompare returns a comparison function (as defined in [sort.SortFunc]) that matches less.
func LessCompare[T any](less Less[T]) func(T, T) int {
return func(a, b T) int {
if less(a, b) {
return -1
} else if less(b, a) {
return 1
} else {
return 0
}
}
}

// Slice sorts x in-place using the given less function to compare items.
//
// Follows the same rules as sort.Slice.
Expand Down
27 changes: 27 additions & 0 deletions xsort/xsort_old.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,30 @@ import (
func OrderedLess[T constraints.Ordered](a, b T) bool {
return a < b
}

func Compare[T constraints.Ordered](x, y T) int {
// Copied from the standard library, here for versions older than 1.21 when it was added.
// https://cs.opensource.google/go/go/+/refs/tags/go1.22.4:src/cmp/cmp.go;l=40

xNaN := isNaN(x)
yNaN := isNaN(y)
if xNaN && yNaN {
return 0
}
if xNaN || x < y {
return -1
}
if yNaN || x > y {
return +1
}
return 0
}

// Copied from the standard library, here for versions older than 1.21 when it was added.
// https://cs.opensource.google/go/go/+/refs/tags/go1.22.4:src/cmp/cmp.go;l=40
//
// isNaN reports whether x is a NaN without requiring the math package.
// This will always return false if T is not floating-point.
func isNaN[T constraints.Ordered](x T) bool {
return x != x
}

0 comments on commit c727c31

Please sign in to comment.