diff --git a/module.go b/module.go index a0d69ba..5b629ff 100644 --- a/module.go +++ b/module.go @@ -2,44 +2,70 @@ package espresso import ( "context" + "errors" "fmt" "reflect" ) -type moduleName string +var ErrModuleDependError = errors.New("depend error") -type Module[T any] struct { - name moduleName +type Module interface { + Name() moduleName + Check(context.Context) error } -func DefineModule[T any]() Module[T] { - var t T - typ := reflect.TypeOf(t) - if typ.Kind() == reflect.Ptr { - panic("T should be a type, not a pointer.") - } +type ModuleImplementer interface { + Check(context.Context) error +} - name := fmt.Sprintf("%T", t) +type moduleName string + +type ModuleType[T ModuleImplementer] struct { + name moduleName + depends []Module +} - return Module[T]{ - name: moduleName(name), +func DefineModule[T ModuleImplementer](depends ...Module) *ModuleType[T] { + var t T + name := reflect.TypeOf(t).Name() + return &ModuleType[T]{ + name: moduleName(name), + depends: depends, } } -func (m Module[T]) With(ctx context.Context, moduleInstance *T) context.Context { - return context.WithValue(ctx, m.name, moduleInstance) +func (m ModuleType[T]) Name() moduleName { + return m.name } -func (m Module[T]) Value(ctx context.Context) *T { +func (m ModuleType[T]) Value(ctx context.Context) T { + var n T v := ctx.Value(m.name) if v == nil { - return nil + return n } - ret, ok := v.(*T) + ret, ok := v.(T) if !ok { - return nil + return n } return ret } + +func (m *ModuleType[T]) Check(ctx context.Context) error { + var errs []error + for _, module := range m.depends { + if err := module.Check(ctx); err != nil { + errs = append(errs, fmt.Errorf("module %s: %w", module.Name(), err)) + } + } + + if len(errs) != 0 { + errs = append(errs, fmt.Errorf("module %s: %w", m.Name(), ErrModuleDependError)) + } else if err := m.Check(ctx); err != nil { + errs = append(errs, fmt.Errorf("module %s: %w", m.Name(), err)) + } + + return errors.Join(errs...) +}