Skip to content

Commit

Permalink
Add generic thread-safe map implementation (#116)
Browse files Browse the repository at this point in the history
* add generic thread-safe map implementation

* fix tests and linter warnings

* use Go 1.23 iterator instead of range
  • Loading branch information
lovromazgon authored Sep 16, 2024
1 parent 76b42d0 commit 045c136
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 1 deletion.
141 changes: 141 additions & 0 deletions csync/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright © 2024 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package csync

import (
"iter"
"sync"
)

// Map is a thread-safe map.
type Map[K comparable, T any] struct {
m map[K]T
lock sync.RWMutex
}

// NewMap creates a new map.
func NewMap[K comparable, T any]() *Map[K, T] {
return &Map[K, T]{
m: make(map[K]T),
}
}

// Set sets the value for a key.
func (m *Map[K, T]) Set(key K, value T) {
m.lock.Lock()
m.m[key] = value
m.lock.Unlock()
}

// Get retrieves the value for a key.
func (m *Map[K, T]) Get(key K) (T, bool) {
m.lock.RLock()
value, ok := m.m[key]
m.lock.RUnlock()
return value, ok
}

// Delete removes the value for a key.
func (m *Map[K, T]) Delete(key K) {
m.lock.Lock()
delete(m.m, key)
m.lock.Unlock()
}

// Len returns the number of items in the map.
func (m *Map[K, T]) Len() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.m)
}

// Keys returns all the keys in the map.
func (m *Map[K, T]) Keys() []K {
m.lock.RLock()
defer m.lock.RUnlock()
keys := make([]K, 0, len(m.m))
for k := range m.m {
keys = append(keys, k)
}
return keys
}

// Values returns all the values in the map.
func (m *Map[K, T]) Values() []T {
m.lock.RLock()
defer m.lock.RUnlock()
values := make([]T, 0, len(m.m))
for _, v := range m.m {
values = append(values, v)
}
return values
}

// Clear removes all items from the map.
func (m *Map[K, T]) Clear() {
m.lock.Lock()
m.m = make(map[K]T)
m.lock.Unlock()
}

// Copy returns a new map with the same key-value pairs.
func (m *Map[K, T]) Copy() *Map[K, T] {
m.lock.RLock()
defer m.lock.RUnlock()

newMap := NewMap[K, T]()
for k, v := range m.m {
newMap.m[k] = v
}
return newMap
}

// Merge adds all key-value pairs from another map to this map.
func (m *Map[K, T]) Merge(other *Map[K, T]) {
other.lock.RLock()
defer other.lock.RUnlock()

m.lock.Lock()
defer m.lock.Unlock()

for k, v := range other.m {
m.m[k] = v
}
}

// ToGoMap returns a copy of the map as a Go map.
func (m *Map[K, T]) ToGoMap() map[K]T {
return m.Copy().m
}

// All returns an iterator over the map's key-value pairs. This can be used to
// iterate over the map using a for-range loop.
//
// Example:
//
// for key, value := range m.All() {
// fmt.Println(key, value)
// }
func (m *Map[K, T]) All() iter.Seq2[K, T] {
return func(yield func(K, T) bool) {
m.lock.RLock()
defer m.lock.RUnlock()
for k, v := range m.m {
if !yield(k, v) {
return
}
}
}
}
187 changes: 187 additions & 0 deletions csync/map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright © 2024 Meroxa, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package csync

import (
"sort"
"sync"
"testing"

"github.com/matryer/is"
)

func TestMap_NewMap(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

is.True(m != nil)
is.Equal(m.Len(), 0)
}

func TestMap_SetAndGet(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

m.Set("foo", 1)
v, ok := m.Get("foo")

is.True(ok)
is.Equal(v, 1)
}

func TestMap_GetNonExistent(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

v, ok := m.Get("bar")
is.True(!ok)
is.Equal(v, 0)
}

func TestMap_Delete(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

m.Set("foo", 1)
m.Delete("foo")

_, ok := m.Get("foo")
is.True(!ok)
}

func TestMap_All(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

m.Set("foo", 1)
m.Set("bar", 2)

got := make(map[string]int)
for key, value := range m.All() {
got[key] = value
}

is.Equal(got, map[string]int{"foo": 1, "bar": 2})
}

func TestMap_Len(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

is.Equal(m.Len(), 0)
m.Set("foo", 1)
is.Equal(m.Len(), 1)
}

func TestMap_Keys(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

m.Set("foo", 1)
m.Set("bar", 2)

keys := m.Keys()
sort.Strings(keys)
is.Equal(keys, []string{"bar", "foo"})
}

func TestMap_Values(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

m.Set("foo", 1)
m.Set("bar", 2)

values := m.Values()
sort.Ints(values)
is.Equal(values, []int{1, 2})
}

func TestMap_Clear(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

m.Set("foo", 1)
m.Clear()
is.Equal(m.Len(), 0)
}

func TestMap_Copy(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

m.Set("foo", 1)

got := m.Copy()
is.Equal(got.Len(), 1)

m.Set("foo", 2)
v, ok := got.Get("foo")
is.True(ok)
is.Equal(v, 1)
}

func TestMap_Merge(t *testing.T) {
is := is.New(t)
m1 := NewMap[string, int]()
m1.Set("foo", 1)
m2 := NewMap[string, int]()
m2.Set("bar", 2)

m1.Merge(m2)
is.Equal(m1.Len(), 2)

v, ok := m1.Get("bar")
is.True(ok)
is.Equal(v, 2)
}

func TestMap_ToGoMap(t *testing.T) {
is := is.New(t)
m := NewMap[string, int]()

m.Set("foo", 1)
m.Set("bar", 2)

got := m.ToGoMap()
is.Equal(got, map[string]int{"foo": 1, "bar": 2})
}

func TestMap_Concurrency(t *testing.T) {
is := is.New(t)
m := NewMap[int, int]()
var wg sync.WaitGroup

// Concurrent writes
for i := 0; i < 1000; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
m.Set(i, i)
}(i)
}

// Concurrent reads
for i := 0; i < 1000; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
_, _ = m.Get(i)
}(i)
}

wg.Wait()
is.Equal(m.Len(), 1000)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/conduitio/conduit-commons

go 1.22.4
go 1.23

require (
github.com/bufbuild/buf v1.41.0
Expand Down

0 comments on commit 045c136

Please sign in to comment.