diff --git a/set.go b/set.go index 292089d..04c478b 100644 --- a/set.go +++ b/set.go @@ -73,6 +73,10 @@ type Set[T comparable] interface { // given items are in the set. ContainsAny(val ...T) bool + // ContainsAnyElement returns whether at least one of the + // given element are in the set. + ContainsAnyElement(other Set[T]) bool + // Difference returns the difference between this set // and other. The returned set will contain // all elements of this set that are not also diff --git a/set_test.go b/set_test.go index da305d5..7c5f611 100644 --- a/set_test.go +++ b/set_test.go @@ -398,6 +398,33 @@ func Test_ContainsAnySet(t *testing.T) { } } +func Test_ContainsAnyElement(t *testing.T) { + a := NewSet[int]() + a.Add(1) + a.Add(3) + a.Add(5) + + b := NewSet[int]() + a.Add(2) + a.Add(4) + a.Add(6) + + if ret := a.ContainsAnyElement(b); ret { + t.Errorf("set a not contain any element in set b") + } + + a.Add(10) + + if ret := a.ContainsAnyElement(b); ret { + t.Errorf("set a not contain any element in set b") + } + + b.Add(10) + + if ret := a.ContainsAnyElement(b); !ret { + t.Errorf("set a contain 10") + } +} func Test_ClearSet(t *testing.T) { a := makeSetInt([]int{2, 5, 9, 10}) diff --git a/threadsafe.go b/threadsafe.go index 93f20c8..664fc61 100644 --- a/threadsafe.go +++ b/threadsafe.go @@ -82,6 +82,19 @@ func (t *threadSafeSet[T]) ContainsAny(v ...T) bool { return ret } +func (t *threadSafeSet[T]) ContainsAnyElement(other Set[T]) bool { + o := other.(*threadSafeSet[T]) + + t.RLock() + o.RLock() + + ret := t.uss.ContainsAnyElement(o.uss) + + t.RUnlock() + o.RUnlock() + return ret +} + func (t *threadSafeSet[T]) IsEmpty() bool { return t.Cardinality() == 0 } diff --git a/threadsafe_test.go b/threadsafe_test.go index ca998c9..9037616 100644 --- a/threadsafe_test.go +++ b/threadsafe_test.go @@ -217,6 +217,27 @@ func Test_ContainsAnyConcurrent(t *testing.T) { wg.Wait() } +func Test_ContainsAnyElementConcurrent(t *testing.T) { + runtime.GOMAXPROCS(2) + + s, ss := NewSet[int](), NewSet[int]() + ints := rand.Perm(N) + for _, v := range ints { + s.Add(v) + ss.Add(v) + } + + var wg sync.WaitGroup + for range ints { + wg.Add(1) + go func() { + s.ContainsAnyElement(ss) + wg.Done() + }() + } + wg.Wait() +} + func Test_DifferenceConcurrent(t *testing.T) { runtime.GOMAXPROCS(2) diff --git a/threadunsafe.go b/threadunsafe.go index 7e3243b..c95d32b 100644 --- a/threadunsafe.go +++ b/threadunsafe.go @@ -109,6 +109,26 @@ func (s *threadUnsafeSet[T]) ContainsAny(v ...T) bool { return false } +func (s *threadUnsafeSet[T]) ContainsAnyElement(other Set[T]) bool { + o := other.(*threadUnsafeSet[T]) + + // loop over smaller set + if s.Cardinality() < other.Cardinality() { + for elem := range *s { + if o.contains(elem) { + return true + } + } + } else { + for elem := range *o { + if s.contains(elem) { + return true + } + } + } + return false +} + // private version of Contains for a single element v func (s *threadUnsafeSet[T]) contains(v T) (ok bool) { _, ok = (*s)[v]