From 8aed3dde18fc103b7362c1f1ff4c7988ee560b13 Mon Sep 17 00:00:00 2001 From: Taichi Sasaki Date: Fri, 6 Sep 2024 00:50:06 +0900 Subject: [PATCH] Generate validation code for required attributes in inline struct http bodies (#3580) * Split test cases for Required * Generate validation code for required attributes in inline struct http bodies. --- http/codegen/server_decode_test.go | 2 + http/codegen/service_data.go | 5 +- .../testdata/payload_decode_functions.go | 77 +++++++++++++++++++ http/codegen/testdata/payload_dsls.go | 41 ++++++++-- 4 files changed, 118 insertions(+), 7 deletions(-) diff --git a/http/codegen/server_decode_test.go b/http/codegen/server_decode_test.go index d6143b417f..bd16aa86e0 100644 --- a/http/codegen/server_decode_test.go +++ b/http/codegen/server_decode_test.go @@ -151,6 +151,7 @@ func TestDecode(t *testing.T) { {"decode-body-user-nested", testdata.PayloadBodyNestedUserDSL, testdata.PayloadBodyNestedUserDecodeCode}, {"decode-body-user-validate", testdata.PayloadBodyUserValidateDSL, testdata.PayloadBodyUserValidateDecodeCode}, {"decode-body-object", testdata.PayloadBodyObjectDSL, testdata.PayloadBodyObjectDecodeCode}, + {"decode-body-object-required", testdata.PayloadBodyObjectRequiredDSL, testdata.PayloadBodyObjectRequiredDecodeCode}, {"decode-body-object-validate", testdata.PayloadBodyObjectValidateDSL, testdata.PayloadBodyObjectValidateDecodeCode}, {"decode-body-union", testdata.PayloadBodyUnionDSL, testdata.PayloadBodyUnionDecodeCode}, {"decode-body-union-validate", testdata.PayloadBodyUnionValidateDSL, testdata.PayloadBodyUnionValidateDecodeCode}, @@ -170,6 +171,7 @@ func TestDecode(t *testing.T) { {"decode-body-primitive-array-string-validate", testdata.PayloadBodyPrimitiveArrayStringValidateDSL, testdata.PayloadBodyPrimitiveArrayStringValidateDecodeCode}, {"decode-body-primitive-array-bool-validate", testdata.PayloadBodyPrimitiveArrayBoolValidateDSL, testdata.PayloadBodyPrimitiveArrayBoolValidateDecodeCode}, + {"decode-body-primitive-array-user-required", testdata.PayloadBodyPrimitiveArrayUserRequiredDSL, testdata.PayloadBodyPrimitiveArrayUserRequiredDecodeCode}, {"decode-body-primitive-array-user-validate", testdata.PayloadBodyPrimitiveArrayUserValidateDSL, testdata.PayloadBodyPrimitiveArrayUserValidateDecodeCode}, {"decode-body-primitive-field-array-user", testdata.PayloadBodyPrimitiveFieldArrayUserDSL, testdata.PayloadBodyPrimitiveFieldArrayUserDecodeCode}, {"decode-body-extend-primitive-field-array-user", testdata.PayloadExtendBodyPrimitiveFieldArrayUserDSL, testdata.PayloadBodyPrimitiveFieldArrayUserDecodeCode}, diff --git a/http/codegen/service_data.go b/http/codegen/service_data.go index 5881b8cfc7..c1818960c9 100644 --- a/http/codegen/service_data.go +++ b/http/codegen/service_data.go @@ -1992,6 +1992,9 @@ func buildRequestBodyType(body, att *expr.AttributeExpr, e *expr.HTTPEndpointExp } } } else { + // Generate validation code first because inline struct validation is removed. + ctx := codegen.NewAttributeContext(!expr.IsPrimitive(body.Type), false, !svr, "", sd.Scope) + validateRef = codegen.ValidationCode(body, nil, ctx, true, expr.IsAlias(body.Type), false, "body") if svr && expr.IsObject(body.Type) { // Body is an explicit object described in the design and in // this case the GoTypeRef is an inline struct definition. We @@ -2000,8 +2003,6 @@ func buildRequestBodyType(body, att *expr.AttributeExpr, e *expr.HTTPEndpointExp body.Validation = nil } varname = sd.Scope.GoTypeRef(body) - ctx := codegen.NewAttributeContext(false, false, !svr, "", sd.Scope) - validateRef = codegen.ValidationCode(body, nil, ctx, true, expr.IsAlias(body.Type), false, "body") desc = body.Description } var init *InitData diff --git a/http/codegen/testdata/payload_decode_functions.go b/http/codegen/testdata/payload_decode_functions.go index 35fbc65cf3..ef65072081 100644 --- a/http/codegen/testdata/payload_decode_functions.go +++ b/http/codegen/testdata/payload_decode_functions.go @@ -3917,6 +3917,40 @@ func DecodeMethodBodyObjectRequest(mux goahttp.Muxer, decoder func(*http.Request } ` +var PayloadBodyObjectRequiredDecodeCode = `// DecodeMethodBodyObjectRequiredRequest returns a decoder for requests sent to +// the ServiceBodyObjectRequired MethodBodyObjectRequired endpoint. +func DecodeMethodBodyObjectRequiredRequest(mux goahttp.Muxer, decoder func(*http.Request) goahttp.Decoder) func(*http.Request) (any, error) { + return func(r *http.Request) (any, error) { + var ( + body struct { + B *string ` + "`" + `form:"b" json:"b" xml:"b"` + "`" + ` + } + err error + ) + err = decoder(r).Decode(&body) + if err != nil { + if err == io.EOF { + return nil, goa.MissingPayloadError() + } + var gerr *goa.ServiceError + if errors.As(err, &gerr) { + return nil, gerr + } + return nil, goa.DecodePayloadError(err.Error()) + } + if body.B == nil { + err = goa.MergeErrors(err, goa.MissingFieldError("b", "body")) + } + if err != nil { + return nil, err + } + payload := NewMethodBodyObjectRequiredPayload(body) + + return payload, nil + } +} +` + var PayloadBodyObjectValidateDecodeCode = `// DecodeMethodBodyObjectValidateRequest returns a decoder for requests sent to // the ServiceBodyObjectValidate MethodBodyObjectValidate endpoint. func DecodeMethodBodyObjectValidateRequest(mux goahttp.Muxer, decoder func(*http.Request) goahttp.Decoder) func(*http.Request) (any, error) { @@ -3938,6 +3972,12 @@ func DecodeMethodBodyObjectValidateRequest(mux goahttp.Muxer, decoder func(*http } return nil, goa.DecodePayloadError(err.Error()) } + if body.B != nil { + err = goa.MergeErrors(err, goa.ValidatePattern("body.b", *body.B, "pattern")) + } + if err != nil { + return nil, err + } payload := NewMethodBodyObjectValidatePayload(body) return payload, nil @@ -4440,6 +4480,43 @@ func DecodeMethodBodyPrimitiveArrayBoolValidateRequest(mux goahttp.Muxer, decode } ` +var PayloadBodyPrimitiveArrayUserRequiredDecodeCode = `// DecodeMethodBodyPrimitiveArrayUserRequiredRequest returns a decoder for +// requests sent to the ServiceBodyPrimitiveArrayUserRequired +// MethodBodyPrimitiveArrayUserRequired endpoint. +func DecodeMethodBodyPrimitiveArrayUserRequiredRequest(mux goahttp.Muxer, decoder func(*http.Request) goahttp.Decoder) func(*http.Request) (any, error) { + return func(r *http.Request) (any, error) { + var ( + body []*PayloadTypeRequestBody + err error + ) + err = decoder(r).Decode(&body) + if err != nil { + if err == io.EOF { + return nil, goa.MissingPayloadError() + } + var gerr *goa.ServiceError + if errors.As(err, &gerr) { + return nil, gerr + } + return nil, goa.DecodePayloadError(err.Error()) + } + for _, e := range body { + if e != nil { + if err2 := ValidatePayloadTypeRequestBody(e); err2 != nil { + err = goa.MergeErrors(err, err2) + } + } + } + if err != nil { + return nil, err + } + payload := NewMethodBodyPrimitiveArrayUserRequiredPayloadType(body) + + return payload, nil + } +} +` + var PayloadBodyPrimitiveArrayUserValidateDecodeCode = `// DecodeMethodBodyPrimitiveArrayUserValidateRequest returns a decoder for // requests sent to the ServiceBodyPrimitiveArrayUserValidate // MethodBodyPrimitiveArrayUserValidate endpoint. diff --git a/http/codegen/testdata/payload_dsls.go b/http/codegen/testdata/payload_dsls.go index dfaa4c344b..561ed6f85f 100644 --- a/http/codegen/testdata/payload_dsls.go +++ b/http/codegen/testdata/payload_dsls.go @@ -1947,7 +1947,6 @@ var PayloadBodyStringValidateDSL = func() { Attribute("b", String, func() { Pattern("pattern") }) - Required("b") }) HTTP(func() { POST("/") @@ -2039,9 +2038,9 @@ var PayloadBodyObjectDSL = func() { }) } -var PayloadBodyObjectValidateDSL = func() { - Service("ServiceBodyObjectValidate", func() { - Method("MethodBodyObjectValidate", func() { +var PayloadBodyObjectRequiredDSL = func() { + Service("ServiceBodyObjectRequired", func() { + Method("MethodBodyObjectRequired", func() { Payload(func() { Attribute("b", String) Required("b") @@ -2057,6 +2056,24 @@ var PayloadBodyObjectValidateDSL = func() { }) } +var PayloadBodyObjectValidateDSL = func() { + Service("ServiceBodyObjectValidate", func() { + Method("MethodBodyObjectValidate", func() { + Payload(func() { + Attribute("b", String, func() { + Pattern("pattern") + }) + }) + HTTP(func() { + POST("/") + Body(func() { + Attribute("b", String) + }) + }) + }) + }) +} + var PayloadBodyUnionDSL = func() { var Union = Type("Union", func() { OneOf("Values", func() { @@ -2342,12 +2359,26 @@ var PayloadBodyPrimitiveArrayBoolValidateDSL = func() { }) } +var PayloadBodyPrimitiveArrayUserRequiredDSL = func() { + var PayloadType = Type("PayloadType", func() { + Attribute("a", String) + Required("a") + }) + Service("ServiceBodyPrimitiveArrayUserRequired", func() { + Method("MethodBodyPrimitiveArrayUserRequired", func() { + Payload(ArrayOf(PayloadType)) + HTTP(func() { + POST("/") + }) + }) + }) +} + var PayloadBodyPrimitiveArrayUserValidateDSL = func() { var PayloadType = Type("PayloadType", func() { Attribute("a", String, func() { Pattern("pattern") }) - Required("a") }) Service("ServiceBodyPrimitiveArrayUserValidate", func() { Method("MethodBodyPrimitiveArrayUserValidate", func() {