Skip to content

Commit

Permalink
feat(gpd): add overlay support for gopackagesdriver (#4101)
Browse files Browse the repository at this point in the history
- Introduced the ability to handle in-memory overlays via
map[string][]byte in gopackagesdriver, allowing files to be dynamically
provided as overlays during package resolution.
- Updated ResolveImports to accept overlays and apply them when parsing
Go files.
- Modified tests to include new overlay functionality, ensuring correct
import resolution for files supplied via overlay maps.
- Refactored runForTest function to support DriverRequestJson input,
making it more flexible for testing overlays.
- Applied overlays in both the JSONPackagesDriver and PackageRegistry,
providing seamless integration with existing workflows.

This feature enhances package resolution by supporting temporary or
unsaved files, making it useful for IDEs, linters, and test scenarios
where in-memory file content is needed.

**What type of PR is this?**
Feature

**What does this PR do? Why is it needed?**
Implements the overlay function specified in the gopackagesdriver
DriverRequest

**Which issues(s) does this PR fix?**
Fixes #4100
  • Loading branch information
LWarrens authored Sep 12, 2024
1 parent ee74d01 commit 0433213
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 35 deletions.
2 changes: 1 addition & 1 deletion go/tools/gopackagesdriver/driver_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type DriverRequest struct {
Tests bool `json:"tests"`
// Overlay maps file paths (relative to the driver's working directory) to the byte contents
// of overlay files.
// Overlay map[string][]byte `json:"overlay"`
Overlay map[string][]byte `json:"overlay"`
}

func ReadDriverRequest(r io.Reader) (*DriverRequest, error) {
Expand Down
15 changes: 13 additions & 2 deletions go/tools/gopackagesdriver/flatpackage.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
package main

import (
"bytes"
"encoding/json"
"fmt"
"go/parser"
"go/token"
"os"
"strconv"
"strings"
"io"
)

type ResolvePkgFunc func(importPath string) string
Expand Down Expand Up @@ -169,7 +171,9 @@ func (fp *FlatPackage) IsStdlib() bool {
return fp.Standard
}

func (fp *FlatPackage) ResolveImports(resolve ResolvePkgFunc) error {
// ResolveImports resolves imports for non-stdlib packages and integrates file overlays
// to allow modification of package imports without modifying disk files.
func (fp *FlatPackage) ResolveImports(resolve ResolvePkgFunc, overlays map[string][]byte) error {
// Stdlib packages are already complete import wise
if fp.IsStdlib() {
return nil
Expand All @@ -178,7 +182,14 @@ func (fp *FlatPackage) ResolveImports(resolve ResolvePkgFunc) error {
fset := token.NewFileSet()

for _, file := range fp.CompiledGoFiles {
f, err := parser.ParseFile(fset, file, nil, parser.ImportsOnly)
// Only assign overlayContent when an overlay for the file exists, since ParseFile checks by type.
// If overlay is assigned directly from the map, it will have []byte as type
// Empty []byte types are parsed into io.EOF
var overlayReader io.Reader
if content, ok := overlays[file]; ok {
overlayReader = bytes.NewReader(content)
}
f, err := parser.ParseFile(fset, file, overlayReader, parser.ImportsOnly)
if err != nil {
return err
}
Expand Down
107 changes: 81 additions & 26 deletions go/tools/gopackagesdriver/gopackagesdriver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"path/filepath"
"strings"
"testing"

"github.com/bazelbuild/rules_go/go/tools/bazel_testing"
)

Expand Down Expand Up @@ -91,7 +90,7 @@ const (
)

func TestBaseFileLookup(t *testing.T) {
resp := runForTest(t, ".", "file=hello.go")
resp := runForTest(t, DriverRequest{}, ".", "file=hello.go")

t.Run("roots", func(t *testing.T) {
if len(resp.Roots) != 1 {
Expand All @@ -106,12 +105,7 @@ func TestBaseFileLookup(t *testing.T) {
})

t.Run("package", func(t *testing.T) {
var pkg *FlatPackage
for _, p := range resp.Packages {
if p.ID == resp.Roots[0] {
pkg = p
}
}
pkg := findPackageByID(resp.Packages, resp.Roots[0])

if pkg == nil {
t.Errorf("Expected to find %q in resp.Packages", resp.Roots[0])
Expand Down Expand Up @@ -161,7 +155,7 @@ func TestBaseFileLookup(t *testing.T) {
}

func TestRelativeFileLookup(t *testing.T) {
resp := runForTest(t, "subhello", "file=./subhello.go")
resp := runForTest(t, DriverRequest{}, "subhello", "file=./subhello.go")

t.Run("roots", func(t *testing.T) {
if len(resp.Roots) != 1 {
Expand All @@ -176,12 +170,7 @@ func TestRelativeFileLookup(t *testing.T) {
})

t.Run("package", func(t *testing.T) {
var pkg *FlatPackage
for _, p := range resp.Packages {
if p.ID == resp.Roots[0] {
pkg = p
}
}
pkg := findPackageByID(resp.Packages, resp.Roots[0])

if pkg == nil {
t.Errorf("Expected to find %q in resp.Packages", resp.Roots[0])
Expand All @@ -197,7 +186,7 @@ func TestRelativeFileLookup(t *testing.T) {
}

func TestRelativePatternWildcardLookup(t *testing.T) {
resp := runForTest(t, "subhello", "./...")
resp := runForTest(t, DriverRequest{}, "subhello", "./...")

t.Run("roots", func(t *testing.T) {
if len(resp.Roots) != 1 {
Expand All @@ -212,12 +201,7 @@ func TestRelativePatternWildcardLookup(t *testing.T) {
})

t.Run("package", func(t *testing.T) {
var pkg *FlatPackage
for _, p := range resp.Packages {
if p.ID == resp.Roots[0] {
pkg = p
}
}
pkg := findPackageByID(resp.Packages, resp.Roots[0])

if pkg == nil {
t.Errorf("Expected to find %q in resp.Packages", resp.Roots[0])
Expand All @@ -233,7 +217,7 @@ func TestRelativePatternWildcardLookup(t *testing.T) {
}

func TestExternalTests(t *testing.T) {
resp := runForTest(t, ".", "file=hello_external_test.go")
resp := runForTest(t, DriverRequest{}, ".", "file=hello_external_test.go")
if len(resp.Roots) != 2 {
t.Errorf("Expected exactly two roots for package: %+v", resp.Roots)
}
Expand All @@ -259,7 +243,66 @@ func TestExternalTests(t *testing.T) {
}
}

func runForTest(t *testing.T, relativeWorkingDir string, args ...string) driverResponse {
func TestOverlay(t *testing.T) {
// format filepaths for overlay request using working directory
wd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}

// format filepaths for overlay request
helloPath := path.Join(wd, "hello.go")
subhelloPath := path.Join(wd, "subhello/subhello.go")

expectedImportsPerFile := map[string][]string{
helloPath: []string{"fmt"},
subhelloPath: []string{"os", "encoding/json"},
}

overlayDriverRequest := DriverRequest {
Overlay: map[string][]byte {
helloPath: []byte (`
package hello
import "fmt"
import "unknown/unknown-package"
func main() {
invalid code
}`),
subhelloPath: []byte (`
package subhello
import "os"
import "encoding/json"
func main() {
fmt.Fprintln(os.Stderr, "Subdirectory Hello World!")
}
`),
},
}

// run the driver with the overlay
helloResp := runForTest(t, overlayDriverRequest, ".", "file=hello.go")
subhelloResp := runForTest(t, overlayDriverRequest, "subhello", "file=subhello.go")

// get root packages
helloPkg := findPackageByID(helloResp.Packages, helloResp.Roots[0])
subhelloPkg := findPackageByID(subhelloResp.Packages, subhelloResp.Roots[0])
if helloPkg == nil {
t.Fatalf("hello package not found in response root %q", helloResp.Roots[0])
}
if subhelloPkg == nil {
t.Fatalf("subhello package not found in response %q", subhelloResp.Roots[0])
}

helloPkgImportPaths := keysFromMap(helloPkg.Imports)
subhelloPkgImportPaths := keysFromMap(subhelloPkg.Imports)

expectSetEquality(t, expectedImportsPerFile[helloPath], helloPkgImportPaths, "hello imports")
expectSetEquality(t, expectedImportsPerFile[subhelloPath], subhelloPkgImportPaths, "subhello imports")
}


func runForTest(t *testing.T, driverRequest DriverRequest, relativeWorkingDir string, args ...string) driverResponse {
t.Helper()

// Remove most environment variables, other than those on an allowlist.
Expand All @@ -268,7 +311,7 @@ func runForTest(t *testing.T, relativeWorkingDir string, args ...string) driverR
// If Bazel is invoked when these variables, it assumes (correctly)
// that it's being invoked by a test, and it does different things that
// we don't want. For example, it randomizes the output directory, which
// is extremely expensive here. Out test framework creates an output
// is extremely expensive here. Our test framework creates an output
// directory shared among go_bazel_tests and points to it using .bazelrc.
//
// This only works if TEST_TMPDIR is not set when invoking bazel.
Expand Down Expand Up @@ -318,7 +361,11 @@ func runForTest(t *testing.T, relativeWorkingDir string, args ...string) driverR
buildWorkingDirectory = oldBuildWorkingDirectory
}()

in := strings.NewReader("{}")
driverRequestJson, err := json.Marshal(driverRequest)
if err != nil {
t.Fatalf("Error serializing driver request: %v\n", err)
}
in := bytes.NewReader(driverRequestJson)
out := &bytes.Buffer{}
if err := run(context.Background(), in, out, args); err != nil {
t.Fatalf("running gopackagesdriver: %v", err)
Expand All @@ -343,3 +390,11 @@ func assertSuffixesInList(t *testing.T, list []string, expectedSuffixes ...strin
}
}
}

// expectSetEquality checks if two slices are equal sets and logs an error if they are not
func expectSetEquality(t *testing.T, expected []string, actual []string, setName string) {
t.Helper()
if !equalSets(expected, actual) {
t.Errorf("Expected %s %v, got %s %v", setName, expected, actual, setName)
}
}
6 changes: 3 additions & 3 deletions go/tools/gopackagesdriver/json_packages_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type JSONPackagesDriver struct {
registry *PackageRegistry
}

func NewJSONPackagesDriver(jsonFiles []string, prf PathResolverFunc, bazelVersion bazelVersion) (*JSONPackagesDriver, error) {
func NewJSONPackagesDriver(jsonFiles []string, prf PathResolverFunc, bazelVersion bazelVersion, overlays map[string][]byte) (*JSONPackagesDriver, error) {
jpd := &JSONPackagesDriver{
registry: NewPackageRegistry(bazelVersion),
}
Expand All @@ -40,8 +40,8 @@ func NewJSONPackagesDriver(jsonFiles []string, prf PathResolverFunc, bazelVersio
return nil, fmt.Errorf("unable to resolve paths: %w", err)
}

if err := jpd.registry.ResolveImports(); err != nil {
return nil, fmt.Errorf("unable to resolve paths: %w", err)
if err := jpd.registry.ResolveImports(overlays); err != nil {
return nil, fmt.Errorf("unable to resolve imports: %w", err)
}

return jpd, nil
Expand Down
2 changes: 1 addition & 1 deletion go/tools/gopackagesdriver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func run(ctx context.Context, in io.Reader, out io.Writer, args []string) error
return fmt.Errorf("unable to build JSON files: %w", err)
}

driver, err := NewJSONPackagesDriver(jsonFiles, bazelJsonBuilder.PathResolver(), bazel.version)
driver, err := NewJSONPackagesDriver(jsonFiles, bazelJsonBuilder.PathResolver(), bazel.version, request.Overlay)
if err != nil {
return fmt.Errorf("unable to load JSON files: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/tools/gopackagesdriver/packageregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (pr *PackageRegistry) ResolvePaths(prf PathResolverFunc) error {
// ResolveImports adds stdlib imports to packages. This is required because
// stdlib packages are not part of the JSON file exports as bazel is unaware of
// them.
func (pr *PackageRegistry) ResolveImports() error {
func (pr *PackageRegistry) ResolveImports(overlays map[string][]byte) error {
resolve := func(importPath string) string {
if pkgID, ok := pr.stdlib[importPath]; ok {
return pkgID
Expand All @@ -68,7 +68,7 @@ func (pr *PackageRegistry) ResolveImports() error {
}

for _, pkg := range pr.packagesByID {
if err := pkg.ResolveImports(resolve); err != nil {
if err := pkg.ResolveImports(resolve, overlays); err != nil {
return err
}
testFp := pkg.MoveTestFiles()
Expand Down
48 changes: 48 additions & 0 deletions go/tools/gopackagesdriver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,51 @@ func packageID(pattern string) string {

return fmt.Sprintf("//%s", pattern)
}

func findPackageByID(packages []*FlatPackage, id string) *FlatPackage {
for _, pkg := range packages {
if pkg.ID == id {
return pkg
}
}
return nil
}

// get map keys
func keysFromMap[K comparable, V any](m map[K]V) []K {
keys := make([]K, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}

// contains checks if a slice contains an element
func contains[S ~[]E, E comparable](set S, element E) bool {
found := false
for _, setElement := range set {
if setElement == element {
found = true
break
}
}
return found
}

// containsAll checks if a slice contains all elements of another slice
func containsAll[S ~[]E, E comparable](set S, subset S) bool {
for _, subsetElement := range subset {
if !contains(set, subsetElement) {
return false
}
}
return true
}

// equalSets checks if two slices are equal sets
func equalSets[S ~[]E, E comparable](set1 S, set2 S) bool {
if len(set1) != len(set2) {
return false
}
return containsAll(set1, set2)
}

0 comments on commit 0433213

Please sign in to comment.