Skip to content

Commit

Permalink
Fix of circular dependencies validation (#49)
Browse files Browse the repository at this point in the history
* Fix of circular dependencies validation.
* Fix for function source detection.

Signed-off-by: Pavel Patrin <[email protected]>
  • Loading branch information
pavelpatrin authored Aug 17, 2024
1 parent dfcc976 commit 1d483ab
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 84 deletions.
7 changes: 7 additions & 0 deletions factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ func splitFuncName(funcFullName string) (string, string) {
}
}

// If the name contains no package path.
if lastPackageChunkIndex == -1 {
packageName := fullNameChunks[0]
funcName := strings.Join(fullNameChunks[1:], ".")
return packageName, funcName
}

// Prepare package name and function name.
packageName := strings.Join(fullNameChunks[:lastPackageChunkIndex+1], ".")
funcName := strings.Join(fullNameChunks[lastPackageChunkIndex+1:], ".")
Expand Down
31 changes: 31 additions & 0 deletions factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,34 @@ func TestFactoryInfo(t *testing.T) {
type globalType struct{}

func globalFunc(string) {}

// TestSplitFuncName tests splitting of function name.
func TestSplitFuncName(t *testing.T) {
tests := []struct {
name string
arg string
want1 string
want2 string
}{{
name: "SplitPublicPackage",
arg: "github.com/NVIDIA/gontainer/app.WithApp.func1",
want1: "github.com/NVIDIA/gontainer/app",
want2: "WithApp.func1",
}, {
name: "SplitMainPackage",
arg: "main.main.func1",
want1: "main",
want2: "main.func1",
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got1, got2 := splitFuncName(tt.arg)
if got1 != tt.want1 {
t.Errorf("splitFuncName() got1 = %v, want %v", got1, tt.want1)
}
if got2 != tt.want2 {
t.Errorf("splitFuncName() got2 = %v, want %v", got2, tt.want2)
}
})
}
}
97 changes: 47 additions & 50 deletions registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"fmt"
"reflect"
"runtime"
"slices"
)

// registry contains all defined factories metadata.
Expand Down Expand Up @@ -52,9 +51,8 @@ func (r *registry) registerFactory(ctx context.Context, factory *Factory) error
func (r *registry) validateFactories() error {
var errs []error

// Validate all factories.
// Validate all input types are resolvable.
for _, factory := range r.factories {
// Validate all input types are resolvable.
for index, factoryInType := range factory.factoryInTypes {
// Is this type a special factory context type?
if isContextInterface(factoryInType) {
Expand All @@ -77,73 +75,72 @@ func (r *registry) validateFactories() error {
typeFactories, _ := r.findFactoriesFor(factoryInType)
if len(typeFactories) == 0 {
errs = append(errs, fmt.Errorf(
"failed to validate service '%s' (argument %d) of '%s' from '%s': %w",
"failed to validate argument '%s' (index %d) of factory '%s' from '%s': %w",
factoryInType, index, factory.Name(), factory.Source(), ErrServiceNotResolved,
))
continue
}
}
}

// Validate all input types have no circular dependencies.
for index, factoryInType := range factory.factoryInTypes {
// Is this type a special factory context type?
if isContextInterface(factoryInType) {
// Validate all output types are unique.
for _, factory := range r.factories {
for index, factoryOutType := range factory.factoryOutTypes {
// Factories returning `any` could be duplicated.
if isEmptyInterface(factoryOutType) {
continue
}

// Validate dependencies graph of this type.
validationQueue := []reflect.Type{factoryInType}
var validatedTypes []reflect.Type
for len(validationQueue) > 0 {
validatingType := validationQueue[0]
validationQueue = validationQueue[1:]
// Validate uniqueness of the every factory output type.
factoriesForSameOutType, _ := r.findFactoriesFor(factoryOutType)
if len(factoriesForSameOutType) > 1 {
errs = append(errs, fmt.Errorf(
"failed to validate output '%s' (index %d) of factory '%s' from '%s': %w",
factoryOutType, index, factory.Name(), factory.Source(), ErrServiceDuplicated,
))
}
}
}

// Validate for circular dependencies.
for index := range r.factories {
factories := []*Factory{r.factories[index]}
recursion:
for len(factories) > 0 {
factory := factories[0]
factories = factories[1:]

for _, factoryInType := range factory.factoryInTypes {
// Is this type a special factory context type?
if isContextInterface(factoryInType) {
continue
}

// Is this type wrapped to the `Optional[type]`?
innerType, isOptional := isOptionalType(validatingType)
innerType, isOptional := isOptionalType(factoryInType)
if isOptional {
validatingType = innerType
factoryInType = innerType
}

// Is this type wrapped to the `Multiple[type]`?
innerType, isMultiple := isMultipleType(validatingType)
innerType, isMultiple := isMultipleType(factoryInType)
if isMultiple {
validatingType = innerType
factoryInType = innerType
}

// Was this type already validated before?
if slices.Contains(validatedTypes, validatingType) {
errs = append(errs, fmt.Errorf(
"failed to validate service '%s' (argument %d) of '%s' from '%s': %w",
validatingType, index, factory.Name(), factory.Source(), ErrCircularDependency,
))
break
}

// Register type as validated.
validatedTypes = append(validatedTypes, validatingType)

// Walk through all input types of all factories.
typeFactories, _ := r.findFactoriesFor(validatingType)
for _, typeFactory := range typeFactories {
validationQueue = append(validationQueue, typeFactory.factoryInTypes...)
// Walk through all factories for this in argument type.
factoriesForType, _ := r.findFactoriesFor(factoryInType)
for _, factoryForType := range factoriesForType {
if factoryForType == r.factories[index] {
errs = append(errs, fmt.Errorf(
"failed to validate factory '%s' from '%s': %w",
r.factories[index].Name(), r.factories[index].Source(), ErrCircularDependency,
))
break recursion
}
}
}
}

// Validate all output types are unique.
for index, factoryOutType := range factory.factoryOutTypes {
// Factories returning `any` could be duplicated.
if isEmptyInterface(factoryOutType) {
continue
}

// Validate uniqueness of the every factory output type.
factoriesForSameOutType, _ := r.findFactoriesFor(factoryOutType)
if len(factoriesForSameOutType) > 1 {
errs = append(errs, fmt.Errorf(
"failed to validate service '%s' (output %d) of '%s' from '%s': %w",
factoryOutType, index, factory.Name(), factory.Source(), ErrServiceDuplicated,
))
factories = append(factories, factoriesForType...)
}
}
}
Expand Down
63 changes: 29 additions & 34 deletions registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ func TestRegistryValidateFactories(t *testing.T) {
equal(t, len(errs), 2)

equal(t, errors.Is(errs[0], ErrServiceNotResolved), true)
equal(t, errs[0].Error(), "failed to validate service 'bool' (argument 0) "+
"of 'Factory[func(bool) error]' from 'github.com/NVIDIA/gontainer': "+
equal(t, errs[0].Error(), "failed to validate argument 'bool' (index 0) "+
"of factory 'Factory[func(bool) error]' from 'github.com/NVIDIA/gontainer': "+
"service not resolved")

equal(t, errors.Is(errs[1], ErrServiceNotResolved), true)
equal(t, errs[1].Error(), "failed to validate service 'string' (argument 0) "+
"of 'Factory[func(string) error]' from 'github.com/NVIDIA/gontainer': "+
equal(t, errs[1].Error(), "failed to validate argument 'string' (index 0) "+
"of factory 'Factory[func(string) error]' from 'github.com/NVIDIA/gontainer': "+
"service not resolved")
},
},
Expand All @@ -85,13 +85,13 @@ func TestRegistryValidateFactories(t *testing.T) {
equal(t, len(errs), 2)

equal(t, errors.Is(errs[0], ErrServiceDuplicated), true)
equal(t, errs[0].Error(), "failed to validate service 'string' (output 0) of "+
"'Factory[func() (string, error)]' from 'github.com/NVIDIA/gontainer': "+
equal(t, errs[0].Error(), "failed to validate output 'string' (index 0) "+
"of factory 'Factory[func() (string, error)]' from 'github.com/NVIDIA/gontainer': "+
"service duplicated")

equal(t, errors.Is(errs[1], ErrServiceDuplicated), true)
equal(t, errs[1].Error(), "failed to validate service 'string' (output 0) of "+
"'Factory[func() (string, error)]' from 'github.com/NVIDIA/gontainer': "+
equal(t, errs[1].Error(), "failed to validate output 'string' (index 0) "+
"of factory 'Factory[func() (string, error)]' from 'github.com/NVIDIA/gontainer': "+
"service duplicated")
},
},
Expand All @@ -111,19 +111,16 @@ func TestRegistryValidateFactories(t *testing.T) {
equal(t, len(errs), 3)

equal(t, errors.Is(errs[0], ErrCircularDependency), true)
equal(t, errs[0].Error(), "failed to validate service 'bool' (argument 0) "+
"of 'Factory[func(bool) (int, error)]' from 'github.com/NVIDIA/gontainer': "+
"circular dependency")
equal(t, errs[0].Error(), "failed to validate factory 'Factory[func(bool) (int, error)]' "+
"from 'github.com/NVIDIA/gontainer': circular dependency")

equal(t, errors.Is(errs[1], ErrCircularDependency), true)
equal(t, errs[1].Error(), "failed to validate service 'string' (argument 0) "+
"of 'Factory[func(string) (bool, error)]' from 'github.com/NVIDIA/gontainer': "+
"circular dependency")
equal(t, errs[1].Error(), "failed to validate factory 'Factory[func(string) (bool, error)]' "+
"from 'github.com/NVIDIA/gontainer': circular dependency")

equal(t, errors.Is(errs[2], ErrCircularDependency), true)
equal(t, errs[2].Error(), "failed to validate service 'int' (argument 0) "+
"of 'Factory[func(int) (string, error)]' from 'github.com/NVIDIA/gontainer': "+
"circular dependency")
equal(t, errs[2].Error(), "failed to validate factory 'Factory[func(int) (string, error)]' "+
"from 'github.com/NVIDIA/gontainer': circular dependency")
},
},
{
Expand All @@ -145,34 +142,32 @@ func TestRegistryValidateFactories(t *testing.T) {
equal(t, len(errs), 6)

equal(t, errors.Is(errs[0], ErrServiceNotResolved), true)
equal(t, errs[0].Error(), "failed to validate service 'struct { X int }' (argument 0) "+
"of 'Factory[func(struct { X int }) string]' from 'github.com/NVIDIA/gontainer': "+
equal(t, errs[0].Error(), "failed to validate argument 'struct { X int }' (index 0) "+
"of factory 'Factory[func(struct { X int }) string]' from 'github.com/NVIDIA/gontainer': "+
"service not resolved")

equal(t, errors.Is(errs[1], ErrServiceDuplicated), true)
equal(t, errs[1].Error(), "failed to validate service 'string' (output 0) "+
"of 'Factory[func(struct { X int }) string]' from 'github.com/NVIDIA/gontainer': "+
equal(t, errs[1].Error(), "failed to validate output 'string' (index 0) "+
"of factory 'Factory[func(struct { X int }) string]' from 'github.com/NVIDIA/gontainer': "+
"service duplicated")

equal(t, errors.Is(errs[2], ErrServiceDuplicated), true)
equal(t, errs[2].Error(), "failed to validate service 'string' (output 0) "+
"of 'Factory[func(context.Context) (string, error)]' from 'github.com/NVIDIA/gontainer': "+
equal(t, errs[2].Error(), "failed to validate output 'string' (index 0) "+
"of factory 'Factory[func(context.Context) (string, error)]' from 'github.com/NVIDIA/gontainer': "+
"service duplicated")

equal(t, errors.Is(errs[3], ErrCircularDependency), true)
equal(t, errs[3].Error(), "failed to validate service 'bool' (argument 0) "+
"of 'Factory[func(bool) (int, error)]' from 'github.com/NVIDIA/gontainer': "+
"circular dependency")
equal(t, errors.Is(errs[3], ErrServiceDuplicated), true)
equal(t, errs[3].Error(), "failed to validate output 'string' (index 1) "+
"of factory 'Factory[func(int) (bool, string)]' from 'github.com/NVIDIA/gontainer': "+
"service duplicated")

equal(t, errors.Is(errs[4], ErrCircularDependency), true)
equal(t, errs[4].Error(), "failed to validate service 'int' (argument 0) "+
"of 'Factory[func(int) (bool, string)]' from 'github.com/NVIDIA/gontainer': "+
"circular dependency")
equal(t, errs[4].Error(), "failed to validate factory 'Factory[func(bool) (int, error)]' "+
"from 'github.com/NVIDIA/gontainer': circular dependency")

equal(t, errors.Is(errs[5], ErrServiceDuplicated), true)
equal(t, errs[5].Error(), "failed to validate service 'string' (output 1) "+
"of 'Factory[func(int) (bool, string)]' from 'github.com/NVIDIA/gontainer': "+
"service duplicated")
equal(t, errors.Is(errs[5], ErrCircularDependency), true)
equal(t, errs[5].Error(), "failed to validate factory 'Factory[func(int) (bool, string)]' "+
"from 'github.com/NVIDIA/gontainer': circular dependency")
},
},
}
Expand Down

0 comments on commit 1d483ab

Please sign in to comment.