diff --git a/core/core.go b/core/core.go index c6ce81b..6f9abf6 100644 --- a/core/core.go +++ b/core/core.go @@ -236,7 +236,7 @@ func reserveCoderDirective(r *owl.Resolver, name string) error { namedAdaptor := namedStringableAdaptors[d.Argv[0]] if namedAdaptor == nil { - return fmt.Errorf("directive %s: unregistered coder: %q", name, d.Argv[0]) + return fmt.Errorf("directive %s: %w: %q", name, ErrUnregisteredCoder, d.Argv[0]) } r.Context = context.WithValue(r.Context, CtxCustomCoder, namedAdaptor) @@ -248,7 +248,7 @@ func reserveCoderDirective(r *owl.Resolver, name string) error { func ensureDirectiveExecutorsRegistered(r *owl.Resolver) error { for _, d := range r.Directives { if decoderNamespace.LookupExecutor(d.Name) == nil { - return fmt.Errorf("unregistered directive: %q", d.Name) + return fmt.Errorf("%w: %q", ErrUnregisteredDirective, d.Name) } // NOTE: don't need to check encoderNamespace because a directive // will always be registered in both namespaces. See RegisterDirective(). diff --git a/core/core_test.go b/core/core_test.go index fe2ebfd..9897b4d 100644 --- a/core/core_test.go +++ b/core/core_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" "testing" "time" @@ -320,6 +321,52 @@ func TestCore_Decode_PointerTypes(t *testing.T) { assert.ErrorContains(err, "invalid place") } +type CommaSeparatedIntegerArray struct { + Value []int +} + +func (a CommaSeparatedIntegerArray) ToString() (string, error) { + var res = make([]string, len(a.Value)) + for i := range a.Value { + res[i] = strconv.Itoa(a.Value[i]) + } + return strings.Join(res, ","), nil +} + +func (pa *CommaSeparatedIntegerArray) FromString(value string) error { + a := CommaSeparatedIntegerArray{} + values := strings.Split(value, ",") + a.Value = make([]int, len(values)) + for i := range values { + if value, err := strconv.Atoi(values[i]); err != nil { + return err + } else { + a.Value[i] = value + } + } + *pa = a + return nil +} + +func TestCore_Decode_CustomTypeSliceValueWrapper(t *testing.T) { + assert := assert.New(t) + + type Input struct { + Ids CommaSeparatedIntegerArray `in:"form=ids"` + } + co, err := New(Input{}) + assert.NoError(err) + + // Missing fields. + r := newMultipartFormRequestFromMap(map[string]any{ + "ids": "1,2,3", + }) + gotValue, err := co.Decode(r) + assert.NoError(err) + got := gotValue.(*Input) + assert.Equal([]int{1, 2, 3}, got.Ids.Value) +} + // Test: register named coders and use them in the "coder" directive, // i.e. customizing the encoding/decoding for a specific struct field. diff --git a/core/directiveruntime.go b/core/directiveruntime.go index 0d60c31..b31df82 100644 --- a/core/directiveruntime.go +++ b/core/directiveruntime.go @@ -6,7 +6,6 @@ import ( "net/http" "reflect" - "github.com/ggicci/httpin/internal" "github.com/ggicci/owl" ) @@ -88,7 +87,7 @@ func (rtm *DirectiveRuntime) SetValue(value any) error { if !newValue.Type().AssignableTo(targetType) { return fmt.Errorf("%w: value of type %q is not assignable to type %q", - internal.ErrTypeMismatch, reflect.TypeOf(value), targetType) + ErrTypeMismatch, reflect.TypeOf(value), targetType) } rtm.Value.Elem().Set(newValue) diff --git a/core/error.go b/core/error.go index a804ecb..d555b9f 100644 --- a/core/error.go +++ b/core/error.go @@ -5,9 +5,16 @@ import ( "fmt" "strings" + "github.com/ggicci/httpin/internal" "github.com/ggicci/owl" ) +var ( + ErrUnregisteredDirective = errors.New("unregistered directive") + ErrUnregisteredCoder = errors.New("unregistered coder") + ErrTypeMismatch = internal.ErrTypeMismatch +) + type InvalidFieldError struct { // err is the underlying error thrown by the directive executor. err error diff --git a/httpin.go b/httpin.go index 02e27f9..efabee4 100644 --- a/httpin.go +++ b/httpin.go @@ -6,6 +6,7 @@ package httpin import ( "context" "fmt" + "io" "net/http" "reflect" @@ -22,6 +23,13 @@ const ( Input contextKey = iota ) +// Option is a collection of options for creating a Core instance. +var Option coreOptions = coreOptions{ + WithErrorHandler: core.WithErrorHandler, + WithMaxMemory: core.WithMaxMemory, + WithNestedDirectivesEnabled: core.WithNestedDirectivesEnabled, +} + // New calls core.New to create a new Core instance. Which is responsible for both: // // - decoding an HTTP request to an instance of the inputStruct; @@ -39,7 +47,9 @@ const ( // API, chained with other middlewares, and also reused in other APIs. You even // don't need to call the Deocde() method explicitly, the middleware will do it // for you and put the decoded instance to the request's context. -var New = core.New +func New(inputStruct any, opts ...core.Option) (*core.Core, error) { + return core.New(inputStruct, opts...) +} // File is the builtin type of httpin to manupulate file uploads. On the server // side, it is used to represent a file in a multipart/form-data request. On the @@ -48,11 +58,15 @@ type File = core.File // UploadFile is a helper function to create a File instance from a file path. // It is useful when you want to upload a file from the local file system. -var UploadFile = core.UploadFile +func UploadFile(path string) *File { + return core.UploadFile(path) +} // UploadStream is a helper function to create a File instance from a io.Reader. It // is useful when you want to upload a file from a stream. -var UploadStream = core.UploadStream +func UploadStream(r io.ReadCloser) *File { + return core.UploadStream(r) +} // Decode decodes an HTTP request to the given input struct. The input must be a // pointer to a struct instance. For example: @@ -61,12 +75,12 @@ var UploadStream = core.UploadStream // if err := Decode(req, &input); err != nil { ... } // // input is now populated with data from the request. -func Decode(req *http.Request, input any) error { +func Decode(req *http.Request, input any, opts ...core.Option) error { originalType := reflect.TypeOf(input) if originalType.Kind() != reflect.Ptr { return fmt.Errorf("httpin: input must be a pointer") } - co, err := New(originalType.Elem()) + co, err := New(originalType.Elem(), opts...) if err != nil { return err } @@ -83,15 +97,15 @@ func Decode(req *http.Request, input any) error { } // NewRequest wraps NewRequestWithContext using context.Background. -func NewRequest(method, url string, input any) (*http.Request, error) { +func NewRequest(method, url string, input any, opts ...core.Option) (*http.Request, error) { return NewRequestWithContext(context.Background(), method, url, input) } // NewRequestWithContext returns a new http.Request given a method, url and an // input struct instance. The fields of the input struct will be encoded to the // request by resolving the "in" tags and executing the directives. -func NewRequestWithContext(ctx context.Context, method, url string, input any) (*http.Request, error) { - co, err := New(input) +func NewRequestWithContext(ctx context.Context, method, url string, input any, opts ...core.Option) (*http.Request, error) { + co, err := New(input, opts...) if err != nil { return nil, err } @@ -145,3 +159,16 @@ func NewInput(inputStruct any, opts ...core.Option) func(http.Handler) http.Hand }) } } + +type coreOptions struct { + // WithErrorHandler overrides the default error handler. + WithErrorHandler func(core.ErrorHandler) core.Option + + // WithMaxMemory overrides the default maximum memory size (32MB) when reading + // the request body. See https://pkg.go.dev/net/http#Request.ParseMultipartForm + // for more details. + WithMaxMemory func(int64) core.Option + + // WithNestedDirectivesEnabled enables/disables nested directives. + WithNestedDirectivesEnabled func(bool) core.Option +}