diff --git a/core/core.go b/core/core.go index 6f9abf6..402f1c1 100644 --- a/core/core.go +++ b/core/core.go @@ -2,9 +2,11 @@ package core import ( "context" + "errors" "fmt" "mime" "net/http" + "reflect" "sort" "sync" @@ -30,7 +32,7 @@ type Core struct { // is responsible for both: // // - decoding an HTTP request to an instance of the inputStruct; -// - and encoding an instance of the inputStruct to an HTTP request. +// - encoding an instance of the inputStruct to an HTTP request. func New(inputStruct any, opts ...Option) (*Core, error) { resolver, err := buildResolver(inputStruct) if err != nil { @@ -52,55 +54,59 @@ func New(inputStruct any, opts ...Option) (*Core, error) { for _, opt := range allOptions { if err := opt(core); err != nil { - return nil, fmt.Errorf("httpin: invalid option: %w", err) + return nil, fmt.Errorf("invalid option: %w", err) } } return core, nil } -// Decode decodes an HTTP request to a struct instance. -// The return value is a pointer to the input struct. -// For example: +// Decode decodes an HTTP request to an instance of the input struct and returns +// its pointer. For example: // -// New(&Input{}).Decode(req) -> *Input // New(Input{}).Decode(req) -> *Input func (c *Core) Decode(req *http.Request) (any, error) { - var err error - ct, _, _ := mime.ParseMediaType(req.Header.Get("Content-Type")) - if ct == "multipart/form-data" { - err = req.ParseMultipartForm(c.maxMemory) + // Create the input struct instance. Used to be created by owl.Resolve(). + value := reflect.New(c.resolver.Type).Interface() + if err := c.DecodeTo(req, value); err != nil { + return nil, err } else { - err = req.ParseForm() + return value, nil } - if err != nil { - return nil, err +} + +// DecodeTo decodes an HTTP request to the given value. The value must be a pointer +// to the struct instance of the type that the Core instance holds. +func (c *Core) DecodeTo(req *http.Request, value any) (err error) { + if err = c.parseRequestForm(req); err != nil { + return fmt.Errorf("failed to parse request form: %w", err) } - rv, err := c.resolver.Resolve( + err = c.resolver.ResolveTo( + value, owl.WithNamespace(decoderNamespace), owl.WithValue(CtxRequest, req), owl.WithNestedDirectivesEnabled(c.enableNestedDirectives), ) - if err != nil { - return nil, NewInvalidFieldError(err) + if err != nil && !errors.Is(err, owl.ErrInvalidResolveTarget) { + return NewInvalidFieldError(err) } - return rv.Interface(), nil + return err } -// NewRequest wraps NewRequestWithContext using context.Background. +// NewRequest wraps NewRequestWithContext using context.Background(), see +// NewRequestWithContext. func (c *Core) NewRequest(method string, url string, input any) (*http.Request, error) { return c.NewRequestWithContext(context.Background(), method, url, input) } -// NewRequestWithContext returns a new http.Request given a method, url and an -// input struct instance. Note that the Core instance is bound to a specific -// type of struct. Which means when the given input is not the type of the -// struct that the Core instance holds, error of type mismatch will be returned. -// In order to avoid this error, you can always use httpin.NewRequest() function -// instead. Which will create a Core instance for you when needed. There's no -// performance penalty for doing so. Because there's a cache layer for all the -// Core instances. +// NewRequestWithContext turns the given input struct into an HTTP request. Note +// that the Core instance is bound to a specific type of struct. Which means +// when the given input is not the type of the struct that the Core instance +// holds, error of type mismatch will be returned. In order to avoid this error, +// you can always use httpin.NewRequest() instead. Which will create a Core +// instance for you on demand. There's no performance penalty for doing so. +// Because there's a cache layer for all the Core instances. func (c *Core) NewRequestWithContext(ctx context.Context, method string, url string, input any) (*http.Request, error) { c.prepareScanResolver() req, err := http.NewRequestWithContext(ctx, method, url, nil) @@ -168,6 +174,16 @@ func (c *Core) prepareScanResolver() { } } +func (c *Core) parseRequestForm(req *http.Request) (err error) { + ct, _, _ := mime.ParseMediaType(req.Header.Get("Content-Type")) + if ct == "multipart/form-data" { + err = req.ParseMultipartForm(c.maxMemory) + } else { + err = req.ParseForm() + } + return +} + // buildResolver builds a resolver for the inputStruct. It will run normalizations // on the resolver and cache it. func buildResolver(inputStruct any) (*owl.Resolver, error) { diff --git a/core/registry.go b/core/registry.go index 8c32510..5b0765b 100644 --- a/core/registry.go +++ b/core/registry.go @@ -14,13 +14,19 @@ var ( namedStringableAdaptors = make(map[string]*NamedAnyStringableAdaptor) ) -// RegisterCoder registers a custom stringable adaptor for the given type T. -// When a field of type T is encountered, the adaptor will be used to convert -// the value to a Stringable, which will be used to convert the value from/to string. +// RegisterCoder registers a custom coder for the given type T. When a field of +// type T is encountered, this coder will be used to convert the value to a +// Stringable, which will be used to convert the value from/to string. // -// NOTE: this function is designed to override the default Stringable adaptors that -// are registered by this package. For example, if you want to override the defualt -// behaviour of converting a bool value from/to string, you can do this: +// NOTE: this function is designed to override the default Stringable adaptors +// that are registered by this package. For example, if you want to override the +// defualt behaviour of converting a bool value from/to string, you can do this: +// +// func init() { +// core.RegisterCoder[bool](func(b *bool) (core.Stringable, error) { +// return (*YesNo)(b), nil +// }) +// } // // type YesNo bool // @@ -42,25 +48,18 @@ var ( // } // return nil // } -// -// func init() { -// core.RegisterCoder[bool](func(b *bool) (core.Stringable, error) { -// return (*YesNo)(b), nil -// }) -// } func RegisterCoder[T any](adapt func(*T) (Stringable, error)) { customStringableAdaptors[internal.TypeOf[T]()] = internal.NewAnyStringableAdaptor[T](adapt) } -// RegisterNamedCoder works similar to RegisterType, except that it binds the adaptor to a name. -// This is useful when you only want to override the types in a specific struct. -// You will be using the "encoder" and "decoder" directives to specify the name of the adaptor. -// -// For example: +// RegisterNamedCoder works similar to RegisterCoder, except that it binds the +// coder to a name. This is useful when you only want to override the types in +// a specific struct field. You will be using the "coder" or "decoder" directive +// to specify the name of the coder to use. For example: // // type MyStruct struct { -// Bool bool // this field will be encoded/decoded using the default bool coder -// YesNo bool `in:"encoder=yesno,decoder=yesno"` // this field will be encoded/decoded using the YesNo coder +// Bool bool // use default bool coder +// YesNo bool `in:"coder=yesno"` // use YesNo coder // } // // func init() { @@ -68,6 +67,27 @@ func RegisterCoder[T any](adapt func(*T) (Stringable, error)) { // return (*YesNo)(b), nil // }) // } +// +// type YesNo bool +// +// func (yn YesNo) String() string { +// if yn { +// return "yes" +// } +// return "no" +// } +// +// func (yn *YesNo) FromString(s string) error { +// switch s { +// case "yes": +// *yn = true +// case "no": +// *yn = false +// default: +// return fmt.Errorf("invalid YesNo value: %q", s) +// } +// return nil +// } func RegisterNamedCoder[T any](name string, adapt func(*T) (Stringable, error)) { namedStringableAdaptors[name] = &NamedAnyStringableAdaptor{ Name: name, @@ -76,8 +96,9 @@ func RegisterNamedCoder[T any](name string, adapt func(*T) (Stringable, error)) } } -// RegisterFileCoder registers the given type T as a file type. T must implement the Fileable interface. -// Remember if you don't register the type explicitly, it won't be recognized as a file type. +// RegisterFileCoder registers the given type T as a file type. T must implement +// the Fileable interface. Remember if you don't register the type explicitly, +// it won't be recognized as a file type. func RegisterFileCoder[T Fileable]() error { fileTypes[internal.TypeOf[T]()] = struct{}{} return nil diff --git a/go.mod b/go.mod index 0cf4b7e..b2364b0 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/ggicci/httpin go 1.20 require ( - github.com/ggicci/owl v0.7.0 + github.com/ggicci/owl v0.8.2 github.com/go-chi/chi/v5 v5.0.11 github.com/gorilla/mux v1.8.1 github.com/justinas/alice v1.2.0 - github.com/labstack/echo/v4 v4.11.4 - github.com/stretchr/testify v1.8.4 + github.com/labstack/echo/v4 v4.12.0 + github.com/stretchr/testify v1.9.0 ) require ( @@ -19,9 +19,9 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/crypto v0.22.0 // indirect + golang.org/x/net v0.24.0 // indirect + golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7500388..dee6dcd 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,19 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/ggicci/owl v0.7.0 h1:+AMlCR0AY7j72q7hjtN4pm8VJiikwpROtMgvPnXtuik= -github.com/ggicci/owl v0.7.0/go.mod h1:TRPWshRwYej6uES//YW5aNgLB370URwyta1Ytfs7KXs= +github.com/ggicci/owl v0.8.0 h1:PCueAADCWwuW2jv7fvp40eNjvrv3se/Rhkb+Ah6MPbM= +github.com/ggicci/owl v0.8.0/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4= +github.com/ggicci/owl v0.8.1 h1:vppxAqpNOYBdrPKpcq7lzLy40UmSMr8Oz+h2EsJVgew= +github.com/ggicci/owl v0.8.1/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4= +github.com/ggicci/owl v0.8.2 h1:og+lhqpzSMPDdEB+NJfzoAJARP7qCG3f8uUC3xvGukA= +github.com/ggicci/owl v0.8.2/go.mod h1:PHRD57u41vFN5UtFz2SF79yTVoM3HlWpjMiE+ZU2dj4= github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo= github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA= -github.com/labstack/echo/v4 v4.11.4 h1:vDZmA+qNeh1pd/cCkEicDMrjtrnMGQ1QFI9gWN1zGq8= -github.com/labstack/echo/v4 v4.11.4/go.mod h1:noh7EvLwqDsmh/X/HWKPUl1AjzJrhyptRyEbQJfxen8= +github.com/labstack/echo/v4 v4.12.0 h1:IKpw49IMryVB2p1a4dzwlhP1O2Tf2E0Ir/450lH+kI0= +github.com/labstack/echo/v4 v4.12.0/go.mod h1:UP9Cr2DJXbOK3Kr9ONYzNowSh7HP0aG0ShAyycHSJvM= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -19,20 +23,20 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/httpin.go b/httpin.go index efabee4..a861afb 100644 --- a/httpin.go +++ b/httpin.go @@ -5,7 +5,7 @@ package httpin import ( "context" - "fmt" + "errors" "io" "net/http" "reflect" @@ -68,42 +68,55 @@ func UploadStream(r io.ReadCloser) *File { return core.UploadStream(r) } -// Decode decodes an HTTP request to the given input struct. The input must be a -// pointer to a struct instance. For example: +// DecodeTo decodes an HTTP request and populates input with data from the HTTP request. +// The input must be a pointer to a struct instance. For example: // // input := &InputStruct{} -// if err := Decode(req, &input); err != nil { ... } +// if err := DecodeTo(req, input); err != nil { ... } // // input is now populated with data from the request. -func Decode(req *http.Request, input any, opts ...core.Option) error { - originalType := reflect.TypeOf(input) - if originalType.Kind() != reflect.Ptr { - return fmt.Errorf("httpin: input must be a pointer") - } - co, err := New(originalType.Elem(), opts...) +func DecodeTo(req *http.Request, input any, opts ...core.Option) error { + co, err := New(internal.DereferencedType(input), opts...) if err != nil { return err } - if value, err := co.Decode(req); err != nil { - return err + return co.DecodeTo(req, input) +} + +// Decode decodes an HTTP request to an instance of T and returns its pointer +// (*T). T must be a struct type. For example: +// +// if user, err := Decode[User](req); err != nil { ... } +// // now user is a *User instance, which has been populated with data from the request. +func Decode[T any](req *http.Request, opts ...core.Option) (*T, error) { + rt := internal.TypeOf[T]() + if rt.Kind() != reflect.Struct { + return nil, errors.New("generic type T must be a struct type") + } + co, err := New(rt, opts...) + if err != nil { + return nil, err + } + if v, err := co.Decode(req); err != nil { + return nil, err } else { - if originalType.Elem().Kind() == reflect.Ptr { - reflect.ValueOf(input).Elem().Set(reflect.ValueOf(value)) - } else { - reflect.ValueOf(input).Elem().Set(reflect.ValueOf(value).Elem()) - } - return nil + return v.(*T), nil } } -// NewRequest wraps NewRequestWithContext using context.Background. +// NewRequest wraps NewRequestWithContext using context.Background(), see NewRequestWithContext. func NewRequest(method, url string, input any, opts ...core.Option) (*http.Request, error) { return NewRequestWithContext(context.Background(), method, url, input) } -// NewRequestWithContext returns a new http.Request given a method, url and an -// input struct instance. The fields of the input struct will be encoded to the -// request by resolving the "in" tags and executing the directives. +// NewRequestWithContext turns the given input into an HTTP request. The input +// must be a struct instance. And its fields' "in" tags define how to bind the +// data from the struct to the HTTP request. Use it as the replacement of +// http.NewRequest(). +// +// addUserPayload := &AddUserRequest{...} +// addUserRequest, err := NewRequestWithContext(context.Background(), "GET", "http://example.com", addUserPayload) +// http.DefaultClient.Do(addUserRequest) func NewRequestWithContext(ctx context.Context, method, url string, input any, opts ...core.Option) (*http.Request, error) { co, err := New(input, opts...) if err != nil { @@ -116,9 +129,9 @@ func NewRequestWithContext(ctx context.Context, method, url string, input any, o // in an http.Handler and returns another http.Handler. // // The middleware created by NewInput is to add the decoding function to an -// existing http.Handler. This functionality will decode the HTTP request and -// put the decoded struct instance to the request's context. So that the next -// hop can get the decoded struct instance from the request's context. +// existing http.Handler. This functionality will decode the HTTP request into a +// struct instance and put its pointer to the request's context. So that the +// next hop can get the decoded struct instance from the request's context. // // We recommend using https://github.com/justinas/alice to chain your // middlewares. If you're using some popular web frameworks, they may have diff --git a/httpin_test.go b/httpin_test.go index efa48b9..8d9865a 100644 --- a/httpin_test.go +++ b/httpin_test.go @@ -20,57 +20,79 @@ type Pagination struct { PerPage int `in:"form=per_page,page_size"` } -func TestDecode(t *testing.T) { +func testcasePagination1100() (*http.Request, *Pagination) { r, _ := http.NewRequest("GET", "/", nil) r.Form = url.Values{ "page": {"1"}, "per_page": {"100"}, } - expected := &Pagination{ + return r, &Pagination{ Page: 1, PerPage: 100, } +} + +func TestDecodeTo(t *testing.T) { + r, expected := testcasePagination1100() func() { input := &Pagination{} - err := Decode(r, input) // pointer to a struct instance + err := DecodeTo(r, input) // pointer to a struct instance assert.NoError(t, err) assert.Equal(t, expected, input) }() func() { input := Pagination{} - err := Decode(r, &input) // addressable struct instance + err := DecodeTo(r, &input) // addressable struct instance assert.NoError(t, err) assert.Equal(t, expected, &input) }() func() { input := &Pagination{} - err := Decode(r, &input) // pointer to pointer of struct instance + err := DecodeTo(r, &input) // pointer to pointer of struct instance assert.NoError(t, err) assert.Equal(t, expected, input) }() func() { input := Pagination{} - err := Decode(r, input) // non-pointer struct instance should fail - assert.ErrorContains(t, err, "input must be a pointer") + err := DecodeTo(r, input) // non-pointer struct instance should fail + assert.ErrorContains(t, err, "invalid resolve target") }() } +func TestDecode(t *testing.T) { + r, expected := testcasePagination1100() + + p, err := Decode[Pagination](r) + assert.NoError(t, err) + assert.Equal(t, expected, p) +} + +func TestDecode_ErrNotAStruct(t *testing.T) { + r, _ := testcasePagination1100() + + _, err := Decode[int](r) + assert.ErrorContains(t, err, "T must be a struct type") + + _, err = Decode[*Pagination](r) + assert.ErrorContains(t, err, "T must be a struct type") +} + func TestDecode_ErrBuildResolverFailed(t *testing.T) { - r, _ := http.NewRequest("GET", "/", nil) - r.Form = url.Values{ - "page": {"1"}, - "per_page": {"100"}, - } + r, _ := testcasePagination1100() type Foo struct { Name string `in:"nonexistent=foo"` } - assert.Error(t, Decode(r, &Foo{})) + assert.Error(t, DecodeTo(r, &Foo{})) + + v, err := Decode[Foo](r) + assert.Nil(t, v) + assert.Error(t, err) } func TestDecode_ErrDecodeFailure(t *testing.T) { @@ -81,7 +103,11 @@ func TestDecode_ErrDecodeFailure(t *testing.T) { } p := &Pagination{} - assert.Error(t, Decode(r, p)) + assert.Error(t, DecodeTo(r, p)) + + v, err := Decode[Pagination](r) + assert.Nil(t, v) + assert.Error(t, err) } type EchoInput struct { diff --git a/internal/misc.go b/internal/misc.go index 83f67a0..5e6a401 100644 --- a/internal/misc.go +++ b/internal/misc.go @@ -7,7 +7,7 @@ import ( func IsNil(value reflect.Value) bool { switch value.Kind() { - case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice: + case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Interface, reflect.Slice: return value.IsNil() default: return false @@ -30,3 +30,12 @@ func TypeOf[T any]() reflect.Type { func Pointerize[T any](v T) *T { return &v } + +// DereferencedType returns the underlying type of a pointer. +func DereferencedType(v any) reflect.Type { + rv := reflect.ValueOf(v) + for rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + return rv.Type() +} diff --git a/internal/misc_test.go b/internal/misc_test.go index 164f7fc..6348622 100644 --- a/internal/misc_test.go +++ b/internal/misc_test.go @@ -27,3 +27,15 @@ func TestTypeOf(t *testing.T) { func TestPointerize(t *testing.T) { assert.Equal(t, 102, *Pointerize[int](102)) } + +func TestDereferencedType(t *testing.T) { + type Object struct{} + + var o = new(Object) + var po = &o + var ppo = &po + assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(Object{})) + assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(o)) + assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(po)) + assert.Equal(t, reflect.TypeOf(Object{}), DereferencedType(ppo)) +}