Skip to content

Commit

Permalink
Format (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
seiyab authored May 3, 2024
1 parent 688ec97 commit 56bec35
Show file tree
Hide file tree
Showing 8 changed files with 667 additions and 7 deletions.
240 changes: 240 additions & 0 deletions format.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package teq

import (
"fmt"
"reflect"
"sort"
"strings"

"github.com/pmezard/go-difflib/difflib"
)

func (teq Teq) report(expected, actual any) string {
simple := fmt.Sprintf("expected %v, got %v", expected, actual)
if expected == nil || actual == nil {
return simple
}
ve := reflect.ValueOf(expected)
va := reflect.ValueOf(actual)
if ve.Type() != va.Type() {
return simple
}
k := ve.Kind()
if k != reflect.Struct &&
k != reflect.Map &&
k != reflect.Slice &&
k != reflect.Array &&
k != reflect.String &&
k != reflect.Pointer {
return simple
}
if k == reflect.String {
if len(ve.String()) < 10 && len(va.String()) < 10 {
return simple
}
if strings.Contains(ve.String(), "\n") || strings.Contains(va.String(), "\n") {
r, ok := richReport(
difflib.SplitLines(ve.String()),
difflib.SplitLines(va.String()),
)
if !ok {
return simple
}
return r
}
}

r, ok := richReport(
teq.format(ve, 0).diffSequence(),
teq.format(va, 0).diffSequence(),
)
if !ok {
return simple
}
return r
}

func richReport(a []string, b []string) (string, bool) {
diff := difflib.UnifiedDiff{
A: a,
B: b,
FromFile: "expected",
ToFile: "actual",
Context: 1,
}
diffTxt, err := difflib.GetUnifiedDiffString(diff)
if err != nil {
return fmt.Sprintf("failed to get diff: %v", err), false
}
if diffTxt == "" {
return "", false
}
return strings.Join([]string{
"not equal",
"differences:",
diffTxt,
}, "\n"), true
}

func (teq Teq) format(v reflect.Value, depth int) lines {
if depth > teq.MaxDepth {
return linesOf("<max depth exceeded>")
}
if !v.IsValid() {
return linesOf("<invalid>")
}

ty := v.Type()
if fm, ok := teq.formats[ty]; ok {
return linesOf(fm(v))
}

fmtFn, ok := fmts[v.Kind()]
if !ok {
fmtFn = todoFmt
}
next := func(v reflect.Value) lines {
return teq.format(v, depth+1)
}
return fmtFn(v, next)
}

var fmts = map[reflect.Kind]func(reflect.Value, func(reflect.Value) lines) lines{
reflect.Array: arrayFmt,
reflect.Slice: sliceFmt,
reflect.Interface: todoFmt,
reflect.Pointer: pointerFmt,
reflect.Struct: structFmt,
reflect.Map: mapFmt,
reflect.Func: todoFmt,
reflect.Int: intFmt,
reflect.Int8: intFmt,
reflect.Int16: intFmt,
reflect.Int32: intFmt,
reflect.Int64: intFmt,
reflect.Uint: uintFmt,
reflect.Uint8: uintFmt,
reflect.Uint16: uintFmt,
reflect.Uint32: uintFmt,
reflect.Uint64: uintFmt,
reflect.Uintptr: uintFmt,
reflect.String: stringFmt,
reflect.Bool: boolFmt,
reflect.Float32: floatFmt,
reflect.Float64: floatFmt,
reflect.Complex64: complexFmt,
reflect.Complex128: complexFmt,
}

func todoFmt(v reflect.Value, next func(reflect.Value) lines) lines {
return linesOf(fmt.Sprintf("<%s>", v.String()))
}

func arrayFmt(v reflect.Value, next func(reflect.Value) lines) lines {
open := fmt.Sprintf("%s{", v.Type().String())
close := "}"
if v.Len() == 0 {
return linesOf(open + close)
}
result := make(lines, 0, v.Len()+2)
result = append(result, lineOf(open))
for i := 0; i < v.Len(); i++ {
elem := next(v.Index(i)).followedBy(",")
result = append(result, elem.indent()...)
}
result = append(result, lineOf(close))
return result

}

func sliceFmt(v reflect.Value, next func(reflect.Value) lines) lines {
open := fmt.Sprintf("[]%s{", v.Type().Elem().String())
close := "}"
if v.Len() == 0 {
return linesOf(open + close)
}
result := make(lines, 0, v.Len()+2)
result = append(result, lineOf(open))
for i := 0; i < v.Len(); i++ {
elem := next(v.Index(i)).followedBy(",")
result = append(result, elem.indent()...)
}
result = append(result, lineOf(close))
return result
}

func pointerFmt(v reflect.Value, next func(reflect.Value) lines) lines {
if v.IsNil() {
return linesOf("<nil>")
}
return next(v.Elem()).ledBy("*")
}

func structFmt(v reflect.Value, next func(reflect.Value) lines) lines {
open := fmt.Sprintf("%s{", v.Type().String())
close := "}"
if v.NumField() == 0 {
return linesOf(open + close)
}
result := make(lines, 0, v.NumField()+2)
result = append(result, lineOf(open))
for i := 0; i < v.NumField(); i++ {
entry := next(v.Field(i)).
ledBy(v.Type().Field(i).Name + ": ").
followedBy(",")
result = append(result, entry.indent()...)
}
result = append(result, lineOf(close))
return result
}

func mapFmt(v reflect.Value, next func(reflect.Value) lines) lines {
open := fmt.Sprintf("map[%s]%s{", v.Type().Key(), v.Type().Elem())
close := "}"
if v.Len() == 0 {
return linesOf(open + close)
}
result := make(lines, 0, v.Len()+2)
result = append(result, lineOf(open))

type entry struct {
key string
lines lines
}
entries := make([]entry, 0, v.Len())
for _, key := range v.MapKeys() {
var e entry
keyLines := next(key)
e.key = keyLines.key()
valLines := next(v.MapIndex(key))
e.lines = keyValue(keyLines, valLines)
entries = append(entries, e)
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].key < entries[j].key
})
for _, e := range entries {
result = append(result, e.lines.indent()...)
}
result = append(result, lineOf(close))
return result
}

func intFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
return linesOf(fmt.Sprintf("%s(%d)", v.Type(), v.Int()))
}
func uintFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
return linesOf(fmt.Sprintf("%s(%d)", v.Type(), v.Uint()))
}
func stringFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
return linesOf(fmt.Sprintf("%q", v.String()))
}
func boolFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
return linesOf(fmt.Sprintf("%t", v.Bool()))
}
func floatFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
return linesOf(fmt.Sprintf("%s(%f)", v.Type(), v.Float()))
}
func complexFmt(v reflect.Value, _ func(reflect.Value) lines) lines {
return linesOf(fmt.Sprintf("%s(%f, %f)", v.Type(), real(v.Complex()), imag(v.Complex())))
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/seiyab/teq

go 1.18

require github.com/pmezard/go-difflib v1.0.0
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
77 changes: 77 additions & 0 deletions line.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package teq

import "strings"

func lineOf(text string) line {
return line{text: text}
}

func linesOf(texts ...string) lines {
result := make(lines, 0, len(texts))
for _, text := range texts {
result = append(result, line{text: text})
}
return result
}

type line struct {
indentDepth int
text string
}

type lines []line

func (l line) indent() line {
return line{
indentDepth: l.indentDepth + 1,
text: l.text,
}
}

func (ls lines) indent() lines {
result := make(lines, len(ls))
for i, l := range ls {
result[i] = l.indent()
}
return result
}

func (ls lines) diffSequence() []string {
result := make([]string, len(ls))
for i, l := range ls {
result[i] = strings.Repeat(" ", l.indentDepth) + l.text + "\n"
}
return result
}

func (ls lines) ledBy(s string) lines {
result := ls.clone()
result[0].text = s + result[0].text
return result
}

func (ls lines) followedBy(s string) lines {
result := ls.clone()
result[len(result)-1].text = result[len(result)-1].text + s
return result
}

func (ls lines) clone() lines {
result := make(lines, len(ls))
copy(result, ls)
return result
}

func (ls lines) key() string {
ts := make([]string, len(ls))
for i, l := range ls {
ts[i] = l.text
}
return strings.Join(ts, "")
}

func keyValue(key lines, value lines) lines {
result := key.followedBy(": " + value[0].text)
result = append(result, value[1:]...)
return result.followedBy(",")
}
25 changes: 24 additions & 1 deletion teq.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ type Teq struct {
MaxDepth int

transforms map[reflect.Type]func(reflect.Value) reflect.Value
formats map[reflect.Type]func(reflect.Value) string
}

func New() Teq {
return Teq{
MaxDepth: 1_000,

transforms: make(map[reflect.Type]func(reflect.Value) reflect.Value),
formats: make(map[reflect.Type]func(reflect.Value) string),
}
}

Expand All @@ -27,7 +29,7 @@ func (teq Teq) Equal(t TestingT, expected, actual any) bool {
}()
ok := teq.equal(expected, actual)
if !ok {
t.Errorf("expected %v, got %v", expected, actual)
t.Errorf(teq.report(expected, actual))
}
return ok
}
Expand Down Expand Up @@ -70,6 +72,27 @@ func (teq *Teq) AddTransform(transform any) {
teq.transforms[ty.In(0)] = reflectTransform
}

func (teq *Teq) AddFormat(format any) {
ty := reflect.TypeOf(format)
if ty.Kind() != reflect.Func {
panic("format must be a function")
}
if ty.NumIn() != 1 {
panic("format must have only one argument")
}
if ty.NumOut() != 1 {
panic("format must have only one return value")
}
if ty.Out(0).Kind() != reflect.String {
panic("format must return string")
}
formatValue := reflect.ValueOf(format)
reflectFormat := func(v reflect.Value) string {
return formatValue.Call([]reflect.Value{v})[0].String()
}
teq.formats[ty.In(0)] = reflectFormat
}

func (teq Teq) equal(x, y any) bool {
if x == nil || y == nil {
return x == y
Expand Down
Loading

0 comments on commit 56bec35

Please sign in to comment.