diff --git a/hashset/hashset.go b/hashset/hashset.go new file mode 100644 index 0000000..76fde67 --- /dev/null +++ b/hashset/hashset.go @@ -0,0 +1,47 @@ +// Package hashset provides a map-based Set. +package hashset + +import "iter" + +// HashSet is a generic set based on a hash table (map). +type HashSet[T comparable] struct { + m map[T]struct{} +} + +// New creates a new HashSet. +func New[T comparable]() *HashSet[T] { + return &HashSet[T]{m: make(map[T]struct{})} +} + +// Add adds a value to the set. +func (hs *HashSet[T]) Add(val T) { + hs.m[val] = struct{}{} +} + +// Contains reports whether the set contains the given value. +func (hs *HashSet[T]) Contains(val T) bool { + _, ok := hs.m[val] + return ok +} + +// Len returns the size/length of the set - the number of values it contains. +func (hs *HashSet[T]) Len() int { + return len(hs.m) +} + +// Delete removes a value from the set; if the value doesn't exist in the +// set, this is a no-op. +func (hs *HashSet[T]) Delete(val T) { + delete(hs.m, val) +} + +// All returns an iterator over all the values in the set. +func (hs *HashSet[T]) All() iter.Seq[T] { + return func(yield func(T) bool) { + for val := range hs.m { + if !yield(val) { + return + } + } + } +} diff --git a/hashset/hashset_test.go b/hashset/hashset_test.go new file mode 100644 index 0000000..9061f3e --- /dev/null +++ b/hashset/hashset_test.go @@ -0,0 +1,96 @@ +package hashset + +import ( + "slices" + "testing" +) + +func TestAll(t *testing.T) { + hs := New[int]() + + checkAll := func(wantSorted []int) { + t.Helper() + if hs.Len() != len(wantSorted) { + t.Errorf("got len=%v, want %v", hs.Len(), len(wantSorted)) + } + + got := slices.Sorted(hs.All()) + if !slices.Equal(got, wantSorted) { + t.Errorf("got %v, want %v", got, wantSorted) + } + } + checkAll([]int{}) + hs.Add(10) + checkAll([]int{10}) + + hs.Add(20) + hs.Add(13) + checkAll([]int{10, 13, 20}) + + hs.Add(18) + checkAll([]int{10, 13, 18, 20}) + + hs.Delete(18) + checkAll([]int{10, 13, 20}) + hs.Delete(10) + checkAll([]int{13, 20}) + + hs.Add(50) + hs.Add(5) + checkAll([]int{5, 13, 20, 50}) + + hs.Add(60) + hs.Add(60) + hs.Add(60) + checkAll([]int{5, 13, 20, 50, 60}) + + hs.Delete(60) + hs.Delete(60) + hs.Delete(60) + checkAll([]int{5, 13, 20, 50}) + + hs.Delete(50) + hs.Delete(20) + hs.Delete(5) + checkAll([]int{13}) + + hs.Delete(13) + checkAll([]int{}) +} + +func TestContains(t *testing.T) { + hs := New[string]() + + checkContains := func(v string, want bool) { + t.Helper() + got := hs.Contains(v) + if got != want { + t.Errorf("contains(%v)=%v, want %v", v, got, want) + } + } + + checkContains("joe", false) + hs.Add("joe") + checkContains("joe", true) + hs.Delete("joe") + checkContains("joe", false) + + hs.Add("bee") + hs.Add("geranium") + checkContains("joe", false) + checkContains("bee", true) + checkContains("geranium", true) + + hs.Add("cheese") + hs.Add("io") + hs.Add("joe") + + for _, v := range []string{"joe", "bee", "geranium", "io", "cheese"} { + checkContains(v, true) + hs.Delete(v) + } + + for _, v := range []string{"joe", "bee", "geranium", "io", "cheese"} { + checkContains(v, false) + } +}