diff --git a/tools/please_go/generate/BUILD b/tools/please_go/generate/BUILD index 228a0b88..03c31004 100644 --- a/tools/please_go/generate/BUILD +++ b/tools/please_go/generate/BUILD @@ -2,7 +2,7 @@ subinclude("///go//build_defs:go") filegroup( name = "srcs", - srcs = glob(["*.go"]), + srcs = glob(["*.go"], exclude=["*_test.go"]), visibility = ["//tools/please_go:bootstrap"], ) @@ -15,3 +15,12 @@ go_library( "//third_party/go:mod", ], ) + +go_test( + name = "generate_test", + srcs = glob(["*_test.go"]), + deps = [ + "//third_party/go:testify", + ":generate", + ], +) \ No newline at end of file diff --git a/tools/please_go/generate/generate.go b/tools/please_go/generate/generate.go index 19625135..f1e2e916 100644 --- a/tools/please_go/generate/generate.go +++ b/tools/please_go/generate/generate.go @@ -139,8 +139,7 @@ func (g *Generate) targetsInDir(dir string) []string { // for that package. Currently, we don't generate BUILD files for any other reason so this assumption holds // true. We may want to check that the BUILD file contains a go_library() target otherwise. if g.isBuildFile(path) { - pleasePkgDir := strings.TrimPrefix(strings.TrimPrefix(filepath.Dir(path), g.srcRoot), "/") - ret = append(ret, g.libTargetForPleasePackage(pleasePkgDir)) + ret = append(ret, g.libTargetForPleasePackage(trimPath(filepath.Dir(path), g.srcRoot))) } return nil }) @@ -150,6 +149,16 @@ func (g *Generate) targetsInDir(dir string) []string { return ret } +func (g *Generate) isBuildFile(file string) bool { + base := filepath.Base(file) + for _, file := range g.buildFileNames { + if base == file { + return true + } + } + return false +} + func (g *Generate) writeInstallFilegroup() error { buildFile, err := parseOrCreateBuildFile(g.srcRoot, g.buildFileNames) if err != nil { @@ -218,7 +227,8 @@ func (g *Generate) generateAll(dir string) error { if path != dir && strings.HasPrefix(info.Name(), "_") { return filepath.SkipDir } - if err := g.generate(filepath.Clean(strings.TrimPrefix(path, g.srcRoot))); err != nil { + + if err := g.generate(trimPath(path, g.srcRoot)); err != nil { switch err.(type) { case *build.NoGoError: // We might walk into a dir that has no .go files for the current arch. This shouldn't @@ -258,31 +268,13 @@ func (g *Generate) importDir(target string) (*build.Package, error) { return pkg, nil } -// trimPath is like strings.TrimPrefix but is path aware. It removes base from target if target starts with base, -// otherwise returns target unmodified. -func trimPath(target, base string) string { - baseParts := strings.Split(filepath.Clean(base), "/") - targetParts := strings.Split(filepath.Clean(target), "/") - - if len(targetParts) < len(baseParts) { - return target - } - - for i := range baseParts { - if baseParts[i] != targetParts[i] { - return target - } - } - return strings.Join(targetParts[len(baseParts):], "/") -} - func (g *Generate) generate(dir string) error { pkg, err := g.importDir(dir) if err != nil { return err } - lib := g.libRule(pkg) + lib := g.libRule(pkg, dir) if lib == nil { return nil } @@ -430,17 +422,13 @@ func (g *Generate) depTargets(imports []string) []string { return deps } -func (g *Generate) libRule(pkg *build.Package) *Rule { - // The name of the target should match the dir it's in, or the basename of the module if it's in the repo root. - name := filepath.Base(pkg.Dir) - if strings.HasSuffix(pkg.Dir, g.srcRoot) || name == "" { - name = filepath.Base(g.moduleName) - } - +func (g *Generate) libRule(pkg *build.Package, dir string) *Rule { if len(pkg.GoFiles) == 0 && len(pkg.CgoFiles) == 0 { return nil } + name := nameForLibInPkg(g.moduleName, trimPath(dir, g.srcRoot)) + return &Rule{ name: name, kind: packageKind(pkg), @@ -548,25 +536,52 @@ func (g *Generate) depTarget(importPath string) string { } subrepoName := g.subrepoName(module) - packageName := strings.TrimPrefix(importPath, module) - packageName = strings.TrimPrefix(packageName, "/") - name := filepath.Base(packageName) - if packageName == "" { - name = filepath.Base(module) - } + packageName := trimPath(importPath, module) + name := nameForLibInPkg(module, packageName) target := buildTarget(name, packageName, subrepoName) g.knownImportTargets[importPath] = target return target } +// nameForLibInPkg returns the lib target name for a target in pkg. The pkg should be the relative pkg part excluding +// the module, e.g. pkg would be asset, and module would be github.com/stretchr/testify for +// github.com/stretchr/testify/assert, +func nameForLibInPkg(module, pkg string) string { + name := filepath.Base(pkg) + if pkg == "" || pkg == "." { + name = filepath.Base(module) + } + + if name == "all" { + return "lib" + } + + return name +} + +// trimPath is like strings.TrimPrefix but is path aware. It removes base from target if target starts with base, +// otherwise returns target unmodified. +func trimPath(target, base string) string { + baseParts := filepath.SplitList(base) + targetParts := filepath.SplitList(target) + + if len(targetParts) < len(baseParts) { + return target + } + + for i := range baseParts { + if baseParts[i] != targetParts[i] { + return target + } + } + return strings.Join(targetParts[len(baseParts):], "/") +} + // libTargetForPleasePackage returns the build label for the go_library() target that would be generated for a package // at this path within the generated Please repo. func (g *Generate) libTargetForPleasePackage(pkg string) string { - if pkg == "" || pkg == "." { - return buildTarget(filepath.Base(g.moduleName), "", "") - } - return buildTarget(filepath.Base(pkg), pkg, "") + return buildTarget(nameForLibInPkg(g.moduleName, pkg), pkg, "") } func (g *Generate) subrepoName(module string) string { @@ -576,13 +591,29 @@ func (g *Generate) subrepoName(module string) string { return filepath.Join(g.thirdPartyFolder, strings.ReplaceAll(module, "/", "_")) } -func buildTarget(name, pkg, subrepo string) string { - if name == "" { - name = filepath.Base(pkg) +func buildTarget(name, pkgDir, subrepo string) string { + bs := new(strings.Builder) + if subrepo != "" { + bs.WriteString("///") + bs.WriteString(subrepo) + } + + // Bit of a special case here where we assume all build targets are absolute which is fine for our use case. + bs.WriteString("//") + + if pkgDir == "." { + pkgDir = "" } - target := fmt.Sprintf("%v:%v", pkg, name) - if subrepo == "" { - return fmt.Sprintf("//%v", target) + + if pkgDir != "" { + bs.WriteString(pkgDir) + if filepath.Base(pkgDir) != name { + bs.WriteString(":") + bs.WriteString(name) + } + } else { + bs.WriteString(":") + bs.WriteString(name) } - return fmt.Sprintf("///%v//%v", subrepo, target) + return bs.String() } diff --git a/tools/please_go/generate/generate_test.go b/tools/please_go/generate/generate_test.go new file mode 100644 index 00000000..f473d401 --- /dev/null +++ b/tools/please_go/generate/generate_test.go @@ -0,0 +1,160 @@ +package generate + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestTrimPath(t *testing.T) { + tests := []struct { + name string + base string + target string + expected string + }{ + { + name: "trims base path", + base: "third_party/go/_foo#dl", + target: "third_party/go/_foo#dl/foo", + expected: "foo", + }, + { + name: "returns target if base is shorter", + base: "foo/bar/baz", + target: "foo/bar", + expected: "foo/bar", + }, + { + name: "returns target if not in base", + base: "foo/bar", + target: "bar/baz", + expected: "bar/baz", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expected, trimPath(test.target, test.base)) + }) + } +} + +func TestBuildTarget(t *testing.T) { + tests := []struct { + test string + name, pkg, subrepo string + expected string + }{ + { + test: "fully qualified", + subrepo: "subrepo", + pkg: "pkg", + name: "name", + expected: "///subrepo//pkg:name", + }, + { + test: "fully qualified local package", + subrepo: "", + pkg: "pkg", + name: "name", + expected: "//pkg:name", + }, + { + test: "root package", + subrepo: "", + pkg: "", + name: "name", + expected: "//:name", + }, + { + test: "root package via .", + subrepo: "", + pkg: ".", + name: "name", + expected: "//:name", + }, + { + test: "root package in subrepo", + subrepo: "subrepo", + pkg: "", + name: "name", + expected: "///subrepo//:name", + }, + { + test: "pkg base matches name", + subrepo: "", + pkg: "foo", + name: "foo", + expected: "//foo", + }, + } + + for _, test := range tests { + t.Run(test.test, func(t *testing.T) { + assert.Equal(t, test.expected, buildTarget(test.name, test.pkg, test.subrepo)) + }) + } +} + +func TestDepTarget(t *testing.T) { + tests := []struct { + name string + deps []string + importTarget string + expected string + }{ + { + name: "resolves local import", + importTarget: "github.com/this/module/foo", + expected: "//foo", + }, + { + name: "resolves local import in base", + importTarget: "github.com/this/module", + expected: "//:module", + }, + { + name: "resolves import to another module", + importTarget: "github.com/some/module/foo", + expected: "///third_party/go/github.com_some_module//foo", + deps: []string{"github.com/some/module"}, + }, + { + name: "resolves import to longest match", + importTarget: "github.com/some/module/foo/bar", + expected: "///third_party/go/github.com_some_module_foo//bar", + deps: []string{"github.com/some/module", "github.com/some/module/foo"}, + }, + { + name: "root package matches module base", + importTarget: "github.com/some/module", + expected: "///third_party/go/github.com_some_module//:module", + deps: []string{"github.com/some/module"}, + }, + { + name: "replaces all with lib locally", + importTarget: "github.com/this/module/all", + expected: "//all:lib", + }, + { + name: "replaces all with lib in another repo", + importTarget: "github.com/some/module/all", + expected: "///third_party/go/github.com_some_module//all:lib", + deps: []string{"github.com/some/module"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + g := &Generate{ + moduleName: "github.com/this/module", + thirdPartyFolder: "third_party/go", + replace: map[string]string{}, + knownImportTargets: map[string]string{}, + moduleDeps: test.deps, + } + + assert.Equal(t, test.expected, g.depTarget(test.importTarget)) + }) + } +} diff --git a/tools/please_go/generate/update.go b/tools/please_go/generate/update.go deleted file mode 100644 index acb3c82a..00000000 --- a/tools/please_go/generate/update.go +++ /dev/null @@ -1,149 +0,0 @@ -// TODO(jpoole): better name for this -package generate - -import ( - "fmt" - "github.com/bazelbuild/buildtools/build" - gobuild "go/build" - "os" - "path/filepath" - "strings" -) - -// Update updates an existing Please project. It may create new BUILD files, however it tries to respect existing build -// rules, updating them as appropriate. -func (g *Generate) Update(moduleName string, paths []string) error { - done := map[string]struct{}{} - g.moduleName = moduleName - for _, path := range paths { - if strings.HasSuffix(path, "/...") { - path = strings.TrimSuffix(path, "/...") - err := filepath.WalkDir(path, func(path string, info os.DirEntry, err error) error { - if info.IsDir() { - return nil - } - - if g.isBuildFile(path) { - if err := g.update(done, path); err != nil { - return err - } - } - return nil - }) - if err != nil { - return err - } - } else { - if err := g.update(done, path); err != nil { - return err - } - } - } - return nil -} - -func (g *Generate) isBuildFile(file string) bool { - base := filepath.Base(file) - for _, file := range g.buildFileNames { - if base == file { - return true - } - } - return false -} - -func (g *Generate) findBuildFile(dir string) string { - for _, name := range g.buildFileNames { - path := filepath.Join(dir, name) - if _, err := os.Lstat(path); err == nil { - return path - } - } - return "" -} - -func (g *Generate) loadBuildFile(path string) (*build.File, error) { - if !g.isBuildFile(path) { - path = g.findBuildFile(path) - if path == "" { - return nil, fmt.Errorf("faild to find build file in %v", path) - } - } - bs, err := os.ReadFile(path) - if err != nil { - return nil, err - } - - return build.ParseBuild(path, bs) -} - -func (g *Generate) update(done map[string]struct{}, path string) error { - if _, ok := done[path]; ok { - return nil - } - defer func() { - done[path] = struct{}{} - }() - - dir := path - if g.isBuildFile(path) { - dir = filepath.Dir(path) - } - - // TODO(jpoole): we should break this up and check each source file so we can split tests out across multiple targets - pkg, err := g.importDir(dir) - if err != nil { - if _, ok := err.(*gobuild.NoGoError); ok { - return nil - } - return err - } - - libRule := g.libRule(pkg) - testRule := g.testRule(pkg, libRule) - - file, err := g.loadBuildFile(path) - if err != nil { - return err - } - - libDone := false - testDone := false - - for _, stmt := range file.Stmt { - if call, ok := stmt.(*build.CallExpr); ok { - rule := build.NewRule(call) - if (rule.Kind() == "go_library" && !pkg.IsCommand()) || (rule.Kind() == "go_binary" && pkg.IsCommand()) { - if libDone { - return fmt.Errorf("too many go_library rules in %v", path) - } - populateRule(rule, libRule) - libDone = true - } - if rule.Kind() == "go_test" { - if testDone { - fmt.Fprintln(os.Stderr, "WARNING: too many go_test rules in ", path) - continue - } - if rule.Attr("external") != nil { - continue - } - populateRule(rule, testRule) - testDone = true - } - } - } - - if !libDone && libRule != nil { - r := NewRule("go_library", libRule.name) - populateRule(r, libRule) - file.Stmt = append(file.Stmt, r.Call) - } - - if !testDone && testRule != nil { - r := NewRule("go_test", testRule.name) - populateRule(r, testRule) - file.Stmt = append(file.Stmt, r.Call) - } - return os.WriteFile(file.Path, build.Format(file), 0664) -}