diff --git a/errors.go b/errors.go index 842ee80..66d7b65 100644 --- a/errors.go +++ b/errors.go @@ -267,3 +267,42 @@ func Cause(err error) error { } return err } + +// Trace returns the underlying stack trace of the error, if possible. +// An error value has a stack trace if it implements the following +// interface: +// +// type stackTracer interface { +// StackTrace() errors.StackTrace +// } +// +// If the error does not implement StackTrace, nil will be returned. +// If the error is nil, nil will be returned without further +// investigation. +func Trace(err error) (frames []Frame) { + if err == nil { + return nil + } + + type stackTracer interface { + StackTrace() StackTrace + } + + type causer interface { + Cause() error + } + + for err != nil { + if stackTrace, ok := err.(stackTracer); ok { + frames = ([]Frame)(stackTrace.StackTrace()) + } + + if cause, ok := err.(causer); ok { + err = cause.Cause() + } else { + break + } + } + + return +} diff --git a/stack.go b/stack.go index 6b1f289..f962cb7 100644 --- a/stack.go +++ b/stack.go @@ -11,13 +11,18 @@ import ( // Frame represents a program counter inside a stack frame. type Frame uintptr +// Name returns the name of function for this Frame's pc. +func (f Frame) Name() string { + return funcname(runtime.FuncForPC(f.pc()).Name()) +} + // pc returns the program counter for this frame; // multiple frames may have the same PC value. func (f Frame) pc() uintptr { return uintptr(f) - 1 } -// file returns the full path to the file that contains the +// File returns the full path to the file that contains the // function for this Frame's pc. -func (f Frame) file() string { +func (f Frame) File() string { fn := runtime.FuncForPC(f.pc()) if fn == nil { return "unknown" @@ -26,9 +31,9 @@ func (f Frame) file() string { return file } -// line returns the line number of source code of the +// Line returns the line number of source code of the // function for this Frame's pc. -func (f Frame) line() int { +func (f Frame) Line() int { fn := runtime.FuncForPC(f.pc()) if fn == nil { return 0 @@ -62,13 +67,12 @@ func (f Frame) Format(s fmt.State, verb rune) { fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) } default: - io.WriteString(s, path.Base(f.file())) + io.WriteString(s, path.Base(f.File())) } case 'd': - fmt.Fprintf(s, "%d", f.line()) + fmt.Fprintf(s, "%d", f.Line()) case 'n': - name := runtime.FuncForPC(f.pc()).Name() - io.WriteString(s, funcname(name)) + io.WriteString(s, f.Name()) case 'v': f.Format(s, 's') io.WriteString(s, ":") diff --git a/stack_test.go b/stack_test.go index 510c27a..fe9dca0 100644 --- a/stack_test.go +++ b/stack_test.go @@ -33,7 +33,7 @@ func TestFrameLine(t *testing.T) { }} for _, tt := range tests { - got := tt.Frame.line() + got := tt.Frame.Line() want := tt.want if want != got { t.Errorf("Frame(%v): want: %v, got: %v", uintptr(tt.Frame), want, got)