Skip to content

Commit

Permalink
New Union and UnitWith methods for hashset
Browse files Browse the repository at this point in the history
  • Loading branch information
eliben committed Sep 6, 2024
1 parent 3bec5df commit 368c0a6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 22 deletions.
21 changes: 21 additions & 0 deletions hashset/hashset.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ func New[T comparable]() *HashSet[T] {
return &HashSet[T]{m: make(map[T]struct{})}
}

// InitWith creates a new HashSet initialized with vals.
func InitWith[T comparable](vals ...T) *HashSet[T] {
hs := New[T]()
for _, v := range vals {
hs.Add(v)
}
return hs
}

// Add adds a value to the set.
func (hs *HashSet[T]) Add(val T) {
hs.m[val] = struct{}{}
Expand Down Expand Up @@ -45,3 +54,15 @@ func (hs *HashSet[T]) All() iter.Seq[T] {
}
}
}

// Union returns the set union of hs with other. It creates a new set.
func (hs *HashSet[T]) Union(other *HashSet[T]) *HashSet[T] {
result := New[T]()
for v := range hs.m {
result.Add(v)
}
for v := range other.m {
result.Add(v)
}
return result
}
59 changes: 37 additions & 22 deletions hashset/hashset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,58 @@ import (
"testing"
)

func checkAll(t *testing.T, hs *HashSet[int], wantSorted []int) {
t.Helper()
if hs.Len() != len(wantSorted) {
t.Errorf("got len=%v, want %v", hs.Len(), len(wantSorted))
}

got := slices.Sorted(hs.All())
if !slices.Equal(got, wantSorted) {
t.Errorf("got %v, want %v", got, wantSorted)
}
}

func TestAll(t *testing.T) {
hs := New[int]()

checkAll := func(wantSorted []int) {
t.Helper()
if hs.Len() != len(wantSorted) {
t.Errorf("got len=%v, want %v", hs.Len(), len(wantSorted))
}

got := slices.Sorted(hs.All())
if !slices.Equal(got, wantSorted) {
t.Errorf("got %v, want %v", got, wantSorted)
}
}
checkAll([]int{})
checkAll(t, hs, []int{})
hs.Add(10)
checkAll([]int{10})
checkAll(t, hs, []int{10})

hs.Add(20)
hs.Add(13)
checkAll([]int{10, 13, 20})
checkAll(t, hs, []int{10, 13, 20})

hs.Add(18)
checkAll([]int{10, 13, 18, 20})
checkAll(t, hs, []int{10, 13, 18, 20})

hs.Delete(18)
checkAll([]int{10, 13, 20})
checkAll(t, hs, []int{10, 13, 20})
hs.Delete(10)
checkAll([]int{13, 20})
checkAll(t, hs, []int{13, 20})

hs.Add(50)
hs.Add(5)
checkAll([]int{5, 13, 20, 50})
checkAll(t, hs, []int{5, 13, 20, 50})

hs.Add(60)
hs.Add(60)
hs.Add(60)
checkAll([]int{5, 13, 20, 50, 60})
checkAll(t, hs, []int{5, 13, 20, 50, 60})

hs.Delete(60)
hs.Delete(60)
hs.Delete(60)
checkAll([]int{5, 13, 20, 50})
checkAll(t, hs, []int{5, 13, 20, 50})

hs.Delete(50)
hs.Delete(20)
hs.Delete(5)
checkAll([]int{13})
checkAll(t, hs, []int{13})

hs.Delete(13)
checkAll([]int{})
checkAll(t, hs, []int{})
}

func TestContains(t *testing.T) {
Expand Down Expand Up @@ -94,3 +95,17 @@ func TestContains(t *testing.T) {
checkContains(v, false)
}
}

func TestUnion(t *testing.T) {
hs1 := InitWith(10, 20, 30, 40)
hs2 := InitWith(11, 21, 30, 41)

u1 := hs1.Union(hs2)
checkAll(t, u1, []int{10, 11, 20, 21, 30, 40, 41})

u2 := hs1.Union(InitWith(20))
checkAll(t, u2, []int{10, 20, 30, 40})

u3 := hs1.Union(InitWith(90))
checkAll(t, u3, []int{10, 20, 30, 40, 90})
}

0 comments on commit 368c0a6

Please sign in to comment.