From a19b9a6e7c9e61b67e20b6c862b49ba0dab1ef27 Mon Sep 17 00:00:00 2001 From: Adam Date: Wed, 15 May 2019 13:43:08 +1000 Subject: [PATCH] Rework the type system --- README.md | 21 +- dataloaden.go | 15 +- example/pkgname/user.go | 2 +- example/pkgname/userloader_gen.go | 18 +- example/slice/user.go | 2 +- example/slice/usersliceloader_gen.go | 24 ++- example/slice/usersliceloader_test.go | 205 ++++++++++++++++++++ example/user.go | 2 +- example/userloader_gen.go | 18 +- go.mod | 2 +- go.sum | 10 +- pkg/generator/generator.go | 118 +++++++---- pkg/generator/generator_test.go | 26 +++ pkg/generator/template.go | 101 +++++----- pkg/generator/testdata/mismatch/mismatch.go | 5 + 15 files changed, 431 insertions(+), 138 deletions(-) create mode 100644 example/slice/usersliceloader_test.go create mode 100644 pkg/generator/generator_test.go create mode 100644 pkg/generator/testdata/mismatch/mismatch.go diff --git a/README.md b/README.md index 13dcadc..f5bcb91 100644 --- a/README.md +++ b/README.md @@ -10,19 +10,17 @@ get used. #### Getting started -First grab it: +From inside the package you want to have the dataloader in: ```bash -go get -u github.com/vektah/dataloaden +go run github.com/vektah/dataloaden UserLoader string *github.com/dataloaden/example.User ``` -then from inside the package you want to have the dataloader in: -```bash -dataloaden github.com/dataloaden/example.User -``` +This will generate a dataloader called `UserLoader` that looks up `*github.com/dataloaden/example.User`'s objects +based on a `string` key. In another file in the same package, create the constructor method: ```go -func NewLoader() *UserLoader { +func NewUserLoader() *UserLoader { return &UserLoader{ wait: 2 * time.Millisecond, maxBatch: 100, @@ -41,7 +39,7 @@ func NewLoader() *UserLoader { Then wherever you want to call the dataloader ```go -loader := NewLoader() +loader := NewUserLoader() user, err := loader.Load("123") ``` @@ -51,13 +49,14 @@ function once. It also caches values and wont request duplicates in a batch. #### Returning Slices -You may want to generate a dataloader that returns slices instead of single values. This can be done using the `-slice` flag: +You may want to generate a dataloader that returns slices instead of single values. Both key and value types can be a +simple go type expression: ```bash -dataloaden -slice github.com/dataloaden/example.User +go run github.com/vektah/dataloaden UserSliceLoader string []*github.com/dataloaden/example.User ``` -Now each key is expected to return a slice of values and the `fetch` function has the return type `[][]User`. +Now each key is expected to return a slice of values and the `fetch` function has the return type `[][]*User`. #### Using with go modules diff --git a/dataloaden.go b/dataloaden.go index b1899e0..3419286 100644 --- a/dataloaden.go +++ b/dataloaden.go @@ -1,7 +1,6 @@ package main import ( - "flag" "fmt" "os" @@ -9,24 +8,20 @@ import ( ) func main() { - keyType := flag.String("keys", "int", "what type should the keys be") - slice := flag.Bool("slice", false, "this dataloader will return slices") - - flag.Parse() - - if flag.NArg() != 1 { - flag.Usage() + if len(os.Args) != 4 { + fmt.Println("usage: name keyType valueType") + fmt.Println(" example:") + fmt.Println(" dataloaden 'UserLoader int []*github.com/my/package.User'") os.Exit(1) } wd, err := os.Getwd() - if err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(2) } - if err := generator.Generate(flag.Arg(0), *keyType, *slice, wd); err != nil { + if err := generator.Generate(os.Args[1], os.Args[2], os.Args[3], wd); err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(2) } diff --git a/example/pkgname/user.go b/example/pkgname/user.go index 754c48c..f9a4bf5 100644 --- a/example/pkgname/user.go +++ b/example/pkgname/user.go @@ -1,3 +1,3 @@ package differentpkg -//go:generate go run github.com/vektah/dataloaden -keys string github.com/vektah/dataloaden/example.User +//go:generate go run github.com/vektah/dataloaden UserLoader string *github.com/vektah/dataloaden/example.User diff --git a/example/pkgname/userloader_gen.go b/example/pkgname/userloader_gen.go index 2ac65da..3495d73 100644 --- a/example/pkgname/userloader_gen.go +++ b/example/pkgname/userloader_gen.go @@ -48,13 +48,13 @@ type UserLoader struct { // the current batch. keys will continue to be collected until timeout is hit, // then everything will be sent to the fetch method and out to the listeners - batch *userBatch + batch *userLoaderBatch // mutex to prevent races mu sync.Mutex } -type userBatch struct { +type userLoaderBatch struct { keys []string data []*example.User error []error @@ -62,12 +62,12 @@ type userBatch struct { done chan struct{} } -// Load a user by key, batching and caching will be applied automatically +// Load a User by key, batching and caching will be applied automatically func (l *UserLoader) Load(key string) (*example.User, error) { return l.LoadThunk(key)() } -// LoadThunk returns a function that when called will block waiting for a user. +// LoadThunk returns a function that when called will block waiting for a User. // This method should be used if you want one goroutine to make requests to many // different data loaders without blocking until the thunk is called. func (l *UserLoader) LoadThunk(key string) func() (*example.User, error) { @@ -79,7 +79,7 @@ func (l *UserLoader) LoadThunk(key string) func() (*example.User, error) { } } if l.batch == nil { - l.batch = &userBatch{done: make(chan struct{})} + l.batch = &userLoaderBatch{done: make(chan struct{})} } batch := l.batch pos := batch.keyIndex(l, key) @@ -128,7 +128,7 @@ func (l *UserLoader) LoadAll(keys []string) ([]*example.User, []error) { return users, errors } -// LoadAllThunk returns a function that when called will block waiting for a users. +// LoadAllThunk returns a function that when called will block waiting for a Users. // This method should be used if you want one goroutine to make requests to many // different data loaders without blocking until the thunk is called. func (l *UserLoader) LoadAllThunk(keys []string) func() ([]*example.User, []error) { @@ -178,7 +178,7 @@ func (l *UserLoader) unsafeSet(key string, value *example.User) { // keyIndex will return the location of the key in the batch, if its not found // it will add the key to the batch -func (b *userBatch) keyIndex(l *UserLoader, key string) int { +func (b *userLoaderBatch) keyIndex(l *UserLoader, key string) int { for i, existingKey := range b.keys { if key == existingKey { return i @@ -202,7 +202,7 @@ func (b *userBatch) keyIndex(l *UserLoader, key string) int { return pos } -func (b *userBatch) startTimer(l *UserLoader) { +func (b *userLoaderBatch) startTimer(l *UserLoader) { time.Sleep(l.wait) l.mu.Lock() @@ -218,7 +218,7 @@ func (b *userBatch) startTimer(l *UserLoader) { b.end(l) } -func (b *userBatch) end(l *UserLoader) { +func (b *userLoaderBatch) end(l *UserLoader) { b.data, b.error = l.fetch(b.keys) close(b.done) } diff --git a/example/slice/user.go b/example/slice/user.go index 9b43ab4..767f2c1 100644 --- a/example/slice/user.go +++ b/example/slice/user.go @@ -1,4 +1,4 @@ -//go:generate go run github.com/vektah/dataloaden -keys int -slice github.com/vektah/dataloaden/example.User +//go:generate go run github.com/vektah/dataloaden UserSliceLoader int []github.com/vektah/dataloaden/example.User package slice diff --git a/example/slice/usersliceloader_gen.go b/example/slice/usersliceloader_gen.go index c6dfe82..c2d6e83 100644 --- a/example/slice/usersliceloader_gen.go +++ b/example/slice/usersliceloader_gen.go @@ -48,13 +48,13 @@ type UserSliceLoader struct { // the current batch. keys will continue to be collected until timeout is hit, // then everything will be sent to the fetch method and out to the listeners - batch *userSliceBatch + batch *userSliceLoaderBatch // mutex to prevent races mu sync.Mutex } -type userSliceBatch struct { +type userSliceLoaderBatch struct { keys []int data [][]example.User error []error @@ -62,12 +62,12 @@ type userSliceBatch struct { done chan struct{} } -// Load a user by key, batching and caching will be applied automatically +// Load a User by key, batching and caching will be applied automatically func (l *UserSliceLoader) Load(key int) ([]example.User, error) { return l.LoadThunk(key)() } -// LoadThunk returns a function that when called will block waiting for a user. +// LoadThunk returns a function that when called will block waiting for a User. // This method should be used if you want one goroutine to make requests to many // different data loaders without blocking until the thunk is called. func (l *UserSliceLoader) LoadThunk(key int) func() ([]example.User, error) { @@ -79,7 +79,7 @@ func (l *UserSliceLoader) LoadThunk(key int) func() ([]example.User, error) { } } if l.batch == nil { - l.batch = &userSliceBatch{done: make(chan struct{})} + l.batch = &userSliceLoaderBatch{done: make(chan struct{})} } batch := l.batch pos := batch.keyIndex(l, key) @@ -128,7 +128,7 @@ func (l *UserSliceLoader) LoadAll(keys []int) ([][]example.User, []error) { return users, errors } -// LoadAllThunk returns a function that when called will block waiting for a users. +// LoadAllThunk returns a function that when called will block waiting for a Users. // This method should be used if you want one goroutine to make requests to many // different data loaders without blocking until the thunk is called. func (l *UserSliceLoader) LoadAllThunk(keys []int) func() ([][]example.User, []error) { @@ -153,7 +153,11 @@ func (l *UserSliceLoader) Prime(key int, value []example.User) bool { l.mu.Lock() var found bool if _, found = l.cache[key]; !found { - l.unsafeSet(key, value) + // make a copy when writing to the cache, its easy to pass a pointer in from a loop var + // and end up with the whole cache pointing to the same value. + cpy := make([]example.User, len(value)) + copy(cpy, value) + l.unsafeSet(key, cpy) } l.mu.Unlock() return !found @@ -175,7 +179,7 @@ func (l *UserSliceLoader) unsafeSet(key int, value []example.User) { // keyIndex will return the location of the key in the batch, if its not found // it will add the key to the batch -func (b *userSliceBatch) keyIndex(l *UserSliceLoader, key int) int { +func (b *userSliceLoaderBatch) keyIndex(l *UserSliceLoader, key int) int { for i, existingKey := range b.keys { if key == existingKey { return i @@ -199,7 +203,7 @@ func (b *userSliceBatch) keyIndex(l *UserSliceLoader, key int) int { return pos } -func (b *userSliceBatch) startTimer(l *UserSliceLoader) { +func (b *userSliceLoaderBatch) startTimer(l *UserSliceLoader) { time.Sleep(l.wait) l.mu.Lock() @@ -215,7 +219,7 @@ func (b *userSliceBatch) startTimer(l *UserSliceLoader) { b.end(l) } -func (b *userSliceBatch) end(l *UserSliceLoader) { +func (b *userSliceLoaderBatch) end(l *UserSliceLoader) { b.data, b.error = l.fetch(b.keys) close(b.done) } diff --git a/example/slice/usersliceloader_test.go b/example/slice/usersliceloader_test.go new file mode 100644 index 0000000..857b197 --- /dev/null +++ b/example/slice/usersliceloader_test.go @@ -0,0 +1,205 @@ +package slice + +import ( + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vektah/dataloaden/example" +) + +func TestUserLoader(t *testing.T) { + var fetches [][]int + var mu sync.Mutex + + dl := &UserSliceLoader{ + wait: 10 * time.Millisecond, + maxBatch: 5, + fetch: func(keys []int) (users [][]example.User, errors []error) { + mu.Lock() + fetches = append(fetches, keys) + mu.Unlock() + + users = make([][]example.User, len(keys)) + errors = make([]error, len(keys)) + + for i, key := range keys { + if key%10 == 0 { // anything ending in zero is bad + errors[i] = fmt.Errorf("users not found") + } else { + users[i] = []example.User{ + {ID: strconv.Itoa(key), Name: "user " + strconv.Itoa(key)}, + {ID: strconv.Itoa(key), Name: "user " + strconv.Itoa(key)}, + } + } + } + return users, errors + }, + } + + t.Run("fetch concurrent data", func(t *testing.T) { + t.Run("load user successfully", func(t *testing.T) { + t.Parallel() + u, err := dl.Load(1) + require.NoError(t, err) + require.Equal(t, u[0].ID, "1") + require.Equal(t, u[1].ID, "1") + }) + + t.Run("load failed user", func(t *testing.T) { + t.Parallel() + u, err := dl.Load(10) + require.Error(t, err) + require.Nil(t, u) + }) + + t.Run("load many users", func(t *testing.T) { + t.Parallel() + u, err := dl.LoadAll([]int{2, 10, 20, 4}) + require.Equal(t, u[0][0].Name, "user 2") + require.Error(t, err[1]) + require.Error(t, err[2]) + require.Equal(t, u[3][0].Name, "user 4") + }) + + t.Run("load thunk", func(t *testing.T) { + t.Parallel() + thunk1 := dl.LoadThunk(5) + thunk2 := dl.LoadThunk(50) + + u1, err1 := thunk1() + require.NoError(t, err1) + require.Equal(t, "user 5", u1[0].Name) + + u2, err2 := thunk2() + require.Error(t, err2) + require.Nil(t, u2) + }) + }) + + t.Run("it sent two batches", func(t *testing.T) { + mu.Lock() + defer mu.Unlock() + + require.Len(t, fetches, 2) + assert.Len(t, fetches[0], 5) + assert.Len(t, fetches[1], 3) + }) + + t.Run("fetch more", func(t *testing.T) { + + t.Run("previously cached", func(t *testing.T) { + t.Parallel() + u, err := dl.Load(1) + require.NoError(t, err) + require.Equal(t, u[0].ID, "1") + }) + + t.Run("load many users", func(t *testing.T) { + t.Parallel() + u, err := dl.LoadAll([]int{2, 4}) + require.NoError(t, err[0]) + require.NoError(t, err[1]) + require.Equal(t, u[0][0].Name, "user 2") + require.Equal(t, u[1][0].Name, "user 4") + }) + }) + + t.Run("no round trips", func(t *testing.T) { + mu.Lock() + defer mu.Unlock() + + require.Len(t, fetches, 2) + }) + + t.Run("fetch partial", func(t *testing.T) { + t.Run("errors not in cache cache value", func(t *testing.T) { + t.Parallel() + u, err := dl.Load(20) + require.Nil(t, u) + require.Error(t, err) + }) + + t.Run("load all", func(t *testing.T) { + t.Parallel() + u, err := dl.LoadAll([]int{1, 4, 10, 9, 5}) + require.Equal(t, u[0][0].ID, "1") + require.Equal(t, u[1][0].ID, "4") + require.Error(t, err[2]) + require.Equal(t, u[3][0].ID, "9") + require.Equal(t, u[4][0].ID, "5") + }) + }) + + t.Run("one partial trip", func(t *testing.T) { + mu.Lock() + defer mu.Unlock() + + require.Len(t, fetches, 3) + require.Len(t, fetches[2], 3) // E1 U9 E2 in some random order + }) + + t.Run("primed reads dont hit the fetcher", func(t *testing.T) { + dl.Prime(99, []example.User{ + {ID: "U99", Name: "Primed user"}, + {ID: "U99", Name: "Primed user"}, + }) + u, err := dl.Load(99) + require.NoError(t, err) + require.Equal(t, "Primed user", u[0].Name) + + require.Len(t, fetches, 3) + }) + + t.Run("priming in a loop is safe", func(t *testing.T) { + users := [][]example.User{ + {{ID: "123", Name: "Alpha"}, {ID: "123", Name: "Alpha"}}, + {{ID: "124", Name: "Omega"}, {ID: "124", Name: "Omega"}}, + } + for _, user := range users { + id, _ := strconv.Atoi(user[0].ID) + dl.Prime(id, user) + } + + u, err := dl.Load(123) + require.NoError(t, err) + require.Equal(t, "Alpha", u[0].Name) + + u, err = dl.Load(124) + require.NoError(t, err) + require.Equal(t, "Omega", u[0].Name) + + require.Len(t, fetches, 3) + }) + + t.Run("cleared results will go back to the fetcher", func(t *testing.T) { + dl.Clear(99) + u, err := dl.Load(99) + require.NoError(t, err) + require.Equal(t, "user 99", u[0].Name) + + require.Len(t, fetches, 4) + }) + + t.Run("load all thunk", func(t *testing.T) { + thunk1 := dl.LoadAllThunk([]int{5, 6}) + thunk2 := dl.LoadAllThunk([]int{6, 60}) + + users1, err1 := thunk1() + + require.NoError(t, err1[0]) + require.NoError(t, err1[1]) + require.Equal(t, "user 5", users1[0][0].Name) + require.Equal(t, "user 6", users1[1][0].Name) + + users2, err2 := thunk2() + + require.NoError(t, err2[0]) + require.Error(t, err2[1]) + require.Equal(t, "user 6", users2[0][0].Name) + }) +} diff --git a/example/user.go b/example/user.go index 587f1c6..24d2863 100644 --- a/example/user.go +++ b/example/user.go @@ -1,4 +1,4 @@ -//go:generate go run github.com/vektah/dataloaden -keys string github.com/vektah/dataloaden/example.User +//go:generate go run github.com/vektah/dataloaden UserLoader string *github.com/vektah/dataloaden/example.User package example diff --git a/example/userloader_gen.go b/example/userloader_gen.go index d16ca20..470ba6a 100644 --- a/example/userloader_gen.go +++ b/example/userloader_gen.go @@ -46,13 +46,13 @@ type UserLoader struct { // the current batch. keys will continue to be collected until timeout is hit, // then everything will be sent to the fetch method and out to the listeners - batch *userBatch + batch *userLoaderBatch // mutex to prevent races mu sync.Mutex } -type userBatch struct { +type userLoaderBatch struct { keys []string data []*User error []error @@ -60,12 +60,12 @@ type userBatch struct { done chan struct{} } -// Load a user by key, batching and caching will be applied automatically +// Load a User by key, batching and caching will be applied automatically func (l *UserLoader) Load(key string) (*User, error) { return l.LoadThunk(key)() } -// LoadThunk returns a function that when called will block waiting for a user. +// LoadThunk returns a function that when called will block waiting for a User. // This method should be used if you want one goroutine to make requests to many // different data loaders without blocking until the thunk is called. func (l *UserLoader) LoadThunk(key string) func() (*User, error) { @@ -77,7 +77,7 @@ func (l *UserLoader) LoadThunk(key string) func() (*User, error) { } } if l.batch == nil { - l.batch = &userBatch{done: make(chan struct{})} + l.batch = &userLoaderBatch{done: make(chan struct{})} } batch := l.batch pos := batch.keyIndex(l, key) @@ -126,7 +126,7 @@ func (l *UserLoader) LoadAll(keys []string) ([]*User, []error) { return users, errors } -// LoadAllThunk returns a function that when called will block waiting for a users. +// LoadAllThunk returns a function that when called will block waiting for a Users. // This method should be used if you want one goroutine to make requests to many // different data loaders without blocking until the thunk is called. func (l *UserLoader) LoadAllThunk(keys []string) func() ([]*User, []error) { @@ -176,7 +176,7 @@ func (l *UserLoader) unsafeSet(key string, value *User) { // keyIndex will return the location of the key in the batch, if its not found // it will add the key to the batch -func (b *userBatch) keyIndex(l *UserLoader, key string) int { +func (b *userLoaderBatch) keyIndex(l *UserLoader, key string) int { for i, existingKey := range b.keys { if key == existingKey { return i @@ -200,7 +200,7 @@ func (b *userBatch) keyIndex(l *UserLoader, key string) int { return pos } -func (b *userBatch) startTimer(l *UserLoader) { +func (b *userLoaderBatch) startTimer(l *UserLoader) { time.Sleep(l.wait) l.mu.Lock() @@ -216,7 +216,7 @@ func (b *userBatch) startTimer(l *UserLoader) { b.end(l) } -func (b *userBatch) end(l *UserLoader) { +func (b *userLoaderBatch) end(l *UserLoader) { b.data, b.error = l.fetch(b.keys) close(b.done) } diff --git a/go.mod b/go.mod index 01c7c28..ba56ca0 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,5 @@ require ( github.com/pkg/errors v0.8.1 github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/testify v1.2.1 - golang.org/x/tools v0.0.0-20190125232054-d66bd3c5d5a6 + golang.org/x/tools v0.0.0-20190515012406-7d7faa4812bd ) diff --git a/go.sum b/go.sum index e6bd4bb..a350afb 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,15 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golangci/golangci-lint v1.13.1 h1:e4khAARYjxlcEtEZgPqqaeoIVWlHmsZ4c+g5nJUpdUQ= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.2.1 h1:52QO5WkIUcHGIR7EnGagH88x1bUzqGXTC5/1bDTUQ7U= github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -golang.org/x/tools v0.0.0-20190125232054-d66bd3c5d5a6 h1:iZgcI2DDp6zW5v9Z/5+f0NuqoxNdmzg4hivjk2WLXpY= -golang.org/x/tools v0.0.0-20190125232054-d66bd3c5d5a6/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190515012406-7d7faa4812bd h1:oMEQDWVXVNpceQoVd1JN3CQ7LYJJzs5qWqZIUcxXHHw= +golang.org/x/tools v0.0.0-20190515012406-7d7faa4812bd/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= diff --git a/pkg/generator/generator.go b/pkg/generator/generator.go index e40642d..ff618e7 100644 --- a/pkg/generator/generator.go +++ b/pkg/generator/generator.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "path/filepath" + "regexp" "strings" "unicode" @@ -14,26 +15,76 @@ import ( ) type templateData struct { - LoaderName string - BatchName string - Package string + Package string + Name string + KeyType *goType + ValType *goType +} + +type goType struct { + Modifiers string + ImportPath string + ImportName string Name string - KeyType string - ValType string - Import string - Slice bool } -func Generate(typename string, keyType string, slice bool, wd string) error { - data, err := getData(typename, keyType, slice, wd) +func (t *goType) String() string { + if t.ImportName != "" { + return t.Modifiers + t.ImportName + "." + t.Name + } + + return t.Modifiers + t.Name +} + +func (t *goType) IsPtr() bool { + return strings.HasPrefix(t.Modifiers, "*") +} + +func (t *goType) IsSlice() bool { + return strings.HasPrefix(t.Modifiers, "[]") +} + +var partsRe = regexp.MustCompile(`^([\[\]\*]*)(.*?)(\.\w*)?$`) + +func parseType(str string) (*goType, error) { + parts := partsRe.FindStringSubmatch(str) + if len(parts) != 4 { + return nil, fmt.Errorf("type must be in the form []*github.com/import/path.Name") + } + + t := &goType{ + Modifiers: parts[1], + ImportPath: parts[2], + Name: strings.TrimPrefix(parts[3], "."), + } + + if t.Name == "" { + t.Name = t.ImportPath + t.ImportPath = "" + } + + if t.ImportPath != "" { + p, err := packages.Load(&packages.Config{Mode: packages.NeedName}, t.ImportPath) + if err != nil { + return nil, err + } + if len(p) != 1 { + return nil, fmt.Errorf("not found") + } + + t.ImportName = p[0].Name + } + + return t, nil +} + +func Generate(name string, keyType string, valueType string, wd string) error { + data, err := getData(name, keyType, valueType, wd) if err != nil { return err } - filename := data.Name + "loader_gen.go" - if data.Slice { - filename = data.Name + "sliceloader_gen.go" - } + filename := strings.ToLower(data.Name) + "_gen.go" if err := writeTemplate(filepath.Join(wd, filename), data); err != nil { return err @@ -42,41 +93,34 @@ func Generate(typename string, keyType string, slice bool, wd string) error { return nil } -func getData(typeName string, keyType string, slice bool, wd string) (templateData, error) { +func getData(name string, keyType string, valueType string, wd string) (templateData, error) { var data templateData - parts := strings.Split(typeName, ".") - if len(parts) < 2 { - return templateData{}, fmt.Errorf("type must be in the form package.Name") - } - - name := parts[len(parts)-1] - importPath := strings.Join(parts[:len(parts)-1], ".") genPkg := getPackage(wd) if genPkg == nil { return templateData{}, fmt.Errorf("unable to find package info for " + wd) } + var err error + data.Name = name data.Package = genPkg.Name - data.LoaderName = name + "Loader" - data.BatchName = lcFirst(name) + "Batch" - data.Name = lcFirst(name) - data.KeyType = keyType - data.Slice = slice - - prefix := "*" - if slice { - prefix = "[]" - data.LoaderName = name + "SliceLoader" - data.BatchName = lcFirst(name) + "SliceBatch" + data.KeyType, err = parseType(keyType) + if err != nil { + return templateData{}, fmt.Errorf("key type: %s", err.Error()) + } + data.ValType, err = parseType(valueType) + if err != nil { + return templateData{}, fmt.Errorf("key type: %s", err.Error()) } // if we are inside the same package as the type we don't need an import and can refer directly to the type - if genPkg.PkgPath == importPath { - data.ValType = prefix + name - } else { - data.Import = importPath - data.ValType = prefix + filepath.Base(data.Import) + "." + name + if genPkg.PkgPath == data.ValType.ImportPath { + data.ValType.ImportName = "" + data.ValType.ImportPath = "" + } + if genPkg.PkgPath == data.KeyType.ImportPath { + data.KeyType.ImportName = "" + data.KeyType.ImportPath = "" } return data, nil diff --git a/pkg/generator/generator_test.go b/pkg/generator/generator_test.go new file mode 100644 index 0000000..ee8d2fe --- /dev/null +++ b/pkg/generator/generator_test.go @@ -0,0 +1,26 @@ +package generator + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseType(t *testing.T) { + require.Equal(t, &goType{Name: "string"}, parse("string")) + require.Equal(t, &goType{Name: "Time", ImportPath: "time", ImportName: "time"}, parse("time.Time")) + require.Equal(t, &goType{ + Name: "Foo", + ImportPath: "github.com/vektah/dataloaden/pkg/generator/testdata/mismatch", + ImportName: "mismatched", + }, parse("github.com/vektah/dataloaden/pkg/generator/testdata/mismatch.Foo")) +} + +func parse(s string) *goType { + t, err := parseType(s) + if err != nil { + panic(err) + } + + return t +} diff --git a/pkg/generator/template.go b/pkg/generator/template.go index e84c0a0..48f5ba2 100644 --- a/pkg/generator/template.go +++ b/pkg/generator/template.go @@ -2,7 +2,11 @@ package generator import "text/template" -var tpl = template.Must(template.New("generated").Parse(` +var tpl = template.Must(template.New("generated"). + Funcs(template.FuncMap{ + "lcFirst": lcFirst, + }). + Parse(` // Code generated by github.com/vektah/dataloaden, DO NOT EDIT. package {{.Package}} @@ -11,13 +15,14 @@ import ( "sync" "time" - {{if .Import}}"{{.Import}}"{{end}} + {{if .KeyType.ImportPath}}"{{.KeyType.ImportPath}}"{{end}} + {{if .ValType.ImportPath}}"{{.ValType.ImportPath}}"{{end}} ) -// {{.LoaderName}}Config captures the config to create a new {{.LoaderName}} -type {{.LoaderName}}Config struct { +// {{.Name}}Config captures the config to create a new {{.Name}} +type {{.Name}}Config struct { // Fetch is a method that provides the data for the loader - Fetch func(keys []{{.KeyType}}) ([]{{.ValType}}, []error) + Fetch func(keys []{{.KeyType.String}}) ([]{{.ValType.String}}, []error) // Wait is how long wait before sending a batch Wait time.Duration @@ -26,19 +31,19 @@ type {{.LoaderName}}Config struct { MaxBatch int } -// New{{.LoaderName}} creates a new {{.LoaderName}} given a fetch, wait, and maxBatch -func New{{.LoaderName}}(config {{.LoaderName}}Config) *{{.LoaderName}} { - return &{{.LoaderName}}{ +// New{{.Name}} creates a new {{.Name}} given a fetch, wait, and maxBatch +func New{{.Name}}(config {{.Name}}Config) *{{.Name}} { + return &{{.Name}}{ fetch: config.Fetch, wait: config.Wait, maxBatch: config.MaxBatch, } } -// {{.LoaderName}} batches and caches requests -type {{.LoaderName}} struct { +// {{.Name}} batches and caches requests +type {{.Name}} struct { // this method provides the data for the loader - fetch func(keys []{{.KeyType}}) ([]{{.ValType}}, []error) + fetch func(keys []{{.KeyType.String}}) ([]{{.ValType.String}}, []error) // how long to done before sending a batch wait time.Duration @@ -49,51 +54,51 @@ type {{.LoaderName}} struct { // INTERNAL // lazily created cache - cache map[{{.KeyType}}]{{.ValType}} + cache map[{{.KeyType.String}}]{{.ValType.String}} // the current batch. keys will continue to be collected until timeout is hit, // then everything will be sent to the fetch method and out to the listeners - batch *{{.BatchName}} + batch *{{.Name|lcFirst}}Batch // mutex to prevent races mu sync.Mutex } -type {{.BatchName}} struct { +type {{.Name|lcFirst}}Batch struct { keys []{{.KeyType}} - data []{{.ValType}} + data []{{.ValType.String}} error []error closing bool done chan struct{} } -// Load a {{.Name}} by key, batching and caching will be applied automatically -func (l *{{.LoaderName}}) Load(key {{.KeyType}}) ({{.ValType}}, error) { +// Load a {{.ValType.Name}} by key, batching and caching will be applied automatically +func (l *{{.Name}}) Load(key {{.KeyType.String}}) ({{.ValType.String}}, error) { return l.LoadThunk(key)() } -// LoadThunk returns a function that when called will block waiting for a {{.Name}}. +// LoadThunk returns a function that when called will block waiting for a {{.ValType.Name}}. // This method should be used if you want one goroutine to make requests to many // different data loaders without blocking until the thunk is called. -func (l *{{.LoaderName}}) LoadThunk(key {{.KeyType}}) func() ({{.ValType}}, error) { +func (l *{{.Name}}) LoadThunk(key {{.KeyType.String}}) func() ({{.ValType.String}}, error) { l.mu.Lock() if it, ok := l.cache[key]; ok { l.mu.Unlock() - return func() ({{.ValType}}, error) { + return func() ({{.ValType.String}}, error) { return it, nil } } if l.batch == nil { - l.batch = &{{.BatchName}}{done: make(chan struct{})} + l.batch = &{{.Name|lcFirst}}Batch{done: make(chan struct{})} } batch := l.batch pos := batch.keyIndex(l, key) l.mu.Unlock() - return func() ({{.ValType}}, error) { + return func() ({{.ValType.String}}, error) { <-batch.done - var data {{.ValType}} + var data {{.ValType.String}} if pos < len(batch.data) { data = batch.data[pos] } @@ -118,53 +123,59 @@ func (l *{{.LoaderName}}) LoadThunk(key {{.KeyType}}) func() ({{.ValType}}, erro // LoadAll fetches many keys at once. It will be broken into appropriate sized // sub batches depending on how the loader is configured -func (l *{{.LoaderName}}) LoadAll(keys []{{.KeyType}}) ([]{{.ValType}}, []error) { - results := make([]func() ({{.ValType}}, error), len(keys)) +func (l *{{.Name}}) LoadAll(keys []{{.KeyType}}) ([]{{.ValType.String}}, []error) { + results := make([]func() ({{.ValType.String}}, error), len(keys)) for i, key := range keys { results[i] = l.LoadThunk(key) } - {{.Name}}s := make([]{{.ValType}}, len(keys)) + {{.ValType.Name|lcFirst}}s := make([]{{.ValType.String}}, len(keys)) errors := make([]error, len(keys)) for i, thunk := range results { - {{.Name}}s[i], errors[i] = thunk() + {{.ValType.Name|lcFirst}}s[i], errors[i] = thunk() } - return {{.Name}}s, errors + return {{.ValType.Name|lcFirst}}s, errors } -// LoadAllThunk returns a function that when called will block waiting for a {{.Name}}s. +// LoadAllThunk returns a function that when called will block waiting for a {{.ValType.Name}}s. // This method should be used if you want one goroutine to make requests to many // different data loaders without blocking until the thunk is called. -func (l *{{.LoaderName}}) LoadAllThunk(keys []{{.KeyType}}) (func() ([]{{.ValType}}, []error)) { - results := make([]func() ({{.ValType}}, error), len(keys)) +func (l *{{.Name}}) LoadAllThunk(keys []{{.KeyType}}) (func() ([]{{.ValType.String}}, []error)) { + results := make([]func() ({{.ValType.String}}, error), len(keys)) for i, key := range keys { results[i] = l.LoadThunk(key) } - return func() ([]{{.ValType}}, []error) { - {{.Name}}s := make([]{{.ValType}}, len(keys)) + return func() ([]{{.ValType.String}}, []error) { + {{.ValType.Name|lcFirst}}s := make([]{{.ValType.String}}, len(keys)) errors := make([]error, len(keys)) for i, thunk := range results { - {{.Name}}s[i], errors[i] = thunk() + {{.ValType.Name|lcFirst}}s[i], errors[i] = thunk() } - return {{.Name}}s, errors + return {{.ValType.Name|lcFirst}}s, errors } } // Prime the cache with the provided key and value. If the key already exists, no change is made // and false is returned. // (To forcefully prime the cache, clear the key first with loader.clear(key).prime(key, value).) -func (l *{{.LoaderName}}) Prime(key {{.KeyType}}, value {{.ValType}}) bool { +func (l *{{.Name}}) Prime(key {{.KeyType}}, value {{.ValType.String}}) bool { l.mu.Lock() var found bool if _, found = l.cache[key]; !found { - {{- if .Slice }} - l.unsafeSet(key, value) - {{- else }} + {{- if .ValType.IsPtr }} // make a copy when writing to the cache, its easy to pass a pointer in from a loop var // and end up with the whole cache pointing to the same value. cpy := *value l.unsafeSet(key, &cpy) + {{- else if .ValType.IsSlice }} + // make a copy when writing to the cache, its easy to pass a pointer in from a loop var + // and end up with the whole cache pointing to the same value. + cpy := make({{.ValType.String}}, len(value)) + copy(cpy, value) + l.unsafeSet(key, cpy) + {{- else }} + l.unsafeSet(key, value) {{- end }} } l.mu.Unlock() @@ -172,22 +183,22 @@ func (l *{{.LoaderName}}) Prime(key {{.KeyType}}, value {{.ValType}}) bool { } // Clear the value at key from the cache, if it exists -func (l *{{.LoaderName}}) Clear(key {{.KeyType}}) { +func (l *{{.Name}}) Clear(key {{.KeyType}}) { l.mu.Lock() delete(l.cache, key) l.mu.Unlock() } -func (l *{{.LoaderName}}) unsafeSet(key {{.KeyType}}, value {{.ValType}}) { +func (l *{{.Name}}) unsafeSet(key {{.KeyType}}, value {{.ValType.String}}) { if l.cache == nil { - l.cache = map[{{.KeyType}}]{{.ValType}}{} + l.cache = map[{{.KeyType}}]{{.ValType.String}}{} } l.cache[key] = value } // keyIndex will return the location of the key in the batch, if its not found // it will add the key to the batch -func (b *{{.BatchName}}) keyIndex(l *{{.LoaderName}}, key {{.KeyType}}) int { +func (b *{{.Name|lcFirst}}Batch) keyIndex(l *{{.Name}}, key {{.KeyType}}) int { for i, existingKey := range b.keys { if key == existingKey { return i @@ -211,7 +222,7 @@ func (b *{{.BatchName}}) keyIndex(l *{{.LoaderName}}, key {{.KeyType}}) int { return pos } -func (b *{{.BatchName}}) startTimer(l *{{.LoaderName}}) { +func (b *{{.Name|lcFirst}}Batch) startTimer(l *{{.Name}}) { time.Sleep(l.wait) l.mu.Lock() @@ -227,7 +238,7 @@ func (b *{{.BatchName}}) startTimer(l *{{.LoaderName}}) { b.end(l) } -func (b *{{.BatchName}}) end(l *{{.LoaderName}}) { +func (b *{{.Name|lcFirst}}Batch) end(l *{{.Name}}) { b.data, b.error = l.fetch(b.keys) close(b.done) } diff --git a/pkg/generator/testdata/mismatch/mismatch.go b/pkg/generator/testdata/mismatch/mismatch.go new file mode 100644 index 0000000..79c8ba2 --- /dev/null +++ b/pkg/generator/testdata/mismatch/mismatch.go @@ -0,0 +1,5 @@ +package mismatched + +type Foo struct { + Name string +}