Skip to content

Commit

Permalink
channel equality (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
seiyab authored May 5, 2024
1 parent d15c348 commit 6f4df62
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
19 changes: 12 additions & 7 deletions deepequal.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package teq

import (
"bytes"
"fmt"
"reflect"
"unsafe"
)
Expand Down Expand Up @@ -73,7 +72,7 @@ func (teq Teq) deepValueEqual(

eqFn, ok := eqs[v1.Kind()]
if !ok {
panic("not implemented")
panic("equality is not defined for " + v1.Type().String())
}
var n next = func(v1, v2 reflect.Value) bool {
return teq.deepValueEqual(v1, v2, visited, depth+1)
Expand All @@ -86,7 +85,7 @@ type next func(v1, v2 reflect.Value) bool
var eqs = map[reflect.Kind]func(v1, v2 reflect.Value, nx next) bool{
reflect.Array: arrayEq,
reflect.Slice: sliceEq,
reflect.Chan: todo,
reflect.Chan: chanEq,
reflect.Interface: interfaceEq,
reflect.Pointer: pointerEq,
reflect.Struct: structEq,
Expand Down Expand Up @@ -119,10 +118,6 @@ func hard(k reflect.Kind) bool {
return false
}

func todo(v1, v2 reflect.Value, nx next) bool {
panic(fmt.Sprintf("not implemented (%s, %s)", v1.Type(), v1.Kind()))
}

func arrayEq(v1, v2 reflect.Value, nx next) bool {
for i := 0; i < v1.Len(); i++ {
if !nx(v1.Index(i), v2.Index(i)) {
Expand Down Expand Up @@ -154,6 +149,16 @@ func sliceEq(v1, v2 reflect.Value, nx next) bool {
return true
}

func chanEq(v1, v2 reflect.Value, _ next) bool {
if v1.CanInterface() && v2.CanInterface() {
return v1.Interface() == v2.Interface()
}
if v1.CanInterface() != v2.CanInterface() {
return false
}
panic("failed to compare channels")
}

func interfaceEq(v1, v2 reflect.Value, nx next) bool {
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() == v2.IsNil()
Expand Down
13 changes: 12 additions & 1 deletion format.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ func (teq Teq) format(v reflect.Value, depth int) lines {
var fmts = map[reflect.Kind]func(reflect.Value, func(reflect.Value) lines) lines{
reflect.Array: arrayFmt,
reflect.Slice: sliceFmt,
reflect.Chan: chanFmt,
reflect.Interface: interfaceFmt,
reflect.Pointer: pointerFmt,
reflect.Struct: structFmt,
Expand Down Expand Up @@ -163,6 +164,13 @@ func sliceFmt(v reflect.Value, next func(reflect.Value) lines) lines {
return result
}

func chanFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
if v.IsNil() {
return linesOf(fmt.Sprintf("%s(<nil>)", v.Type()))
}
return linesOf(fmt.Sprintf("%s(0x%x)", v.Type(), v.Pointer()))
}

func interfaceFmt(v reflect.Value, next func(reflect.Value) lines) lines {
open := fmt.Sprintf("%s(", v.Type().String())
close := ")"
Expand Down Expand Up @@ -230,7 +238,10 @@ func mapFmt(v reflect.Value, next func(reflect.Value) lines) lines {
}

func funcFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
return linesOf(fmt.Sprintf("%s(%v)", v.Type(), v.Pointer()))
if v.IsNil() {
return linesOf(fmt.Sprintf("%s(<nil>)", v.Type()))
}
return linesOf(fmt.Sprintf("%s(0x%x)", v.Type(), v.Pointer()))
}

func intFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
Expand Down
31 changes: 31 additions & 0 deletions teq_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestEqual(t *testing.T) {
{"slices", slices()},
{"maps", maps()},
{"interfaces", interfaces()},
{"channels", channels()},
{"recursions", recursions()},
}

Expand Down Expand Up @@ -282,6 +283,36 @@ differences:
}
}

func channels() []test {
c1 := make(chan int)
c2 := make(chan int)
return []test{
{c1, c1, nil, false},
{c1, c2, []string{fmt.Sprintf("expected %p, got %p", c1, c2)}, false},
{[]chan int{c1}, []chan int{c1}, nil, false},
{[]chan int{c1}, []chan int{c2}, []string{fmt.Sprintf(`not equal
differences:
--- expected
+++ actual
@@ -1,3 +1,3 @@
[]chan int{
- chan int(%p),
+ chan int(%p),
}
`, c1, c2)}, false},
{[]chan int{c1}, []chan int{nil}, []string{fmt.Sprintf(`not equal
differences:
--- expected
+++ actual
@@ -1,3 +1,3 @@
[]chan int{
- chan int(%p),
+ chan int(<nil>),
}
`, c1)}, false},
}
}

func recursions() []test {
type privateRecursiveStruct struct {
i int
Expand Down

0 comments on commit 6f4df62

Please sign in to comment.