Skip to content

Commit

Permalink
Add Print All Types
Browse files Browse the repository at this point in the history
  • Loading branch information
tung.tq committed Oct 24, 2023
1 parent 8ea51a4 commit 68b10cc
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 47 deletions.
121 changes: 85 additions & 36 deletions svloc.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type universeData struct {

cleared bool

regList []*registeredService

shutdownFuncs []func() // list of shutdown funcs from earliest to latest
alreadyShutdown bool
}
Expand All @@ -58,12 +60,41 @@ func NewUniverse() *Universe {
// After called, all other calls will panic, excepts for Shutdown
func (u *Universe) CleanUp() {
u.data.mut.Lock()
defer u.data.mut.Unlock()

u.data.svcMap = nil
u.data.regList = nil
u.data.cleared = true
}

defer u.data.mut.Unlock()
type getLoc struct {
regType reflect.Type
loc string
}

u.data.cleared = true
func (u *Universe) getPrintTypeLocations() []getLoc {
d := u.data

d.mut.Lock()
cloneList := make([]*registeredService, len(d.regList))
copy(cloneList, d.regList)
d.mut.Unlock()

result := make([]getLoc, 0, len(cloneList))
for _, e := range cloneList {
result = append(result, e.getLastOverrideLoc())
}
return result
}

// PrintAllUsedTypes prints all types that have been initialized
func (u *Universe) PrintAllUsedTypes() {
locs := u.getPrintTypeLocations()
printSeparateLine()
for _, loc := range locs {
_, _ = fmt.Fprintf(os.Stderr, "%s %s\n", loc.regType.String(), loc.loc)
}
printSeparateLine()
}

type registeredService struct {
Expand All @@ -80,6 +111,8 @@ type registeredService struct {

overrideCallLocation string

regType reflect.Type

newFunc func(unv *Universe) any

wrappers []func(unv *Universe, svc any) any
Expand All @@ -97,11 +130,12 @@ func (s *registeredService) newService(unv *Universe) any {

svc := s.newServiceSlow(unv, callLoc)

unv.data.appendShutdownFunc(s.onShutdown)
unv.data.appendShutdownFunc(s)

return svc
}

// callNewFuncAndWrappers already locked
func (s *registeredService) callNewFuncAndWrappers(unv *Universe) {
newFunc := s.loc.newFn
if s.newFunc != nil {
Expand All @@ -125,13 +159,15 @@ func (s *registeredService) callNewFuncAndWrappers(unv *Universe) {
}()

newSvc := newFunc(newUnv)
regType := reflect.TypeOf(newSvc)

for _, wrapper := range s.wrappers {
newSvc = wrapper(newUnv, newSvc)
}

s.svc = newSvc
s.createUnv = newUnv
s.regType = regType
}

func (s *registeredService) newServiceSlow(unv *Universe, callLoc string) any {
Expand Down Expand Up @@ -225,14 +261,17 @@ func (u *universeData) getService(
return svc, nil
}

func (u *universeData) appendShutdownFunc(fn func()) {
func (u *universeData) appendShutdownFunc(s *registeredService) {
u.mut.Lock()
defer u.mut.Unlock()

if u.alreadyShutdown {
panic("svloc: can NOT call 'Get' after 'Shutdown'")
}

u.regList = append(u.regList, s)

fn := s.onShutdown
if fn == nil {
return
}
Expand All @@ -252,8 +291,8 @@ type locatorData struct {

// Get can be called multiple times but the newFn inside Register* will be called ONCE.
// It can panic if Universe.Shutdown already called
func (s *Locator[T]) Get(unv *Universe) T {
reg, err := unv.data.getService(&s.data, "Get")
func (l *Locator[T]) Get(unv *Universe) T {
reg, err := unv.data.getService(&l.data, "Get")
if err != nil {
panic(err.Error())
}
Expand All @@ -268,18 +307,18 @@ func (s *Locator[T]) Get(unv *Universe) T {
}

// Override the value returned by Get, it also prevents running of the function inside Register
func (s *Locator[T]) Override(unv *Universe, svc T) error {
return s.overrideFuncWithLoc(unv, func(unv *Universe) T {
func (l *Locator[T]) Override(unv *Universe, svc T) error {
return l.overrideFuncWithLoc(unv, func(unv *Universe) T {
return svc
}, getCallerLocation())
}

func (s *Locator[T]) doBeforeGet(
func (l *Locator[T]) doBeforeGet(
unv *Universe,
methodName string,
handler func(reg *registeredService),
) error {
reg, err := unv.data.getService(&s.data, methodName)
reg, err := unv.data.getService(&l.data, methodName)
if err != nil {
return err
}
Expand All @@ -298,60 +337,60 @@ func (s *Locator[T]) doBeforeGet(
}

// OverrideFunc ...
func (s *Locator[T]) OverrideFunc(unv *Universe, newFn func(unv *Universe) T) error {
return s.overrideFuncWithLoc(unv, newFn, getCallerLocation())
func (l *Locator[T]) OverrideFunc(unv *Universe, newFn func(unv *Universe) T) error {
return l.overrideFuncWithLoc(unv, newFn, getCallerLocation())
}

func (s *Locator[T]) overrideFuncWithLoc(
func (l *Locator[T]) overrideFuncWithLoc(
unv *Universe, newFn func(unv *Universe) T,
callLoc string,
) error {
if unv.prev != nil {
return errOverrideInsideNewFunctions
}
return s.doBeforeGet(unv, "Override", func(reg *registeredService) {
return l.doBeforeGet(unv, "Override", func(reg *registeredService) {
reg.overrideCallLocation = callLoc
reg.newFunc = func(unv *Universe) any {
return newFn(unv)
}
})
}

func (s *Locator[T]) panicOverrideError(err error) {
func (l *Locator[T]) panicOverrideError(err error) {
var val *T
svcType := reflect.TypeOf(val).Elem()

panic(fmt.Sprintf("Can NOT override service of type '%v', err: %v", svcType, err))
}

// MustOverride will panic if Override returns error
func (s *Locator[T]) MustOverride(unv *Universe, svc T) {
err := s.overrideFuncWithLoc(unv, func(unv *Universe) T {
func (l *Locator[T]) MustOverride(unv *Universe, svc T) {
err := l.overrideFuncWithLoc(unv, func(unv *Universe) T {
return svc
}, getCallerLocation())
if err != nil {
s.panicOverrideError(err)
l.panicOverrideError(err)
}
}

// MustOverrideFunc similar to OverrideFunc but panics if error returned
func (s *Locator[T]) MustOverrideFunc(unv *Universe, newFn func(unv *Universe) T) {
err := s.overrideFuncWithLoc(unv, newFn, getCallerLocation())
func (l *Locator[T]) MustOverrideFunc(unv *Universe, newFn func(unv *Universe) T) {
err := l.overrideFuncWithLoc(unv, newFn, getCallerLocation())
if err != nil {
s.panicOverrideError(err)
l.panicOverrideError(err)
}
}

// Wrap the original implementation with the object created by wrapper
func (s *Locator[T]) Wrap(unv *Universe, wrapper func(unv *Universe, svc T) T) (err error) {
return s.wrapWithLoc(unv, wrapper, getCallerLocation())
func (l *Locator[T]) Wrap(unv *Universe, wrapper func(unv *Universe, svc T) T) (err error) {
return l.wrapWithLoc(unv, wrapper, getCallerLocation())
}

func (s *Locator[T]) wrapWithLoc(
func (l *Locator[T]) wrapWithLoc(
unv *Universe, wrapper func(unv *Universe, svc T) T,
callLoc string,
) (err error) {
return s.doBeforeGet(unv, "Wrap", func(reg *registeredService) {
return l.doBeforeGet(unv, "Wrap", func(reg *registeredService) {
reg.wrappers = append(reg.wrappers, func(unv *Universe, svc any) any {
return wrapper(unv, svc.(T))
})
Expand All @@ -360,8 +399,8 @@ func (s *Locator[T]) wrapWithLoc(
}

// MustWrap similar to Wrap, but it will panic if not succeeded
func (s *Locator[T]) MustWrap(unv *Universe, wrapper func(unv *Universe, svc T) T) {
err := s.wrapWithLoc(unv, wrapper, getCallerLocation())
func (l *Locator[T]) MustWrap(unv *Universe, wrapper func(unv *Universe, svc T) T) {
err := l.wrapWithLoc(unv, wrapper, getCallerLocation())
if err != nil {
var val *T

Expand All @@ -376,24 +415,34 @@ func (s *Locator[T]) MustWrap(unv *Universe, wrapper func(unv *Universe, svc T)

// GetLastOverrideLocation returns the last location that Override* is called.
// If no Override* functions is called, returns the Register location
func (s *Locator[T]) GetLastOverrideLocation(unv *Universe) (string, error) {
reg, err := unv.data.getService(&s.data, "GetLastOverrideLocation")
func (l *Locator[T]) GetLastOverrideLocation(unv *Universe) (string, error) {
reg, err := unv.data.getService(&l.data, "GetLastOverrideLocation")
if err != nil {
return "", err
}
return reg.getLastOverrideLoc().loc, nil
}

reg.mut.Lock()
defer reg.mut.Unlock()
func (s *registeredService) getLastOverrideLoc() getLoc {
s.mut.Lock()
defer s.mut.Unlock()

if s.overrideCallLocation != "" {
return getLoc{
regType: s.regType,
loc: s.overrideCallLocation,
}
}

if reg.overrideCallLocation != "" {
return reg.overrideCallLocation, nil
return getLoc{
regType: s.regType,
loc: s.loc.registerLoc,
}
return s.data.registerLoc, nil
}

// GetWrapLocations returns Wrap* call's locations
func (s *Locator[T]) GetWrapLocations(unv *Universe) ([]string, error) {
reg, err := unv.data.getService(&s.data, "GetWrapLocations")
func (l *Locator[T]) GetWrapLocations(unv *Universe) ([]string, error) {
reg, err := unv.data.getService(&l.data, "GetWrapLocations")
if err != nil {
return nil, err
}
Expand Down
59 changes: 48 additions & 11 deletions svloc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package svloc

import (
"errors"
"reflect"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -727,7 +728,7 @@ func TestLocator_Do_Shutdown_Complex(t *testing.T) {
}

func TestSizeOfRegisteredService(t *testing.T) {
assert.Equal(t, 144, int(unsafe.Sizeof(registeredService{})))
assert.Equal(t, 160, int(unsafe.Sizeof(registeredService{})))
}

func TestUniverse_CleanUp(t *testing.T) {
Expand Down Expand Up @@ -847,8 +848,8 @@ func TestLocator_GetLastOverrideLocation(t *testing.T) {
loc, err := repoLoc.GetLastOverrideLocation(unv)
assert.Equal(t, nil, err)

expect := "svloc_test.go:841"
assert.Equal(t, expect, loc[len(loc)-len(expect):])
expect := "svloc_test.go:842"
assertSuffixEqual(t, expect, loc)
})

t.Run("after override", func(t *testing.T) {
Expand All @@ -863,8 +864,8 @@ func TestLocator_GetLastOverrideLocation(t *testing.T) {
loc, err := repoLoc.GetLastOverrideLocation(unv)
assert.Equal(t, nil, err)

expect := "svloc_test.go:861"
assert.Equal(t, expect, loc[len(loc)-len(expect):])
expect := "svloc_test.go:862"
assertSuffixEqual(t, expect, loc)
})

t.Run("after override func", func(t *testing.T) {
Expand All @@ -882,8 +883,8 @@ func TestLocator_GetLastOverrideLocation(t *testing.T) {
loc, err := repoLoc.GetLastOverrideLocation(unv)
assert.Equal(t, nil, err)

expect := "svloc_test.go:877"
assert.Equal(t, expect, loc[len(loc)-len(expect):])
expect := "svloc_test.go:878"
assertSuffixEqual(t, expect, loc)
})

t.Run("after clean up", func(t *testing.T) {
Expand Down Expand Up @@ -940,11 +941,11 @@ func TestLocator_GetWrapLocations(t *testing.T) {

assert.Equal(t, 2, len(locs))

expect := "svloc_test.go:924"
assert.Equal(t, expect, locs[0][len(locs[0])-len(expect):])
expect := "svloc_test.go:925"
assertSuffixEqual(t, expect, locs[0])

expect = "svloc_test.go:931"
assert.Equal(t, expect, locs[1][len(locs[1])-len(expect):])
expect = "svloc_test.go:932"
assertSuffixEqual(t, expect, locs[1])
})

t.Run("fail after clean up", func(t *testing.T) {
Expand Down Expand Up @@ -994,3 +995,39 @@ func BenchmarkLocator_Mutex_Lock_Unlock(b *testing.B) {
benchMutex.Unlock()
}
}

func TestUniverse_getPrintLocations(t *testing.T) {
t.Run("normal", func(t *testing.T) {
repoLoc := Register[Repo](func(unv *Universe) Repo {
return &UserRepo{}
})

svcLoc := Register[*UserService](func(unv *Universe) *UserService {
return NewService(repoLoc.Get(unv))
})

unv := NewUniverse()

svc := svcLoc.Get(unv)

assert.Equal(t, "hello: user_repo", svc.Hello())

locs := unv.getPrintTypeLocations()
assert.Equal(t, 2, len(locs))

unv.PrintAllUsedTypes()

loc1 := "svloc_test.go:1001"
loc2 := "svloc_test.go:1005"

assertSuffixEqual(t, loc1, locs[0].loc)
assertSuffixEqual(t, loc2, locs[1].loc)

assert.Equal(t, reflect.TypeOf(&UserRepo{}), locs[0].regType)
assert.Equal(t, reflect.TypeOf(&UserService{}), locs[1].regType)
})
}

func assertSuffixEqual(t *testing.T, suffix string, s string) {
assert.Equal(t, suffix, s[len(s)-len(suffix):])
}

0 comments on commit 68b10cc

Please sign in to comment.