From 121f7cfb5ff4ec9afa2244b32578bccf61d933ce Mon Sep 17 00:00:00 2001 From: leaxoy Date: Wed, 3 Jan 2024 16:58:51 +0800 Subject: [PATCH] feat(maps, sets): update map, set --- maps/maps.go | 39 ++++++++----- sets/set.go | 131 +++++++++++++++++-------------------------- sets/set_test.go | 141 +++++++++++------------------------------------ 3 files changed, 109 insertions(+), 202 deletions(-) diff --git a/maps/maps.go b/maps/maps.go index c504009..0ab4a35 100644 --- a/maps/maps.go +++ b/maps/maps.go @@ -2,6 +2,7 @@ package maps import ( "github.com/go-board/std/iter" + "github.com/go-board/std/iter/collector" "github.com/go-board/std/tuple" ) @@ -12,7 +13,7 @@ func (e MapEntry[K, V]) Key() K { return e.inner.First } func (e MapEntry[K, V]) Value() V { return e.inner.Second } func entry[K, V any](key K, value V) MapEntry[K, V] { - return MapEntry[K, V]{inner: tuple.PairOf(key, value)} + return MapEntry[K, V]{inner: tuple.MakePair(key, value)} } // Entries returns all entry of a map as an [iter.Seq] @@ -26,7 +27,12 @@ func Entries[K comparable, V any, M ~map[K]V](m M) iter.Seq[MapEntry[K, V]] { } } -// Keys return key slice of a map. +// EntrySlice return entry slice of a map. +func EntrySlice[K comparable, V any, M ~map[K]V](m M) []MapEntry[K, V] { + return collector.Collect(Entries(m), collector.ToSlice[MapEntry[K, V]]()) +} + +// Keys return key's [iter.Seq] of a map. func Keys[K comparable, V any, M ~map[K]V](m M) iter.Seq[K] { return func(yield func(K) bool) { for k := range m { @@ -37,7 +43,12 @@ func Keys[K comparable, V any, M ~map[K]V](m M) iter.Seq[K] { } } -// Values returns value slice of a map. +// KeySlice return key slice of a map. +func KeySlice[K comparable, V any, M ~map[K]V](m M) []K { + return collector.Collect(Keys(m), collector.ToSlice[K]()) +} + +// Values return value's [iter.Seq] of a map. func Values[K comparable, V any, M ~map[K]V](m M) iter.Seq[V] { return func(yield func(V) bool) { for _, v := range m { @@ -48,13 +59,15 @@ func Values[K comparable, V any, M ~map[K]V](m M) iter.Seq[V] { } } +// ValueSlice return value slice of a map. +func ValueSlice[K comparable, V any, M ~map[K]V](m M) []V { + return collector.Collect(Values(m), collector.ToSlice[V]()) +} + +// Collect collects [iter.Seq] into a map func Collect[K comparable, V any](s iter.Seq[MapEntry[K, V]]) map[K]V { - m := make(map[K]V) - iter.CollectFunc(s, func(x MapEntry[K, V]) bool { - m[x.Key()] = x.Value() - return true - }) - return m + extract := func(e MapEntry[K, V]) (K, V) { return e.Key(), e.Value() } + return collector.Collect(s, collector.ToMap(extract)) } // ForEach iter over the map, and call the udf on each k-v pair. @@ -121,10 +134,10 @@ func MergeFunc[K comparable, V any, M ~map[K]V](ms iter.Seq[M], onConflict func( } // Invert maps k-v to v-k, when key conflict, the back element will overwrite the previous one. -func Invert[K, V comparable, M1 ~map[K]V, M2 ~map[V]K](m M1) M2 { - m2 := make(M2) +func Invert[K, V comparable, M ~map[K]V](m M) map[V]K { + rs := make(map[V]K) for k, v := range m { - m2[v] = k + rs[v] = k } - return m2 + return rs } diff --git a/sets/set.go b/sets/set.go index 1ca9b44..f5bc788 100644 --- a/sets/set.go +++ b/sets/set.go @@ -1,8 +1,6 @@ package sets import ( - "encoding/json" - "github.com/go-board/std/core" "github.com/go-board/std/iter" ) @@ -11,52 +9,42 @@ var unit = struct{}{} type HashSet[E comparable] struct{ inner map[E]core.Unit } -var _set HashSet[core.Unit] - -var ( - _ json.Marshaler = _set - _ json.Unmarshaler = _set -) +func New[E comparable]() HashSet[E] { + return HashSet[E]{inner: make(map[E]core.Unit)} +} // FromSlice returns a new empty hash set. func FromSlice[E comparable](elements ...E) HashSet[E] { - inner := make(map[E]core.Unit, len(elements)) - for _, element := range elements { - inner[element] = unit + set := New[E]() + for _, elem := range elements { + set.Add(elem) } - return HashSet[E]{inner: inner} + return set } -func FromMapKeys[E comparable, TValue any, M ~map[E]TValue](m M) HashSet[E] { - inner := make(map[E]core.Unit, len(m)) +func FromMapKeys[E comparable, V any, M ~map[E]V](m M) HashSet[E] { + set := New[E]() for key := range m { - inner[key] = unit + set.Add(key) } - return HashSet[E]{inner: inner} + return set } // FromIter create a HashSet from [Seq]. func FromIter[E comparable](s iter.Seq[E]) HashSet[E] { - set := HashSet[E]{inner: make(map[E]core.Unit)} + set := New[E]() set.AddIter(s) return set } -// Add adds the given keys to the set. -func (self HashSet[E]) Add(keys ...E) { - for _, key := range keys { - self.inner[key] = unit - } -} - -// AddAll adds all elements from another [HashSet]. -func (self HashSet[E]) AddAll(other HashSet[E]) { - self.AddIter(other.Iter()) +// Add adds the given key to the set. +func (self HashSet[E]) Add(key E) { + self.inner[key] = unit } -// AddIter adds all elements in Iter. -func (self HashSet[E]) AddIter(s iter.Seq[E]) { - iter.ForEach(s, func(e E) { self.inner[e] = unit }) +// AddIter adds all elements in [iter.Seq] to the set. +func (self HashSet[E]) AddIter(it iter.Seq[E]) { + iter.ForEach(it, self.Add) } // Remove removes the given key from the set. @@ -64,6 +52,7 @@ func (self HashSet[E]) Remove(key E) { delete(self.inner, key) } +// RemoveIter removes all elements in [iter.Seq]. func (self HashSet[E]) RemoveIter(it iter.Seq[E]) { iter.ForEach(it, self.Remove) } // Clear removes all keys from the set. @@ -73,18 +62,13 @@ func (self HashSet[E]) Clear() { } } -func (self HashSet[E]) Filter(fn func(E) bool) HashSet[E] { - return FromIter(iter.Filter(self.Iter(), fn)) -} - +// Retain keep element that match the given predicate function. +// +// Otherwise, remove from [HashSet]. func (self HashSet[E]) Retain(fn func(E) bool) { iter.ForEach(iter.Filter(self.Iter(), func(e E) bool { return !fn(e) }), self.Remove) } -func (self HashSet[E]) Map(fn func(E) E) HashSet[E] { - return FromIter(iter.Map(self.Iter(), fn)) -} - // Contains returns true if the given key is in the set. func (self HashSet[E]) Contains(key E) bool { _, exists := self.inner[key] @@ -93,12 +77,12 @@ func (self HashSet[E]) Contains(key E) bool { // ContainsAll returns true if all the given keys are in the set. func (self HashSet[E]) ContainsAll(keys iter.Seq[E]) bool { - return iter.All(keys, func(e E) bool { return self.Contains(e) }) + return iter.All(keys, self.Contains) } // ContainsAny returns true if any of the given keys are in the set. func (self HashSet[E]) ContainsAny(keys iter.Seq[E]) bool { - return iter.Any(keys, func(e E) bool { return self.Contains(e) }) + return iter.Any(keys, self.Contains) } // Size returns the number of elements in the set. @@ -112,11 +96,7 @@ func (self HashSet[E]) IsEmpty() bool { } func (self HashSet[E]) ToMap() map[E]struct{} { - m := make(map[E]struct{}, len(self.inner)) - for k, v := range self.inner { - m[k] = v - } - return m + return self.Clone().inner } // Clone returns a copy of the set. @@ -126,11 +106,7 @@ func (self HashSet[E]) Clone() HashSet[E] { // DeepCloneBy returns a copy of the set and clone each element use given clone func. func (self HashSet[E]) DeepCloneBy(clone func(E) E) HashSet[E] { - other := FromSlice[E]() - for key := range self.inner { - other.Add(clone(key)) - } - return other + return FromIter(iter.Map(self.Iter(), clone)) } // SupersetOf returns true if the given set is a superset of this set. @@ -138,15 +114,23 @@ func (self HashSet[E]) SupersetOf(other HashSet[E]) bool { return iter.All(other.Iter(), self.Contains) } +func (self HashSet[E]) SupersetOfIter(it iter.Seq[E]) bool { + return iter.All(it, self.Contains) +} + // SubsetOf returns true if the given set is a subset of this set. func (self HashSet[E]) SubsetOf(other HashSet[E]) bool { return iter.All(self.Iter(), other.Contains) } +func (self HashSet[E]) SubsetOfIter(it iter.Seq[E]) bool { + return iter.All(self.Iter(), FromIter(it).Contains) +} + // Union returns a new set containing all the elements that are in either set. func (self HashSet[E]) Union(other HashSet[E]) HashSet[E] { union := self.Clone() - union.AddAll(other) + union.AddIter(other.Iter()) return union } @@ -156,35 +140,31 @@ func (self HashSet[E]) UnionIter(it iter.Seq[E]) HashSet[E] { return union } -// UnionAssign union another [HashSet] into self -func (self HashSet[E]) UnionAssign(other HashSet[E]) { - self.AddAll(other) -} - // Intersection returns a new set containing all the elements that are in both sets. func (self HashSet[E]) Intersection(other HashSet[E]) HashSet[E] { return FromIter(iter.Filter(self.Iter(), other.Contains)) } +func (self HashSet[E]) IntersectionIter(it iter.Seq[E]) HashSet[E] { + return self.Intersection(FromIter(it)) +} + // Difference returns a new set containing all the elements that are in this set but not in the other set. func (self HashSet[E]) Difference(other HashSet[E]) HashSet[E] { return FromIter(iter.Filter(self.Iter(), func(e E) bool { return !other.Contains(e) })) } +func (self HashSet[E]) DifferenceIter(it iter.Seq[E]) HashSet[E] { + return self.Difference(FromIter(it)) +} + // SymmetricDifference returns a new set containing all the elements that are in this set or the other set but not in both. func (self HashSet[E]) SymmetricDifference(other HashSet[E]) HashSet[E] { - diff := FromSlice[E]() - for key := range self.inner { - if !other.Contains(key) { - diff.Add(key) - } - } - for key := range other.inner { - if !self.Contains(key) { - diff.Add(key) - } - } - return diff + return self.Union(other).Difference(self.Intersection(other)) +} + +func (self HashSet[E]) SymmetricDifferenceIter(it iter.Seq[E]) HashSet[E] { + return self.SymmetricDifference(FromIter(it)) } // Equal returns true if the given set is equal to this set. @@ -213,26 +193,17 @@ func (self HashSet[E]) Iter() iter.Seq[E] { func (self HashSet[E]) IterMut() iter.Seq[*SetItem[E]] { return func(yield func(*SetItem[E]) bool) { for key := range self.inner { - if !yield(&SetItem[E]{item: key, s: self}) { + if !yield(&SetItem[E]{item: key, set: self}) { break } } } } -func (self HashSet[E]) MarshalJSON() ([]byte, error) { - return json.Marshal(self.inner) -} - -func (self HashSet[E]) UnmarshalJSON(v []byte) error { - return json.Unmarshal(v, &self.inner) -} - type SetItem[E comparable] struct { item E - s HashSet[E] + set HashSet[E] } -func (s *SetItem[E]) Remove() { s.s.Remove(s.item) } - func (s *SetItem[E]) Value() E { return s.item } +func (s *SetItem[E]) Remove() { s.set.Remove(s.item) } diff --git a/sets/set_test.go b/sets/set_test.go index 16e1c64..9a2447b 100644 --- a/sets/set_test.go +++ b/sets/set_test.go @@ -1,17 +1,15 @@ package sets import ( - "encoding/json" "testing" - "github.com/go-board/std/iter" - "github.com/frankban/quicktest" + "github.com/go-board/std/iter" ) -func seq[E any, S ~[]E](s S) iter.Seq[E] { +func seq[E any](elems ...E) iter.Seq[E] { return func(yield func(E) bool) { - for _, e := range s { + for _, e := range elems { if !yield(e) { break } @@ -26,27 +24,15 @@ func (i item) Clone() item { return item{key: i.key} } func TestHashSet_Add(t *testing.T) { a := quicktest.New(t) s := FromSlice[int]() - s.Add(1, 2) + s.AddIter(seq(1, 2)) a.Assert(s.Contains(1), quicktest.IsTrue) a.Assert(s.Contains(2), quicktest.IsTrue) a.Assert(s.Contains(3), quicktest.IsFalse) } -func TestHashSet_AddAll(t *testing.T) { - a := quicktest.New(t) - s := FromSlice[int]() - s.AddAll(FromSlice(1, 2)) - a.Assert(s.Contains(1), quicktest.IsTrue) - a.Assert(s.Contains(2), quicktest.IsTrue) - a.Assert(s.Contains(3), quicktest.IsFalse) - a.Assert(s.Size(), quicktest.Equals, 2) -} - func TestHashSet_Remove(t *testing.T) { a := quicktest.New(t) - s := FromSlice[int]() - s.Add(1) - s.Add(2) + s := FromSlice[int](1, 2) s.Remove(1) a.Assert(s.Contains(1), quicktest.IsFalse) a.Assert(s.Contains(2), quicktest.IsTrue) @@ -55,9 +41,7 @@ func TestHashSet_Remove(t *testing.T) { func TestHashSet_Clear(t *testing.T) { a := quicktest.New(t) - s := FromSlice[int]() - s.Add(1) - s.Add(2) + s := FromSlice[int](1, 2) s.Clear() a.Assert(s.Contains(1), quicktest.IsFalse) a.Assert(s.Contains(2), quicktest.IsFalse) @@ -66,9 +50,7 @@ func TestHashSet_Clear(t *testing.T) { func TestHashSet_Contains(t *testing.T) { a := quicktest.New(t) - s := FromSlice[int]() - s.Add(1) - s.Add(2) + s := FromSlice[int](1, 2) a.Assert(s.Contains(1), quicktest.IsTrue) a.Assert(s.Contains(2), quicktest.IsTrue) a.Assert(s.Contains(3), quicktest.IsFalse) @@ -76,31 +58,22 @@ func TestHashSet_Contains(t *testing.T) { func TestHashSet_ContainsAll(t *testing.T) { a := quicktest.New(t) - s := FromSlice[int]() - s.Add(1) - s.Add(2) - s.Add(3) - a.Assert(s.ContainsAll(seq([]int{1, 2, 3})), quicktest.IsTrue) - a.Assert(s.ContainsAll(seq([]int{1, 2, 4})), quicktest.IsFalse) + s := FromSlice[int](1, 2, 3) + a.Assert(s.ContainsAll(seq(1, 2, 3)), quicktest.IsTrue) + a.Assert(s.ContainsAll(seq(1, 2, 4)), quicktest.IsFalse) } func TestHashSet_ContainsAny(t *testing.T) { a := quicktest.New(t) - s := FromSlice[int]() - s.Add(1) - s.Add(2) - s.Add(3) - a.Assert(true, quicktest.Equals, s.ContainsAny(seq([]int{1, 2, 4}))) - a.Assert(true, quicktest.Equals, s.ContainsAny(seq([]int{1, 2, 3}))) - a.Assert(false, quicktest.Equals, s.ContainsAny(seq([]int{5, 6, 7}))) + s := FromSlice[int](1, 2, 3) + a.Assert(true, quicktest.Equals, s.ContainsAny(seq(1, 2, 4))) + a.Assert(true, quicktest.Equals, s.ContainsAny(seq(1, 2, 3))) + a.Assert(false, quicktest.Equals, s.ContainsAny(seq(5, 6, 7))) } func TestHashSet_Size(t *testing.T) { a := quicktest.New(t) - s := FromSlice[int]() - s.Add(1) - s.Add(2) - s.Add(3) + s := FromSlice[int](1, 2, 3) a.Assert(s.Size(), quicktest.Equals, 3) s.Add(5) a.Assert(s.Size(), quicktest.Equals, 4) @@ -118,10 +91,7 @@ func TestHashSet_IsEmpty(t *testing.T) { func TestHashSet_Clone(t *testing.T) { a := quicktest.New(t) - s1 := FromSlice[int]() - s1.Add(1) - s1.Add(2) - s1.Add(3) + s1 := FromSlice[int](1, 2, 3) s2 := s1.Clone() a.Assert(s1.Equal(s2), quicktest.IsTrue) s2.Add(4) @@ -130,10 +100,7 @@ func TestHashSet_Clone(t *testing.T) { func TestHashSet_DeepCloneBy(t *testing.T) { a := quicktest.New(t) - s1 := FromSlice[int]() - s1.Add(1) - s1.Add(2) - s1.Add(3) + s1 := FromSlice[int](1, 2, 3) s2 := s1.DeepCloneBy(func(i int) int { return i }) a.Assert(s1.Equal(s2), quicktest.IsTrue) s2.Add(4) @@ -143,14 +110,10 @@ func TestHashSet_DeepCloneBy(t *testing.T) { func TestHashSet_SupersetOf(t *testing.T) { a := quicktest.New(t) s1 := FromSlice[int]() - s1.Add(1) - s1.Add(2) - s1.Add(3) + s1.AddIter(seq[int](1, 2, 3)) s2 := FromSlice[int]() a.Assert(s1.SupersetOf(s2), quicktest.IsTrue) - s2.Add(1) - s2.Add(2) - s2.Add(3) + s2.AddIter(seq(1, 2, 3)) a.Assert(s1.SupersetOf(s2), quicktest.IsTrue) s2.Add(4) a.Assert(s1.SupersetOf(s2), quicktest.IsFalse) @@ -158,15 +121,10 @@ func TestHashSet_SupersetOf(t *testing.T) { func TestHashSet_SubsetOf(t *testing.T) { a := quicktest.New(t) - s1 := FromSlice[int]() - s1.Add(1) - s1.Add(2) - s1.Add(3) + s1 := FromSlice[int](1, 2, 3) s2 := FromSlice[int]() a.Assert(s2.SubsetOf(s1), quicktest.IsTrue) - s2.Add(1) - s2.Add(2) - s2.Add(3) + s2.AddIter(seq(1, 2, 3)) a.Assert(s2.SubsetOf(s1), quicktest.IsTrue) s2.Add(4) a.Assert(s2.SubsetOf(s1), quicktest.IsFalse) @@ -174,16 +132,10 @@ func TestHashSet_SubsetOf(t *testing.T) { func TestHashSet_Union(t *testing.T) { a := quicktest.New(t) - s1 := FromSlice[int]() - s1.Add(1) - s1.Add(2) - s1.Add(3) - s2 := FromSlice[int]() - s2.Add(1) - s2.Add(4) - s2.Add(5) + s1 := FromSlice[int](1, 2, 3) + s2 := FromSlice[int](1, 4, 5) s3 := s1.Union(s2) - a.Assert(s3.ContainsAll(seq([]int{1, 2, 3, 4, 5})), quicktest.IsTrue) + a.Assert(s3.ContainsAll(seq(1, 2, 3, 4, 5)), quicktest.IsTrue) a.Assert(s3.Size(), quicktest.Equals, 5) } @@ -198,7 +150,7 @@ func TestHashSet_Intersection(t *testing.T) { s2.Add(4) s2.Add(5) s3 := s1.Intersection(s2) - a.Assert(s3.ContainsAll(seq([]int{1})), quicktest.IsTrue) + a.Assert(s3.ContainsAll(seq(1)), quicktest.IsTrue) a.Assert(s3.Size(), quicktest.Equals, 1) } @@ -213,35 +165,23 @@ func TestHashSet_Difference(t *testing.T) { s2.Add(4) s2.Add(5) s3 := s1.Difference(s2) - a.Assert(s3.ContainsAll(seq([]int{2, 3})), quicktest.IsTrue) + a.Assert(s3.ContainsAll(seq(2, 3)), quicktest.IsTrue) a.Assert(s3.Size(), quicktest.Equals, 2) } func TestHashSet_SymmetricDifference(t *testing.T) { a := quicktest.New(t) - s1 := FromSlice[int]() - s1.Add(1) - s1.Add(2) - s1.Add(3) - s2 := FromSlice[int]() - s2.Add(1) - s2.Add(4) - s2.Add(5) + s1 := FromSlice[int](1, 2, 3) + s2 := FromSlice[int](1, 4, 5) s3 := s1.SymmetricDifference(s2) - a.Assert(s3.ContainsAll(seq([]int{2, 3, 4, 5})), quicktest.IsTrue) + a.Assert(s3.ContainsAll(seq(2, 3, 4, 5)), quicktest.IsTrue) a.Assert(s3.Size(), quicktest.Equals, 4) } func TestHashSet_Equal(t *testing.T) { a := quicktest.New(t) - s1 := FromSlice[int]() - s1.Add(1) - s1.Add(2) - s1.Add(3) - s2 := FromSlice[int]() - s2.Add(1) - s2.Add(2) - s2.Add(3) + s1 := FromSlice[int](1, 2, 3) + s2 := FromSlice[int](1, 2, 3) a.Assert(s1.Equal(s2), quicktest.IsTrue) s2.Add(4) a.Assert(s1.Equal(s2), quicktest.IsFalse) @@ -254,24 +194,7 @@ func TestHashSet_Iter(t *testing.T) { a.Assert(iter.Size(s1.Iter()), quicktest.Equals, 5) } -func TestHashSet_Marshal(t *testing.T) { - a := quicktest.New(t) - s := FromSlice(5, 1, 4, 2, 3, 1, 2, 3) - b, err := json.Marshal(s) - a.Assert(err, quicktest.IsNil) - a.Logf("%s\n ", b) -} - -func TestHashSet_UnmarshalJSON(t *testing.T) { - a := quicktest.New(t) - s := FromSlice[int]() - err := json.Unmarshal([]byte(`{"1":{},"2":{},"3":{},"4":{},"5":{}}`), &s) - a.Assert(err, quicktest.IsNil) - a.Logf("%+v\n", s) - a.Assert(s.Equal(FromSlice(1, 2, 3, 4, 5)), quicktest.IsTrue) -} - -func TestIterMut(t *testing.T) { +func TestHashSet_IterMut(t *testing.T) { s := FromSlice[int](1, 2, 3, 4) iter.ForEach(s.IterMut(), func(s *SetItem[int]) { if s.Value()%2 == 0 {