diff --git a/.gitignore b/.gitignore index 6f72f89..5024fcd 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ go.work.sum # env file .env + +cov.html diff --git a/list/list.go b/list/list.go index 8ffb501..be76331 100644 --- a/list/list.go +++ b/list/list.go @@ -105,8 +105,8 @@ func (lst *List[T]) Remove(node *Node[T]) { lst.length-- } -// All returns an iterator over all the values in the list. -func (lst *List[T]) All() iter.Seq[T] { +// Values returns an iterator over all the values in the list. +func (lst *List[T]) Values() iter.Seq[T] { return func(yield func(T) bool) { for node := lst.front.next; node != lst.back; node = node.next { if !yield(node.Value) { @@ -116,6 +116,17 @@ func (lst *List[T]) All() iter.Seq[T] { } } +// Nodes returns an iterator over all the nodes in the list. +func (lst *List[T]) Nodes() iter.Seq[*Node[T]] { + return func(yield func(*Node[T]) bool) { + for node := lst.front.next; node != lst.back; node = node.next { + if !yield(node) { + return + } + } + } +} + func (lst *List[T]) debugPrint() { fmt.Println("-----------------------") for n := lst.front; n != nil; n = n.next { diff --git a/list/list_test.go b/list/list_test.go index 4941027..bb215be 100644 --- a/list/list_test.go +++ b/list/list_test.go @@ -11,10 +11,39 @@ func checkList[T comparable](t *testing.T, lst *List[T], want []T) { t.Errorf("got len=%v, want %v", lst.Len(), len(want)) } - got := slices.Collect(lst.All()) + got := slices.Collect(lst.Values()) if !slices.Equal(got, want) { t.Errorf("got %v, want %v", got, want) } + + if lst.front.prev != nil { + t.Errorf("got lst.front.prev=%p, want nil", lst.front.prev) + } + if lst.back.next != nil { + t.Errorf("got lst.back.next=%p, want nil", lst.back.next) + } + + if lst.Len() < 1 { + return + } + // Get all nodes from the list and verify list invariants. + nodes := slices.Collect(lst.Nodes()) + first := nodes[0] + last := nodes[len(nodes)-1] + + if lst.front.next != first || first.prev != lst.front { + t.Errorf("front mismatch: front.next=%p, first.prev=%p", lst.front.next, first.prev) + } + if lst.back.prev != last || last.next != lst.back { + t.Errorf("back mismatch: back.prev=%p, last.next=%p", lst.back.prev, last.next) + } + + for i := 0; i < len(nodes)-1; i++ { + j := i + 1 + if nodes[i].next != nodes[j] || nodes[j].prev != nodes[i] { + t.Errorf("node link mismatch at i=%d, j=%d", i, j) + } + } } func TestBasicInsertFront(t *testing.T) { @@ -93,3 +122,46 @@ func TestRemove(t *testing.T) { nl.Remove(nl.Back()) checkList(t, nl, []int{}) } + +func TestFrontBack(t *testing.T) { + // Empty list - nil for both + nl := New[int]() + if nl.Front() != nil { + t.Errorf("got front=%v, want nil", nl.Front()) + } + if nl.Back() != nil { + t.Errorf("got back=%v, want nil", nl.Back()) + } + + // Insert element + nl.InsertBack(50) + if nl.Front().Value != 50 || nl.Back().Value != 50 { + t.Errorf("got front=%v, back=%v, want 50", nl.Front().Value, nl.Back().Value) + } +} + +func TestNextPrev(t *testing.T) { + nl := New[string]() + nl.InsertBack("five") + nl.InsertBack("six") + nl.InsertBack("seven") + + var vals []string + for n := nl.Front(); n != nil; n = nl.Next(n) { + vals = append(vals, n.Value) + } + + wantVals := []string{"five", "six", "seven"} + if !slices.Equal(vals, wantVals) { + t.Errorf("got %v, want %v", vals, wantVals) + } + + var revVals []string + for n := nl.Back(); n != nil; n = nl.Prev(n) { + revVals = append(revVals, n.Value) + } + slices.Reverse(revVals) + if !slices.Equal(revVals, wantVals) { + t.Errorf("got %v, want %v", revVals, wantVals) + } +}