Skip to content

Commit

Permalink
refactor!: use an iterator function for Entries
Browse files Browse the repository at this point in the history
This change removes the need for a context to cancel a partial iteration.
  • Loading branch information
mdawar committed Aug 23, 2024
1 parent 516b0c1 commit b01a6a6
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 64 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func main() {
removed := m.RemoveExpired() // Returns the number of removed keys.

// Iterate over the map entries.
for entry := range m.Entries(context.TODO()) {
fmt.Println("Key:", entry.Key, "-", "Value:", entry.Value)
for key, value := range m.Entries() {
fmt.Println("Key:", key, "-", "Value:", value)
}
}
```
Expand Down
21 changes: 2 additions & 19 deletions examples_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package xmap_test

import (
"context"
"fmt"
"time"

Expand Down Expand Up @@ -69,23 +68,7 @@ func ExampleMap_Entries() {
m := xmap.New[string, int]()
defer m.Stop()

for entry := range m.Entries(context.TODO()) {
fmt.Println("Key:", entry.Key, "-", "Value:", entry.Value)
}
}

func ExampleMap_Entries_partial_iteration() {
m := xmap.New[string, int]()
defer m.Stop()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for entry := range m.Entries(ctx) {
fmt.Println("Key:", entry.Key, "-", "Value:", entry.Value)
// With a partial iteration, the context must be canceled
// to prevent a deadlock (A read lock is held during the iteration).
cancel()
break
for k, v := range m.Entries() {
fmt.Println("Key:", k, "-", "Value:", v)
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module github.com/mdawar/xmap

go 1.21
go 1.23

require go.uber.org/goleak v1.3.0
36 changes: 8 additions & 28 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
package xmap

import (
"context"
"iter"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -169,44 +169,24 @@ func (m *Map[K, V]) GetWithExpiration(key K) (V, time.Time, bool) {
return zero, time.Time{}, false
}

// Entry represents a key/value pair in the Map.
type Entry[K comparable, V any] struct {
Key K
Value V
}

// Entries returns a read-only channel of Entry elements representing the
// current entries in the Map.
//
// This channel can be used in a for range loop to iterate over the current
// map entries. Only the elements that have not expired are sent on this
// channel. The channel is closed after all the entries have been sent.
// Entries returns an iterator over key-value pairs of the Map entries.
//
// Like the map type, the iteration order is not guaranteed.
// Only the entries that have not expired are produced during the iteration.
//
// A read lock is held during the iteration, so it's important to consume
// all of the elements sent on the channel or cancel the passed context
// if a full iteration is not needed.
func (m *Map[K, V]) Entries(ctx context.Context) <-chan Entry[K, V] {
ch := make(chan Entry[K, V])

go func() {
// Similar to the map type, the iteration order is not guaranteed.
func (m *Map[K, V]) Entries() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
m.mu.RLock()
defer m.mu.RUnlock()
defer close(ch)

for key, entry := range m.kv {
if !m.expired(entry) {
select {
case <-ctx.Done():
if !yield(key, entry.value) {
return
case ch <- Entry[K, V]{key, entry.value}:
}
}
}
}()

return ch
}
}

// Delete removes a key from the map.
Expand Down
23 changes: 9 additions & 14 deletions map_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package xmap_test

import (
"context"
"maps"
"testing"
"time"
Expand Down Expand Up @@ -653,8 +652,8 @@ func TestMapIterateOverMapEntries(t *testing.T) {
t.Helper()

gotEntries := make(map[string]int)
for entry := range m.Entries(context.Background()) {
gotEntries[entry.Key] = entry.Value
for k, v := range m.Entries() {
gotEntries[k] = v
}

if !maps.Equal(wantEntries, gotEntries) {
Expand Down Expand Up @@ -708,21 +707,17 @@ func TestMapPartialIterationOverEntries(t *testing.T) {
m.Set(entry.key, entry.value, entry.ttl)
}

// Number of entries consumed.
var gotCount int
// Consumed entries.
consumed := make(map[string]int)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for range m.Entries(ctx) {
gotCount++
cancel() // Must cancel the context to release the lock.
break // Stop after consuming 1 entry.
for k, v := range m.Entries() {
consumed[k] = v
break // Stop after consuming 1 entry.
}

// Make sure we consume at least 1 entry.
if gotCount != 1 {
t.Errorf("want to consume 1 entry, got %d", gotCount)
if len(consumed) != 1 {
t.Errorf("want to consume 1 entry, got %d", len(consumed))
}

// Channel used to wait for stopping the map.
Expand Down

0 comments on commit b01a6a6

Please sign in to comment.