From 484d1349171c872b0cdc9204a4c8a65cec20bbcd Mon Sep 17 00:00:00 2001 From: Roger Peppe Date: Wed, 12 May 2021 12:27:12 +0100 Subject: [PATCH] WIP: add support for registering source files From [an idea](https://twitter.com/bradfitz/status/1387817724634492928) from Brad Fitzpatrick, this makes it possible to show source code in test failures even when the source code isn't currently available. Still to do: the `quicktest-generate` command, which would generate code looking something like the following: ``` package {{.Pkg}} import ( "embed" "testing" qt "github.com/frankban/quicktest" ) //go:embed *_test.go var _quicktestFiles embed.FS func init() { qt.RegisterSource({{.Package}}, _quicktestFiles) } ``` --- report.go | 48 ++++++++++++++++++++++++++++++++++-------- sourceregister_1.16.go | 46 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 9 deletions(-) create mode 100644 sourceregister_1.16.go diff --git a/report.go b/report.go index 04956d7..f321f4e 100644 --- a/report.go +++ b/report.go @@ -10,6 +10,8 @@ import ( "go/printer" "go/token" "io" + "io/ioutil" + "log" "reflect" "runtime" "strings" @@ -127,7 +129,7 @@ func writeStack(w io.Writer) { } fmt.Fprint(w, prefixf(prefix, "%s:%d", frame.File, frame.Line)) if strings.HasSuffix(frame.File, ".go") { - stmt, err := sg.Get(frame.File, frame.Line) + stmt, err := sg.Get(fileToPackage(frame.Function), frame.File, frame.Line) if err != nil { fmt.Fprint(w, prefixf(prefix+prefix, "<%s>", err)) } else { @@ -141,22 +143,50 @@ func writeStack(w io.Writer) { } } +func fileToPackage(fn string) string { + if i := strings.LastIndex(fn, "."); i >= 0 { + return fn[0:i] + } + return "" +} + type stmtGetter struct { fset *token.FileSet files map[string]*ast.File config *printer.Config } -// Get returns the lines of code of the statement at the given file and line. -func (sg *stmtGetter) Get(file string, line int) (string, error) { - f := sg.files[file] - if f == nil { - var err error - f, err = parser.ParseFile(sg.fset, file, nil, parser.ParseComments) +var registeredSourceForPackage = func(pkg, path string) []byte { + return nil +} + +func (sg *stmtGetter) parseFile(pkg, file string) (*ast.File, error) { + if f := sg.files[file]; f != nil { + return f, nil + } + data := registeredSourceForPackage(pkg, file) + if data == nil { + data1, err := ioutil.ReadFile(file) if err != nil { - return "", fmt.Errorf("cannot parse source file: %s", err) + return nil, err } - sg.files[file] = f + data = data1 + } else { + log.Printf("got registered source for package %q, file %q", pkg, file) + } + f, err := parser.ParseFile(sg.fset, file, data, parser.ParseComments) + if err != nil { + return nil, err + } + sg.files[file] = f + return f, nil +} + +// Get returns the lines of code of the statement at the given file and line. +func (sg *stmtGetter) Get(pkg string, file string, line int) (string, error) { + f, err := sg.parseFile(pkg, file) + if err != nil { + return "", fmt.Errorf("cannot parse source file: %s", err) } var stmt string ast.Inspect(f, func(n ast.Node) bool { diff --git a/sourceregister_1.16.go b/sourceregister_1.16.go new file mode 100644 index 0000000..7e2c8ce --- /dev/null +++ b/sourceregister_1.16.go @@ -0,0 +1,46 @@ +//+build go1.16 + +package quicktest + +import ( + "embed" + "fmt" + "path/filepath" + "sync" +) + +var ( + sourceRegisterMu sync.Mutex + sourceRegister = make(map[string]embed.FS) +) + +func init() { + registeredSourceForPackage = func(pkg, path string) []byte { + sourceRegisterMu.Lock() + defer sourceRegisterMu.Unlock() + fs, ok := sourceRegister[pkg] + if !ok { + return nil + } + data, _ := fs.ReadFile(filepath.Base(path)) + return data + } +} + +// RegisterSource registers Go source files for the given package. +// +// You shouldn't usually need to call this function directly - instead +// use a "go generate" directive as follows: +// +// //go:generate quicktest-generate +// +// and use the "go generate" command to generate the small +// amount of boilerplate required. +func RegisterSource(pkg string, files embed.FS) { + sourceRegisterMu.Lock() + defer sourceRegisterMu.Unlock() + if _, ok := sourceRegister[pkg]; ok { + panic(fmt.Errorf("package source for %q registered more than once", pkg)) + } + sourceRegister[pkg] = files +}