diff --git a/diagnose.go b/diagnose.go index 9417aa63..f164e099 100644 --- a/diagnose.go +++ b/diagnose.go @@ -210,6 +210,9 @@ func (di *diagnose) diag(cborSequence bool) (string, error) { } case io.EOF: + if firstItem { + return di.w.String(), err + } return di.w.String(), nil default: diff --git a/diagnose_test.go b/diagnose_test.go index 22cde2be..dc4b9ca3 100644 --- a/diagnose_test.go +++ b/diagnose_test.go @@ -6,6 +6,7 @@ package cbor import ( "bytes" "fmt" + "io" "reflect" "strings" "testing" @@ -1079,3 +1080,41 @@ func TestDiagnoseNotwellformedData(t *testing.T) { t.Errorf("Diagnose(0x%x) returned error %q", cborData, err) } } + +func TestDiagnoseEmptyData(t *testing.T) { + var emptyData []byte + + defaultMode, _ := DiagOptions{}.DiagMode() + sequenceMode, _ := DiagOptions{CBORSequence: true}.DiagMode() + + testCases := []struct { + name string + dm DiagMode + }{ + {"default", defaultMode}, + {"sequence", sequenceMode}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s, err := tc.dm.Diagnose(emptyData) + if len(s) != 0 { + t.Errorf("Diagnose() didn't return empty notation for empty data") + } + if err != io.EOF { + t.Errorf("Diagnose() didn't return io.EOF for empty data") + } + + s, rest, err := tc.dm.DiagnoseFirst(emptyData) + if len(s) != 0 { + t.Errorf("DiagnoseFirst() didn't return empty notation for empty data") + } + if len(rest) != 0 { + t.Errorf("DiagnoseFirst() didn't return empty rest for empty data") + } + if err != io.EOF { + t.Errorf("DiagnoseFirst() didn't return io.EOF for empty data") + } + }) + } +}