From 09555500581dd59b9852c70c146da80e2abf9b91 Mon Sep 17 00:00:00 2001 From: Raphael Simon Date: Tue, 26 Nov 2024 13:12:58 -0800 Subject: [PATCH] Initial interceptors implementation --- README.md | 1 + codegen/service/client.go | 12 + codegen/service/client_test.go | 1 + codegen/service/convert_test.go | 24 -- codegen/service/endpoint.go | 66 +++- codegen/service/endpoint_test.go | 2 + codegen/service/interceptors.go | 215 +++++++++++++ codegen/service/interceptors_test.go | 209 ++++++++++++ codegen/service/service.go | 5 + .../service/templates/client_wrappers.go.tpl | 29 ++ .../templates/endpoint_wrappers.go.tpl | 30 ++ codegen/service/templates/interceptors.go.tpl | 153 +++++++++ .../templates/service_client_init.go.tpl | 12 +- .../templates/service_endpoints_init.go.tpl | 15 +- .../templates/service_interceptor.go.tpl | 0 codegen/service/testdata/client_code.go | 47 +++ codegen/service/testdata/endpoint_code.go | 181 +++++++++++ codegen/service/testdata/endpoint_dsls.go | 44 +++ .../interceptor-with-read-payload.golden | 41 +++ .../interceptor-with-read-result.golden | 41 +++ ...interceptor-with-read-write-payload.golden | 45 +++ .../interceptor-with-read-write-result.golden | 45 +++ .../interceptor-with-write-payload.golden | 41 +++ .../interceptor-with-write-result.golden | 41 +++ .../interceptors/multiple-interceptors.golden | 30 ++ .../single-api-server-interceptor.golden | 16 + .../single-client-interceptor.golden | 16 + .../single-method-server-interceptor.golden | 16 + .../single-service-server-interceptor.golden | 16 + ...ming-interceptors-with-read-payload.golden | 45 +++ ...aming-interceptors-with-read-result.golden | 45 +++ .../streaming-interceptors.golden | 16 + codegen/service/testdata/interceptors_dsls.go | 297 ++++++++++++++++++ codegen/service/testdata/service_dsls.go | 14 + codegen/service/testing.go | 42 +++ dsl/description.go | 2 + dsl/interceptor.go | 293 +++++++++++++++++ dsl/interceptor_test.go | 250 +++++++++++++++ dsl/meta.go | 39 +++ expr/api.go | 4 + expr/interceptor.go | 84 +++++ expr/method.go | 79 ++++- expr/root.go | 2 + expr/service.go | 4 + go.mod | 7 +- go.sum | 4 +- http/codegen/service_data.go | 48 +-- pkg/interceptor.go | 22 ++ 48 files changed, 2608 insertions(+), 83 deletions(-) create mode 100644 codegen/service/interceptors.go create mode 100644 codegen/service/interceptors_test.go create mode 100644 codegen/service/templates/client_wrappers.go.tpl create mode 100644 codegen/service/templates/endpoint_wrappers.go.tpl create mode 100644 codegen/service/templates/interceptors.go.tpl create mode 100644 codegen/service/templates/service_interceptor.go.tpl create mode 100644 codegen/service/testdata/interceptors/interceptor-with-read-payload.golden create mode 100644 codegen/service/testdata/interceptors/interceptor-with-read-result.golden create mode 100644 codegen/service/testdata/interceptors/interceptor-with-read-write-payload.golden create mode 100644 codegen/service/testdata/interceptors/interceptor-with-read-write-result.golden create mode 100644 codegen/service/testdata/interceptors/interceptor-with-write-payload.golden create mode 100644 codegen/service/testdata/interceptors/interceptor-with-write-result.golden create mode 100644 codegen/service/testdata/interceptors/multiple-interceptors.golden create mode 100644 codegen/service/testdata/interceptors/single-api-server-interceptor.golden create mode 100644 codegen/service/testdata/interceptors/single-client-interceptor.golden create mode 100644 codegen/service/testdata/interceptors/single-method-server-interceptor.golden create mode 100644 codegen/service/testdata/interceptors/single-service-server-interceptor.golden create mode 100644 codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload.golden create mode 100644 codegen/service/testdata/interceptors/streaming-interceptors-with-read-result.golden create mode 100644 codegen/service/testdata/interceptors/streaming-interceptors.golden create mode 100644 codegen/service/testdata/interceptors_dsls.go create mode 100644 codegen/service/testing.go create mode 100644 dsl/interceptor.go create mode 100644 dsl/interceptor_test.go create mode 100644 expr/interceptor.go create mode 100644 pkg/interceptor.go diff --git a/README.md b/README.md index 7735bc20e3..cd38d1919e 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Slack: Goa Slack: Sign-up BSky: Goa + Twitter: @goadesign

diff --git a/codegen/service/client.go b/codegen/service/client.go index a088de0d37..f991e81aed 100644 --- a/codegen/service/client.go +++ b/codegen/service/client.go @@ -45,6 +45,18 @@ func ClientFile(_ string, service *expr.ServiceExpr) *codegen.File { Source: readTemplate("service_client_method"), Data: m, }) + if len(m.ClientInterceptors) > 0 { + sections = append(sections, &codegen.SectionTemplate{ + Name: "client-wrapper", + Source: readTemplate("client_wrappers"), + Data: map[string]interface{}{ + "Method": m.Name, + "MethodVarName": codegen.Goify(m.Name, true), + "Service": svc.Name, + "ClientInterceptors": m.ClientInterceptors, + }, + }) + } } } diff --git a/codegen/service/client_test.go b/codegen/service/client_test.go index 2fa5152365..6cfda7d4df 100644 --- a/codegen/service/client_test.go +++ b/codegen/service/client_test.go @@ -32,6 +32,7 @@ func TestClient(t *testing.T) { {"client-streaming-payload-no-result", testdata.StreamingPayloadNoResultMethodDSL, testdata.StreamingPayloadNoResultMethodClient}, {"client-bidirectional-streaming", testdata.BidirectionalStreamingMethodDSL, testdata.BidirectionalStreamingMethodClient}, {"client-bidirectional-streaming-no-payload", testdata.BidirectionalStreamingNoPayloadMethodDSL, testdata.BidirectionalStreamingNoPayloadMethodClient}, + {"client-interceptor", testdata.EndpointWithClientInterceptorDSL, testdata.InterceptorClient}, } for _, c := range cases { t.Run(c.Name, func(t *testing.T) { diff --git a/codegen/service/convert_test.go b/codegen/service/convert_test.go index 52a0e1d23e..479d3c1c5e 100644 --- a/codegen/service/convert_test.go +++ b/codegen/service/convert_test.go @@ -14,7 +14,6 @@ import ( "goa.design/goa/v3/codegen" "goa.design/goa/v3/codegen/service/testdata" "goa.design/goa/v3/dsl" - "goa.design/goa/v3/eval" "goa.design/goa/v3/expr" ) @@ -257,30 +256,7 @@ func TestConvertFile(t *testing.T) { } } -// runDSL returns the DSL root resulting from running the given DSL. -func runDSL(t *testing.T, dsl func()) *expr.RootExpr { - // reset all roots and codegen data structures - Services = make(ServicesData) - eval.Reset() - expr.Root = new(expr.RootExpr) - expr.GeneratedResultTypes = new(expr.ResultTypesRoot) - require.NoError(t, eval.Register(expr.Root)) - require.NoError(t, eval.Register(expr.GeneratedResultTypes)) - expr.Root.API = expr.NewAPIExpr("test api", func() {}) - expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()} - - // run DSL (first pass) - require.True(t, eval.Execute(dsl, nil)) - - // run DSL (second pass) - require.NoError(t, eval.RunDSL()) - - // return generated root - return expr.Root -} - // Test fixtures - var obj = &expr.UserTypeExpr{ AttributeExpr: &expr.AttributeExpr{ Type: &expr.Object{ diff --git a/codegen/service/endpoint.go b/codegen/service/endpoint.go index aaefc4bbd2..45d476620e 100644 --- a/codegen/service/endpoint.go +++ b/codegen/service/endpoint.go @@ -25,6 +25,10 @@ type ( ServiceVarName string // Methods lists the endpoint struct methods. Methods []*EndpointMethodData + // HasServerInterceptors indicates if the service has server interceptors. + HasServerInterceptors bool + // HasClientInterceptors indicates if the service has client interceptors. + HasClientInterceptors bool // ClientInitArgs lists the arguments needed to instantiate the client. ClientInitArgs string // Schemes contains the security schemes types used by the @@ -44,6 +48,10 @@ type ( ServiceName string // ServiceVarName is the name of the owner service Go interface. ServiceVarName string + // ServerInterceptors contains the server-side interceptors for this method + ServerInterceptors []*InterceptorData + // ClientInterceptors contains the client-side interceptors for this method + ClientInterceptors []*InterceptorData } ) @@ -122,6 +130,18 @@ func EndpointFile(genpkg string, service *expr.ServiceExpr) *codegen.File { Data: m, FuncMap: map[string]any{"payloadVar": payloadVar}, }) + if len(m.ServerInterceptors) > 0 { + sections = append(sections, &codegen.SectionTemplate{ + Name: "endpoint-wrapper", + Source: readTemplate("endpoint_wrappers"), + Data: map[string]interface{}{ + "MethodVarName": codegen.Goify(m.Name, true), + "Method": m.Name, + "Service": svc.Name, + "ServerInterceptors": m.ServerInterceptors, + }, + }) + } } } @@ -133,25 +153,45 @@ func endpointData(service *expr.ServiceExpr) *EndpointsData { methods := make([]*EndpointMethodData, len(svc.Methods)) names := make([]string, len(svc.Methods)) for i, m := range svc.Methods { + serverInts, clientInts := buildMethodInterceptors(service.Method(m.Name), svc.Scope) methods[i] = &EndpointMethodData{ - MethodData: m, - ArgName: codegen.Goify(m.VarName, false), - ServiceName: svc.Name, - ServiceVarName: serviceInterfaceName, - ClientVarName: clientStructName, + MethodData: m, + ArgName: codegen.Goify(m.VarName, false), + ServiceName: svc.Name, + ServiceVarName: serviceInterfaceName, + ClientVarName: clientStructName, + ServerInterceptors: serverInts, + ClientInterceptors: clientInts, } names[i] = codegen.Goify(m.VarName, false) } desc := fmt.Sprintf("%s wraps the %q service endpoints.", endpointsStructName, service.Name) + var hasServerInterceptors, hasClientInterceptors bool + for _, m := range methods { + if len(m.ServerInterceptors) > 0 { + hasServerInterceptors = true + if hasClientInterceptors { + break + } + } + if len(m.ClientInterceptors) > 0 { + hasClientInterceptors = true + if hasServerInterceptors { + break + } + } + } return &EndpointsData{ - Name: service.Name, - Description: desc, - VarName: endpointsStructName, - ClientVarName: clientStructName, - ServiceVarName: serviceInterfaceName, - ClientInitArgs: strings.Join(names, ", "), - Methods: methods, - Schemes: svc.Schemes, + Name: service.Name, + Description: desc, + VarName: endpointsStructName, + ClientVarName: clientStructName, + ServiceVarName: serviceInterfaceName, + ClientInitArgs: strings.Join(names, ", "), + Methods: methods, + HasServerInterceptors: hasServerInterceptors, + HasClientInterceptors: hasClientInterceptors, + Schemes: svc.Schemes, } } diff --git a/codegen/service/endpoint_test.go b/codegen/service/endpoint_test.go index 0356bf958e..2132eeeeca 100644 --- a/codegen/service/endpoint_test.go +++ b/codegen/service/endpoint_test.go @@ -34,6 +34,8 @@ func TestEndpoint(t *testing.T) { {"endpoint-streaming-payload-no-result", testdata.StreamingPayloadNoResultMethodDSL, testdata.StreamingPayloadNoResultMethodEndpoint}, {"endpoint-bidirectional-streaming", testdata.BidirectionalStreamingEndpointDSL, testdata.BidirectionalStreamingMethodEndpoint}, {"endpoint-bidirectional-streaming-no-payload", testdata.BidirectionalStreamingNoPayloadMethodDSL, testdata.BidirectionalStreamingNoPayloadMethodEndpoint}, + {"endpoint-with-server-interceptor", testdata.EndpointWithServerInterceptorDSL, testdata.EndpointWithServerInterceptor}, + {"endpoint-with-multiple-interceptors", testdata.EndpointWithMultipleInterceptorsDSL, testdata.EndpointWithMultipleInterceptors}, } for _, c := range cases { t.Run(c.Name, func(t *testing.T) { diff --git a/codegen/service/interceptors.go b/codegen/service/interceptors.go new file mode 100644 index 0000000000..57d734e0c6 --- /dev/null +++ b/codegen/service/interceptors.go @@ -0,0 +1,215 @@ +package service + +import ( + "path/filepath" + + "goa.design/goa/v3/codegen" + "goa.design/goa/v3/expr" +) + +type ( + // ServiceInterceptorData contains all data needed for generating interceptor code + ServiceInterceptorData struct { + Service string + PkgName string + Methods []*MethodInterceptorData + ServerInterceptors []*InterceptorData + ClientInterceptors []*InterceptorData + AllInterceptors []*InterceptorData + HasPrivateImplementationTypes bool + } + + // MethodInterceptorData contains interceptor data for a single method + MethodInterceptorData struct { + Service string + Method string + MethodVarName string + PayloadRef string + ResultRef string + ServerInterceptors []*InterceptorData + ClientInterceptors []*InterceptorData + } + + // InterceptorData describes a single interceptor. + InterceptorData struct { + Name string + UnexportedName string + Description string + PayloadRef string + ResultRef string + ReadPayload []*AttributeData + WritePayload []*AttributeData + ReadResult []*AttributeData + WriteResult []*AttributeData + ServerStreamInputStruct string + ClientStreamInputStruct string + } + + // AttributeData describes a single attribute. + AttributeData struct { + Name string + TypeRef string + FieldPointer bool + } +) + +// InterceptorsFile returns the interceptors file for the given service. +func InterceptorsFile(genpkg string, service *expr.ServiceExpr) *codegen.File { + svc := Services.Get(service.Name) + data := interceptorsData(service) + if len(data.ServerInterceptors) == 0 && len(data.ClientInterceptors) == 0 { + return nil + } + + path := filepath.Join(codegen.Gendir, svc.PathName, "interceptors.go") + sections := []*codegen.SectionTemplate{ + codegen.Header(service.Name+" interceptors", svc.PkgName, []*codegen.ImportSpec{ + {Path: "context"}, + codegen.GoaImport(""), + }), + { + Name: "interceptors", + Source: readTemplate("interceptors"), + Data: data, + }, + } + + return &codegen.File{Path: path, SectionTemplates: sections} +} + +func interceptorsData(service *expr.ServiceExpr) *ServiceInterceptorData { + svc := Services.Get(service.Name) + scope := svc.Scope + + // Build method data first + methods := make([]*MethodInterceptorData, 0, len(service.Methods)) + seenInts := make(map[string]*InterceptorData) + var serviceServerInts, serviceClientInts, allInts []*InterceptorData + var hasTypes bool + + for _, m := range service.Methods { + methodServerInts, methodClientInts := buildMethodInterceptors(m, scope) + if len(methodServerInts) == 0 && len(methodClientInts) == 0 { + continue + } + hasTypes = hasTypes || hasPrivateImplementationTypes(methodServerInts) || hasPrivateImplementationTypes(methodClientInts) + + // Add method data + methods = append(methods, &MethodInterceptorData{ + Service: svc.Name, + Method: m.Name, + MethodVarName: codegen.Goify(m.Name, true), + PayloadRef: scope.GoFullTypeRef(m.Payload, ""), + ResultRef: scope.GoFullTypeRef(m.Result, ""), + ServerInterceptors: methodServerInts, + ClientInterceptors: methodClientInts, + }) + + // Collect unique interceptors + for _, i := range methodServerInts { + if _, ok := seenInts[i.Name]; !ok { + seenInts[i.Name] = i + serviceServerInts = append(serviceServerInts, i) + allInts = append(allInts, i) + } + } + for _, i := range methodClientInts { + if _, ok := seenInts[i.Name]; !ok { + seenInts[i.Name] = i + serviceClientInts = append(serviceClientInts, i) + allInts = append(allInts, i) + } + } + } + + return &ServiceInterceptorData{ + Service: service.Name, + PkgName: svc.PkgName, + Methods: methods, + ServerInterceptors: serviceServerInts, + ClientInterceptors: serviceClientInts, + AllInterceptors: allInts, + HasPrivateImplementationTypes: hasTypes, + } +} + +func buildMethodInterceptors(m *expr.MethodExpr, scope *codegen.NameScope) ([]*InterceptorData, []*InterceptorData) { + svc := Services.Get(m.Service.Name) + methodData := svc.Method(m.Name) + var serverEndpointStruct, clientEndpointStruct string + if methodData.ServerStream != nil { + serverEndpointStruct = methodData.ServerStream.EndpointStruct + } + if methodData.ClientStream != nil { + clientEndpointStruct = methodData.ClientStream.EndpointStruct + } + var hasPrivateImplementationTypes bool + buildInterceptor := func(intr *expr.InterceptorExpr) *InterceptorData { + hasPrivateImplementationTypes = hasPrivateImplementationTypes || + intr.ReadPayload != nil || intr.WritePayload != nil || intr.ReadResult != nil || intr.WriteResult != nil + + return &InterceptorData{ + Name: codegen.Goify(intr.Name, true), + UnexportedName: codegen.Goify(intr.Name, false), + Description: intr.Description, + PayloadRef: methodData.PayloadRef, + ResultRef: methodData.ResultRef, + ServerStreamInputStruct: serverEndpointStruct, + ClientStreamInputStruct: clientEndpointStruct, + ReadPayload: collectAttributes(intr.ReadPayload, m.Payload, scope), + WritePayload: collectAttributes(intr.WritePayload, m.Payload, scope), + ReadResult: collectAttributes(intr.ReadResult, m.Result, scope), + WriteResult: collectAttributes(intr.WriteResult, m.Result, scope), + } + } + + serverInts := make([]*InterceptorData, len(m.ServerInterceptors)) + for i, intr := range m.ServerInterceptors { + serverInts[i] = buildInterceptor(intr) + } + + clientInts := make([]*InterceptorData, len(m.ClientInterceptors)) + for i, intr := range m.ClientInterceptors { + clientInts[i] = buildInterceptor(intr) + } + + return serverInts, clientInts +} + +// hasPrivateImplementationTypes returns true if any of the interceptors have +// private implementation types. +func hasPrivateImplementationTypes(interceptors []*InterceptorData) bool { + for _, intr := range interceptors { + if intr.ReadPayload != nil || intr.WritePayload != nil || intr.ReadResult != nil || intr.WriteResult != nil { + return true + } + } + return false +} + +// collectAttributes builds AttributeData from an AttributeExpr +func collectAttributes(attrNames, parent *expr.AttributeExpr, scope *codegen.NameScope) []*AttributeData { + if attrNames == nil { + return nil + } + + obj := expr.AsObject(attrNames.Type) + if obj == nil { + return nil + } + + data := make([]*AttributeData, len(*obj)) + for i, nat := range *obj { + parentAttr := parent.Find(nat.Name) + if parentAttr == nil { + continue + } + + data[i] = &AttributeData{ + Name: codegen.Goify(nat.Name, true), + TypeRef: scope.GoTypeRef(parentAttr), + FieldPointer: parent.IsPrimitivePointer(nat.Name, true), + } + } + return data +} diff --git a/codegen/service/interceptors_test.go b/codegen/service/interceptors_test.go new file mode 100644 index 0000000000..0817ac2b31 --- /dev/null +++ b/codegen/service/interceptors_test.go @@ -0,0 +1,209 @@ +package service + +import ( + "bytes" + "flag" + "go/format" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "goa.design/goa/v3/codegen" + "goa.design/goa/v3/codegen/service/testdata" + "goa.design/goa/v3/expr" +) + +var updateGolden = false + +func init() { + flag.BoolVar(&updateGolden, "w", false, "update golden files") +} + +func TestInterceptors(t *testing.T) { + cases := []struct { + Name string + DSL func() + }{ + {"no-interceptors", testdata.NoInterceptorsDSL}, + {"single-api-server-interceptor", testdata.SingleAPIServerInterceptorDSL}, + {"single-service-server-interceptor", testdata.SingleServiceServerInterceptorDSL}, + {"single-method-server-interceptor", testdata.SingleMethodServerInterceptorDSL}, + {"single-client-interceptor", testdata.SingleClientInterceptorDSL}, + {"multiple-interceptors", testdata.MultipleInterceptorsDSL}, + {"interceptor-with-read-payload", testdata.InterceptorWithReadPayloadDSL}, + {"interceptor-with-write-payload", testdata.InterceptorWithWritePayloadDSL}, + {"interceptor-with-read-write-payload", testdata.InterceptorWithReadWritePayloadDSL}, + {"interceptor-with-read-result", testdata.InterceptorWithReadResultDSL}, + {"interceptor-with-write-result", testdata.InterceptorWithWriteResultDSL}, + {"interceptor-with-read-write-result", testdata.InterceptorWithReadWriteResultDSL}, + {"streaming-interceptors", testdata.StreamingInterceptorsDSL}, + {"streaming-interceptors-with-read-payload", testdata.StreamingInterceptorsWithReadPayloadDSL}, + {"streaming-interceptors-with-read-result", testdata.StreamingInterceptorsWithReadResultDSL}, + } + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + root := runDSL(t, c.DSL) + require.Len(t, root.Services, 1) + + fs := InterceptorsFile("goa.design/goa/example", root.Services[0]) + + if c.Name == "no-interceptors" { + assert.Nil(t, fs) + return + } + + require.NotNil(t, fs) + + buf := new(bytes.Buffer) + for _, s := range fs.SectionTemplates[1:] { + require.NoError(t, s.Write(buf)) + } + bs, err := format.Source(buf.Bytes()) + require.NoError(t, err, buf.String()) + code := strings.ReplaceAll(string(bs), "\r\n", "\n") + + golden := filepath.Join("testdata", "interceptors", c.Name+".golden") + compareOrUpdateGolden(t, code, golden) + }) + } +} + +func TestInvalidInterceptors(t *testing.T) { + cases := []struct { + Name string + DSL func() + ErrContains string + }{ + { + Name: "streaming-result-interceptor", + DSL: testdata.StreamingResultInterceptorDSL, + ErrContains: "cannot be applied because the method result is streaming", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + _, err := runDSLWithError(t, c.DSL) + require.Error(t, err) + assert.Contains(t, err.Error(), c.ErrContains) + }) + } +} + +func TestCollectAttributes(t *testing.T) { + cases := []struct { + name string + attrNames *expr.AttributeExpr + parent *expr.AttributeExpr + want []*AttributeData + }{ + { + name: "nil-attributes", + attrNames: nil, + parent: &expr.AttributeExpr{Type: &expr.Object{}}, + want: nil, + }, + { + name: "non-object-attributes", + attrNames: &expr.AttributeExpr{Type: expr.Primitive(expr.StringKind)}, + parent: &expr.AttributeExpr{Type: &expr.Object{}}, + want: nil, + }, + { + name: "simple-string-attribute", + attrNames: &expr.AttributeExpr{ + Type: &expr.Object{ + {Name: "name", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.StringKind)}}, + }, + }, + parent: &expr.AttributeExpr{ + Type: &expr.Object{ + {Name: "name", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.StringKind)}}, + }, + Validation: &expr.ValidationExpr{Required: []string{"name"}}, + }, + want: []*AttributeData{ + {Name: "Name", TypeRef: "string", FieldPointer: false}, + }, + }, + { + name: "pointer-primitive", + attrNames: &expr.AttributeExpr{ + Type: &expr.Object{ + {Name: "age", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.IntKind)}}, + }, + }, + parent: &expr.AttributeExpr{ + Type: &expr.Object{ + {Name: "age", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.IntKind), Meta: map[string][]string{"struct:field:pointer": {"true"}}}}, + }, + }, + want: []*AttributeData{ + {Name: "Age", TypeRef: "int", FieldPointer: true}, + }, + }, + { + name: "multiple-attributes", + attrNames: &expr.AttributeExpr{ + Type: &expr.Object{ + {Name: "name", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.StringKind)}}, + {Name: "age", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.IntKind)}}, + }, + }, + parent: &expr.AttributeExpr{ + Type: &expr.Object{ + {Name: "name", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.StringKind)}}, + {Name: "age", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.IntKind), Meta: map[string][]string{"struct:field:pointer": {"true"}}}}, + }, + Validation: &expr.ValidationExpr{Required: []string{"name"}}, + }, + want: []*AttributeData{ + {Name: "Name", TypeRef: "string", FieldPointer: false}, + {Name: "Age", TypeRef: "int", FieldPointer: true}, + }, + }, + { + name: "attribute-not-in-parent", + attrNames: &expr.AttributeExpr{ + Type: &expr.Object{ + {Name: "missing", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.StringKind)}}, + }, + }, + parent: &expr.AttributeExpr{ + Type: &expr.Object{ + {Name: "name", Attribute: &expr.AttributeExpr{Type: expr.Primitive(expr.StringKind)}}, + }, + Validation: &expr.ValidationExpr{Required: []string{"name"}}, + }, + want: []*AttributeData{nil}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + scope := codegen.NewNameScope() + got := collectAttributes(tc.attrNames, tc.parent, scope) + assert.Equal(t, tc.want, got) + }) + } +} + +func compareOrUpdateGolden(t *testing.T, code, golden string) { + t.Helper() + if updateGolden { + require.NoError(t, os.MkdirAll(filepath.Dir(golden), 0750)) + require.NoError(t, os.WriteFile(golden, []byte(code), 0640)) + return + } + data, err := os.ReadFile(golden) + require.NoError(t, err) + if runtime.GOOS == "windows" { + data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n")) + } + assert.Equal(t, string(data), code) +} diff --git a/codegen/service/service.go b/codegen/service/service.go index 8d920c7854..a6c45448d1 100644 --- a/codegen/service/service.go +++ b/codegen/service/service.go @@ -195,6 +195,11 @@ func Files(genpkg string, service *expr.ServiceExpr, userTypePkgs map[string][]s } files := []*codegen.File{{Path: svcPath, SectionTemplates: sections}} + // interceptor.go + if file := InterceptorsFile(genpkg, service); file != nil { + files = append(files, file) + } + // user types paths := make([]string, len(typeDefSections)) i := 0 diff --git a/codegen/service/templates/client_wrappers.go.tpl b/codegen/service/templates/client_wrappers.go.tpl new file mode 100644 index 0000000000..94364c38cd --- /dev/null +++ b/codegen/service/templates/client_wrappers.go.tpl @@ -0,0 +1,29 @@ +{{ comment (printf "Wrap%sClientEndpoint wraps the %s endpoint with the client interceptors defined in the design." .MethodVarName .Method) }} +func Wrap{{ .MethodVarName }}ClientEndpoint(endpoint goa.Endpoint, i ClientInterceptors) goa.Endpoint { + {{- range .ClientInterceptors }} + endpoint = wrapClient{{ .Name }}(endpoint, i, "{{ $.Method }}") + {{- end }} + return endpoint +} + +{{- range .ClientInterceptors }} +{{ comment (printf "wrapClient%s applies the %s interceptor to endpoints." .Name .Name) }} +func wrapClient{{ .Name }}(endpoint goa.Endpoint, i ClientInterceptors, method string) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + info := &{{ .Name }}Info{ + Service: "{{ $.Service }}", + Method: method, + Endpoint: endpoint, + {{- if .ClientStreamInputStruct }} + RawPayload: req.(*{{ .ClientStreamInputStruct }}).Payload, + {{- else }} + RawPayload: req, + {{- end }} + } + next := func(ctx context.Context) (any, error) { + return endpoint(ctx, req) + } + return i.{{ .Name }}(ctx, info, next) + } +} +{{- end }} \ No newline at end of file diff --git a/codegen/service/templates/endpoint_wrappers.go.tpl b/codegen/service/templates/endpoint_wrappers.go.tpl new file mode 100644 index 0000000000..bae0b9fa25 --- /dev/null +++ b/codegen/service/templates/endpoint_wrappers.go.tpl @@ -0,0 +1,30 @@ +{{ comment (printf "Wrap%sEndpoint wraps the %s endpoint with the server-side interceptors defined in the design." .MethodVarName .Method) }} +func Wrap{{ .MethodVarName }}Endpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint { + {{- range .ServerInterceptors }} + endpoint = wrap{{ .Name }}(endpoint, i, "{{ $.Method }}") + {{- end }} + return endpoint +} + +{{- range .ServerInterceptors }} +{{ comment (printf "wrap%s applies the %s interceptor to endpoints." .Name .Name) }} +func wrap{{ .Name }}(endpoint goa.Endpoint, i ServerInterceptors, method string) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + info := &{{ .Name }}Info{ + Service: "{{ $.Service }}", + Method: method, + Endpoint: endpoint, + {{- if .ServerStreamInputStruct }} + RawPayload: req.(*{{ .ServerStreamInputStruct }}).Payload, + {{- else }} + RawPayload: req, + {{- end }} + } + next := func(ctx context.Context) (any, error) { + return endpoint(ctx, req) + } + return i.{{ .Name }}(ctx, info, next) + } +} + +{{- end }} \ No newline at end of file diff --git a/codegen/service/templates/interceptors.go.tpl b/codegen/service/templates/interceptors.go.tpl new file mode 100644 index 0000000000..46dfceedb0 --- /dev/null +++ b/codegen/service/templates/interceptors.go.tpl @@ -0,0 +1,153 @@ +{{- if .ServerInterceptors -}} +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { +{{- range .ServerInterceptors }} + {{ comment .Description }} + {{ .Name }}(context.Context, *{{ .Name }}Info, goa.NextFunc) (any, error) +{{- end }} +} +{{- end }} + +{{- if .ClientInterceptors -}} +// ClientInterceptors defines the interface for all client-side interceptors. +// Client interceptors execute after the payload is encoded and before the request +// is sent to the server (request interceptors) or after the response is decoded +// and before the result is returned to the client (response interceptors). +type ClientInterceptors interface { +{{- range .ClientInterceptors }} + {{ comment .Description }} + {{ .Name }}(context.Context, *{{ .Name }}Info, goa.NextFunc) (any, error) +{{- end }} +} +{{- end }} + +// Access interfaces for interceptor payloads and results +type ( +{{- range .AllInterceptors }} + // {{ .Name }}Info provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + {{ .Name }}Info goa.InterceptorInfo + {{- if or .ReadPayload .WritePayload }} + + // {{ .Name }}PayloadAccess provides type-safe access to the method payload. + // It allows reading and writing specific fields of the payload as defined + // in the design. + {{ .Name }}PayloadAccess interface { + {{- range .ReadPayload }} + {{ .Name }}() {{ .TypeRef }} + {{- end }} + {{- range .WritePayload }} + Set{{ .Name }}({{ .TypeRef }}) + {{- end }} + } + {{- end }} + {{- if or .ReadResult .WriteResult }} + + // {{ .Name }}ResultAccess provides type-safe access to the method result. + // It allows reading and writing specific fields of the result as defined + // in the design. + {{ .Name }}ResultAccess interface { + {{- range .ReadResult }} + {{ .Name }}() {{ .TypeRef }} + {{- end }} + {{- range .WriteResult }} + Set{{ .Name }}({{ .TypeRef }}) + {{- end }} + } + {{- end }} +{{- end }} +) + +{{- if .HasPrivateImplementationTypes }} + +// Private implementation types +type ( + {{- range .AllInterceptors }} + {{- if or .ReadPayload .WritePayload }} + {{ .UnexportedName }}PayloadAccess struct { + payload {{ .PayloadRef }} + } + {{- end }} + + {{- if or .ReadResult .WriteResult }} + {{ .UnexportedName }}ResultAccess struct { + result {{ .ResultRef }} + } + {{- end }} + {{- end }} +) + +// Public accessor methods for Info types +{{- range .AllInterceptors }} + {{- if or .ReadPayload .WritePayload }} +// Payload returns a type-safe accessor for the method payload. +func (info *{{ .Name }}Info) Payload() {{ .Name }}PayloadAccess { + return &{{ .UnexportedName }}PayloadAccess{payload: info.RawPayload.({{ .PayloadRef }})} +} + {{- end }} + + {{- if or .ReadResult .WriteResult }} +// Result returns a type-safe accessor for the method result. +func (info *{{ .Name }}Info) Result(res any) {{ .Name }}ResultAccess { + return &{{ .UnexportedName }}ResultAccess{result: res.({{ .ResultRef }})} +} + {{- end }} +{{- end }} + +// Private implementation methods +{{- range .AllInterceptors }} + {{- $interceptor := . }} + {{- range .ReadPayload }} +func (p *{{ $interceptor.UnexportedName }}PayloadAccess) {{ .Name }}() {{ .TypeRef }} { + {{- if .FieldPointer }} + if p.payload.{{ .Name }} == nil { + var zero {{ .TypeRef }} + return zero + } + return *p.payload.{{ .Name }} + {{- else }} + return p.payload.{{ .Name }} + {{- end }} +} + {{- end }} + + {{- range .WritePayload }} +func (p *{{ $interceptor.UnexportedName }}PayloadAccess) Set{{ .Name }}(v {{ .TypeRef }}) { + {{- if .FieldPointer }} + p.payload.{{ .Name }} = &v + {{- else }} + p.payload.{{ .Name }} = v + {{- end }} +} + {{- end }} + + {{- range .ReadResult }} +func (r *{{ $interceptor.UnexportedName }}ResultAccess) {{ .Name }}() {{ .TypeRef }} { + {{- if .FieldPointer }} + if r.result.{{ .Name }} == nil { + var zero {{ .TypeRef }} + return zero + } + return *r.result.{{ .Name }} + {{- else }} + return r.result.{{ .Name }} + {{- end }} +} + {{- end }} + + {{- range .WriteResult }} +func (r *{{ $interceptor.UnexportedName }}ResultAccess) Set{{ .Name }}(v {{ .TypeRef }}) { + {{- if .FieldPointer }} + r.result.{{ .Name }} = &v + {{- else }} + r.result.{{ .Name }} = v + {{- end }} +} + {{- end }} +{{- end }} +{{- end }} + + diff --git a/codegen/service/templates/service_client_init.go.tpl b/codegen/service/templates/service_client_init.go.tpl index 3f2302c524..20621066a3 100644 --- a/codegen/service/templates/service_client_init.go.tpl +++ b/codegen/service/templates/service_client_init.go.tpl @@ -1,8 +1,8 @@ {{ printf "New%s initializes a %q service client given the endpoints." .ClientVarName .Name | comment }} -func New{{ .ClientVarName }}({{ .ClientInitArgs }} goa.Endpoint) *{{ .ClientVarName }} { - return &{{ .ClientVarName }}{ -{{- range .Methods }} - {{ .VarName }}Endpoint: {{ .ArgName }}, -{{- end }} - } +func New{{ .ClientVarName }}({{ .ClientInitArgs }} goa.Endpoint{{ if .HasClientInterceptors }}, ci ClientInterceptors{{ end }}) *{{ .ClientVarName }} { + return &{{ .ClientVarName }}{ + {{- range .Methods }} + {{ .VarName }}Endpoint: {{ if .ClientInterceptors }}Wrap{{ .VarName }}ClientEndpoint({{ end }}{{ .ArgName }}{{ if .ClientInterceptors }}, ci){{ end }}, + {{- end }} + } } diff --git a/codegen/service/templates/service_endpoints_init.go.tpl b/codegen/service/templates/service_endpoints_init.go.tpl index 64a9803a6b..3f46fc4d88 100644 --- a/codegen/service/templates/service_endpoints_init.go.tpl +++ b/codegen/service/templates/service_endpoints_init.go.tpl @@ -1,14 +1,25 @@ - {{ printf "New%s wraps the methods of the %q service with endpoints." .VarName .Name | comment }} -func New{{ .VarName }}(s {{ .ServiceVarName }}) *{{ .VarName }} { +func New{{ .VarName }}(s {{ .ServiceVarName }}{{ if .HasServerInterceptors }}, si ServerInterceptors{{ end }}) *{{ .VarName }} { {{- if .Schemes }} // Casting service to Auther interface a := s.(Auther) {{- end }} +{{- if .HasServerInterceptors }} + endpoints := &{{ .VarName }}{ +{{- else }} return &{{ .VarName }}{ +{{- end }} {{- range .Methods }} {{ .VarName }}: New{{ .VarName }}Endpoint(s{{ range .Schemes }}, a.{{ .Type }}Auth{{ end }}), {{- end }} } +{{- if .HasServerInterceptors }} + {{- range .Methods }} + {{- if .ServerInterceptors }} + endpoints.{{ .VarName }} = Wrap{{ .VarName }}Endpoint(endpoints.{{ .VarName }}, si) + {{- end }} + {{- end }} + return endpoints +{{- end }} } \ No newline at end of file diff --git a/codegen/service/templates/service_interceptor.go.tpl b/codegen/service/templates/service_interceptor.go.tpl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/codegen/service/testdata/client_code.go b/codegen/service/testdata/client_code.go index 2b2356134d..96f5a7b6a0 100644 --- a/codegen/service/testdata/client_code.go +++ b/codegen/service/testdata/client_code.go @@ -283,3 +283,50 @@ func (c *Client) BidirectionalStreamingNoPayloadMethod(ctx context.Context) (res return ires.(BidirectionalStreamingNoPayloadMethodClientStream), nil } ` + +const InterceptorClient = `// Client is the "ServiceWithClientInterceptor" service client. +type Client struct { + MethodEndpoint goa.Endpoint +} + +// NewClient initializes a "ServiceWithClientInterceptor" service client given +// the endpoints. +func NewClient(method goa.Endpoint, ci ClientInterceptors) *Client { + return &Client{ + MethodEndpoint: WrapMethodClientEndpoint(method, ci), + } +} + +// Method calls the "Method" endpoint of the "ServiceWithClientInterceptor" +// service. +func (c *Client) Method(ctx context.Context, p string) (res string, err error) { + var ires any + ires, err = c.MethodEndpoint(ctx, p) + if err != nil { + return + } + return ires.(string), nil +} + +// WrapMethodClientEndpoint wraps the Method endpoint with the client +// interceptors defined in the design. +func WrapMethodClientEndpoint(endpoint goa.Endpoint, i ClientInterceptors) goa.Endpoint { + endpoint = wrapClientTracing(endpoint, i, "Method") + return endpoint +} + +// wrapClientTracing applies the Tracing interceptor to endpoints. +func wrapClientTracing(endpoint goa.Endpoint, i ClientInterceptors, method string) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + info := &TracingInfo{ + Service: "ServiceWithClientInterceptor", + Method: method, + Endpoint: endpoint, + RawPayload: req, + } + next := func(ctx context.Context) (any, error) { + return endpoint(ctx, req) + } + return i.Tracing(ctx, info, next) + } +}` diff --git a/codegen/service/testdata/endpoint_code.go b/codegen/service/testdata/endpoint_code.go index 978ff5e5f7..27967a6e08 100644 --- a/codegen/service/testdata/endpoint_code.go +++ b/codegen/service/testdata/endpoint_code.go @@ -513,3 +513,184 @@ func NewBidirectionalStreamingNoPayloadMethodEndpoint(s Service) goa.Endpoint { } } ` + +var EndpointWithServerInterceptor = `// Endpoints wraps the "ServiceWithServerInterceptor" service endpoints. +type Endpoints struct { + Method goa.Endpoint +} + +// NewEndpoints wraps the methods of the "ServiceWithServerInterceptor" service +// with endpoints. +func NewEndpoints(s Service, si ServerInterceptors) *Endpoints { + endpoints := &Endpoints{ + Method: NewMethodEndpoint(s), + } + endpoints.Method = WrapMethodEndpoint(endpoints.Method, si) + return endpoints +} + +// Use applies the given middleware to all the "ServiceWithServerInterceptor" +// service endpoints. +func (e *Endpoints) Use(m func(goa.Endpoint) goa.Endpoint) { + e.Method = m(e.Method) +} + +// NewMethodEndpoint returns an endpoint function that calls the method +// "Method" of service "ServiceWithServerInterceptor". +func NewMethodEndpoint(s Service) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + p := req.(string) + return s.Method(ctx, p) + } +} + +// WrapMethodEndpoint wraps the Method endpoint with the server-side +// interceptors defined in the design. +func WrapMethodEndpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint { + endpoint = wrapLogging(endpoint, i, "Method") + return endpoint +} + +// wrapLogging applies the Logging interceptor to endpoints. +func wrapLogging(endpoint goa.Endpoint, i ServerInterceptors, method string) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + info := &LoggingInfo{ + Service: "ServiceWithServerInterceptor", + Method: method, + Endpoint: endpoint, + RawPayload: req, + } + next := func(ctx context.Context) (any, error) { + return endpoint(ctx, req) + } + return i.Logging(ctx, info, next) + } +}` + +var EndpointWithMultipleInterceptors = `// Endpoints wraps the "ServiceWithMultipleInterceptors" service endpoints. +type Endpoints struct { + Method goa.Endpoint +} + +// NewEndpoints wraps the methods of the "ServiceWithMultipleInterceptors" +// service with endpoints. +func NewEndpoints(s Service, si ServerInterceptors) *Endpoints { + endpoints := &Endpoints{ + Method: NewMethodEndpoint(s), + } + endpoints.Method = WrapMethodEndpoint(endpoints.Method, si) + return endpoints +} + +// Use applies the given middleware to all the +// "ServiceWithMultipleInterceptors" service endpoints. +func (e *Endpoints) Use(m func(goa.Endpoint) goa.Endpoint) { + e.Method = m(e.Method) +} + +// NewMethodEndpoint returns an endpoint function that calls the method +// "Method" of service "ServiceWithMultipleInterceptors". +func NewMethodEndpoint(s Service) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + p := req.(string) + return s.Method(ctx, p) + } +} + +// WrapMethodEndpoint wraps the Method endpoint with the server-side +// interceptors defined in the design. +func WrapMethodEndpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint { + endpoint = wrapLogging(endpoint, i, "Method") + endpoint = wrapMetrics(endpoint, i, "Method") + return endpoint +} + +// wrapLogging applies the Logging interceptor to endpoints. +func wrapLogging(endpoint goa.Endpoint, i ServerInterceptors, method string) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + info := &LoggingInfo{ + Service: "ServiceWithMultipleInterceptors", + Method: method, + Endpoint: endpoint, + RawPayload: req, + } + next := func(ctx context.Context) (any, error) { + return endpoint(ctx, req) + } + return i.Logging(ctx, info, next) + } +} + +// wrapMetrics applies the Metrics interceptor to endpoints. +func wrapMetrics(endpoint goa.Endpoint, i ServerInterceptors, method string) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + info := &MetricsInfo{ + Service: "ServiceWithMultipleInterceptors", + Method: method, + Endpoint: endpoint, + RawPayload: req, + } + next := func(ctx context.Context) (any, error) { + return endpoint(ctx, req) + } + return i.Metrics(ctx, info, next) + } +}` + +var EndpointStreamingWithInterceptor = `// Endpoints wraps the "ServiceStreamingWithInterceptor" service endpoints. +type Endpoints struct { + Method goa.Endpoint +} + +// MethodEndpointInput holds both the payload and the server stream of the +// "Method" method. +type MethodEndpointInput struct { + // Stream is the server stream used by the "Method" method to send data. + Stream MethodServerStream +} + +// NewEndpoints wraps the methods of the "ServiceStreamingWithInterceptor" service +// with endpoints. +func NewEndpoints(s Service, i ServerInterceptors) *Endpoints { + return &Endpoints{ + Method: WrapMethodEndpoint(NewMethodEndpoint(s), i), + } +} + +// Use applies the given middleware to all the "ServiceStreamingWithInterceptor" +// service endpoints. +func (e *Endpoints) Use(m func(goa.Endpoint) goa.Endpoint) { + e.Method = m(e.Method) +} + +// NewMethodEndpoint returns an endpoint function that calls the method "Method" +// of service "ServiceStreamingWithInterceptor". +func NewMethodEndpoint(s Service) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + ep := req.(*MethodEndpointInput) + return nil, s.Method(ctx, ep.Stream) + } +} + +// WrapMethodEndpoint wraps the Method endpoint with the server-side +// interceptors defined in the design. +func WrapMethodEndpoint(endpoint goa.Endpoint, i ServerInterceptors) goa.Endpoint { + endpoint = wrapLogging(endpoint, i, "Method") + return endpoint +} + +// wrapLogging applies the Logging interceptor to endpoints. +func wrapLogging(endpoint goa.Endpoint, i ServerInterceptors, method string) goa.Endpoint { + return func(ctx context.Context, req any) (any, error) { + info := &LoggingInfo{ + Service: "ServiceStreamingWithInterceptor", + Method: method, + Endpoint: endpoint, + RawPayload: req.(*MethodEndpointInput).Payload, + } + next := func(ctx context.Context) (any, error) { + return endpoint(ctx, req) + } + return i.Logging(ctx, info, next) + } +}` diff --git a/codegen/service/testdata/endpoint_dsls.go b/codegen/service/testdata/endpoint_dsls.go index 4ca2c7b435..de3a220401 100644 --- a/codegen/service/testdata/endpoint_dsls.go +++ b/codegen/service/testdata/endpoint_dsls.go @@ -181,3 +181,47 @@ var BidirectionalStreamingEndpointDSL = func() { }) }) } + +var EndpointWithServerInterceptorDSL = func() { + Interceptor("logging") + Service("ServiceWithServerInterceptor", func() { + Method("Method", func() { + ServerInterceptor("logging") + Payload(String) + Result(String) + HTTP(func() { + POST("/") + }) + }) + }) +} + +var EndpointWithMultipleInterceptorsDSL = func() { + Interceptor("logging") + Interceptor("metrics") + Service("ServiceWithMultipleInterceptors", func() { + Method("Method", func() { + ServerInterceptor("logging") + ServerInterceptor("metrics") + Payload(String) + Result(String) + HTTP(func() { + POST("/") + }) + }) + }) +} + +var EndpointStreamingWithInterceptorDSL = func() { + Interceptor("logging") + Service("ServiceStreamingWithInterceptor", func() { + Method("Method", func() { + ServerInterceptor("logging") + StreamingPayload(String) + StreamingResult(String) + HTTP(func() { + GET("/") + }) + }) + }) +} diff --git a/codegen/service/testdata/interceptors/interceptor-with-read-payload.golden b/codegen/service/testdata/interceptors/interceptor-with-read-payload.golden new file mode 100644 index 0000000000..eb5152b895 --- /dev/null +++ b/codegen/service/testdata/interceptors/interceptor-with-read-payload.golden @@ -0,0 +1,41 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Validation(context.Context, *ValidationInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // ValidationInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + ValidationInfo goa.InterceptorInfo + + // ValidationPayloadAccess provides type-safe access to the method payload. + // It allows reading and writing specific fields of the payload as defined + // in the design. + ValidationPayloadAccess interface { + Name() string + } +) + +// Private implementation types +type ( + validationPayloadAccess struct { + payload *MethodPayload + } +) + +// Public accessor methods for Info types +// Payload returns a type-safe accessor for the method payload. +func (info *ValidationInfo) Payload() ValidationPayloadAccess { + return &validationPayloadAccess{payload: info.RawPayload.(*MethodPayload)} +} + +// Private implementation methods +func (p *validationPayloadAccess) Name() string { + return p.payload.Name +} + + diff --git a/codegen/service/testdata/interceptors/interceptor-with-read-result.golden b/codegen/service/testdata/interceptors/interceptor-with-read-result.golden new file mode 100644 index 0000000000..254d1ddaf0 --- /dev/null +++ b/codegen/service/testdata/interceptors/interceptor-with-read-result.golden @@ -0,0 +1,41 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Caching(context.Context, *CachingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // CachingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + CachingInfo goa.InterceptorInfo + + // CachingResultAccess provides type-safe access to the method result. + // It allows reading and writing specific fields of the result as defined + // in the design. + CachingResultAccess interface { + Data() string + } +) + +// Private implementation types +type ( + cachingResultAccess struct { + result *MethodResult + } +) + +// Public accessor methods for Info types +// Result returns a type-safe accessor for the method result. +func (info *CachingInfo) Result(res any) CachingResultAccess { + return &cachingResultAccess{result: res.(*MethodResult)} +} + +// Private implementation methods +func (r *cachingResultAccess) Data() string { + return r.result.Data +} + + diff --git a/codegen/service/testdata/interceptors/interceptor-with-read-write-payload.golden b/codegen/service/testdata/interceptors/interceptor-with-read-write-payload.golden new file mode 100644 index 0000000000..07d0f938cb --- /dev/null +++ b/codegen/service/testdata/interceptors/interceptor-with-read-write-payload.golden @@ -0,0 +1,45 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Validation(context.Context, *ValidationInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // ValidationInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + ValidationInfo goa.InterceptorInfo + + // ValidationPayloadAccess provides type-safe access to the method payload. + // It allows reading and writing specific fields of the payload as defined + // in the design. + ValidationPayloadAccess interface { + Name() string + SetName(string) + } +) + +// Private implementation types +type ( + validationPayloadAccess struct { + payload *MethodPayload + } +) + +// Public accessor methods for Info types +// Payload returns a type-safe accessor for the method payload. +func (info *ValidationInfo) Payload() ValidationPayloadAccess { + return &validationPayloadAccess{payload: info.RawPayload.(*MethodPayload)} +} + +// Private implementation methods +func (p *validationPayloadAccess) Name() string { + return p.payload.Name +} +func (p *validationPayloadAccess) SetName(v string) { + p.payload.Name = v +} + + diff --git a/codegen/service/testdata/interceptors/interceptor-with-read-write-result.golden b/codegen/service/testdata/interceptors/interceptor-with-read-write-result.golden new file mode 100644 index 0000000000..202260177d --- /dev/null +++ b/codegen/service/testdata/interceptors/interceptor-with-read-write-result.golden @@ -0,0 +1,45 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Caching(context.Context, *CachingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // CachingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + CachingInfo goa.InterceptorInfo + + // CachingResultAccess provides type-safe access to the method result. + // It allows reading and writing specific fields of the result as defined + // in the design. + CachingResultAccess interface { + Data() string + SetData(string) + } +) + +// Private implementation types +type ( + cachingResultAccess struct { + result *MethodResult + } +) + +// Public accessor methods for Info types +// Result returns a type-safe accessor for the method result. +func (info *CachingInfo) Result(res any) CachingResultAccess { + return &cachingResultAccess{result: res.(*MethodResult)} +} + +// Private implementation methods +func (r *cachingResultAccess) Data() string { + return r.result.Data +} +func (r *cachingResultAccess) SetData(v string) { + r.result.Data = v +} + + diff --git a/codegen/service/testdata/interceptors/interceptor-with-write-payload.golden b/codegen/service/testdata/interceptors/interceptor-with-write-payload.golden new file mode 100644 index 0000000000..7f0eabeb04 --- /dev/null +++ b/codegen/service/testdata/interceptors/interceptor-with-write-payload.golden @@ -0,0 +1,41 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Validation(context.Context, *ValidationInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // ValidationInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + ValidationInfo goa.InterceptorInfo + + // ValidationPayloadAccess provides type-safe access to the method payload. + // It allows reading and writing specific fields of the payload as defined + // in the design. + ValidationPayloadAccess interface { + SetName(string) + } +) + +// Private implementation types +type ( + validationPayloadAccess struct { + payload *MethodPayload + } +) + +// Public accessor methods for Info types +// Payload returns a type-safe accessor for the method payload. +func (info *ValidationInfo) Payload() ValidationPayloadAccess { + return &validationPayloadAccess{payload: info.RawPayload.(*MethodPayload)} +} + +// Private implementation methods +func (p *validationPayloadAccess) SetName(v string) { + p.payload.Name = v +} + + diff --git a/codegen/service/testdata/interceptors/interceptor-with-write-result.golden b/codegen/service/testdata/interceptors/interceptor-with-write-result.golden new file mode 100644 index 0000000000..362ae77d6a --- /dev/null +++ b/codegen/service/testdata/interceptors/interceptor-with-write-result.golden @@ -0,0 +1,41 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Caching(context.Context, *CachingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // CachingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + CachingInfo goa.InterceptorInfo + + // CachingResultAccess provides type-safe access to the method result. + // It allows reading and writing specific fields of the result as defined + // in the design. + CachingResultAccess interface { + SetData(string) + } +) + +// Private implementation types +type ( + cachingResultAccess struct { + result *MethodResult + } +) + +// Public accessor methods for Info types +// Result returns a type-safe accessor for the method result. +func (info *CachingInfo) Result(res any) CachingResultAccess { + return &cachingResultAccess{result: res.(*MethodResult)} +} + +// Private implementation methods +func (r *cachingResultAccess) SetData(v string) { + r.result.Data = v +} + + diff --git a/codegen/service/testdata/interceptors/multiple-interceptors.golden b/codegen/service/testdata/interceptors/multiple-interceptors.golden new file mode 100644 index 0000000000..111d148050 --- /dev/null +++ b/codegen/service/testdata/interceptors/multiple-interceptors.golden @@ -0,0 +1,30 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Logging(context.Context, *LoggingInfo, goa.NextFunc) (any, error) + + Tracing(context.Context, *TracingInfo, goa.NextFunc) (any, error) +} // ClientInterceptors defines the interface for all client-side interceptors. +// Client interceptors execute after the payload is encoded and before the request +// is sent to the server (request interceptors) or after the response is decoded +// and before the result is returned to the client (response interceptors). +type ClientInterceptors interface { + Metrics(context.Context, *MetricsInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // LoggingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + LoggingInfo goa.InterceptorInfo + // TracingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + TracingInfo goa.InterceptorInfo + // MetricsInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + MetricsInfo goa.InterceptorInfo +) + + diff --git a/codegen/service/testdata/interceptors/single-api-server-interceptor.golden b/codegen/service/testdata/interceptors/single-api-server-interceptor.golden new file mode 100644 index 0000000000..ae003b0548 --- /dev/null +++ b/codegen/service/testdata/interceptors/single-api-server-interceptor.golden @@ -0,0 +1,16 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Logging(context.Context, *LoggingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // LoggingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + LoggingInfo goa.InterceptorInfo +) + + diff --git a/codegen/service/testdata/interceptors/single-client-interceptor.golden b/codegen/service/testdata/interceptors/single-client-interceptor.golden new file mode 100644 index 0000000000..0fc4e0f1a9 --- /dev/null +++ b/codegen/service/testdata/interceptors/single-client-interceptor.golden @@ -0,0 +1,16 @@ +// ClientInterceptors defines the interface for all client-side interceptors. +// Client interceptors execute after the payload is encoded and before the request +// is sent to the server (request interceptors) or after the response is decoded +// and before the result is returned to the client (response interceptors). +type ClientInterceptors interface { + Tracing(context.Context, *TracingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // TracingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + TracingInfo goa.InterceptorInfo +) + + diff --git a/codegen/service/testdata/interceptors/single-method-server-interceptor.golden b/codegen/service/testdata/interceptors/single-method-server-interceptor.golden new file mode 100644 index 0000000000..ae003b0548 --- /dev/null +++ b/codegen/service/testdata/interceptors/single-method-server-interceptor.golden @@ -0,0 +1,16 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Logging(context.Context, *LoggingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // LoggingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + LoggingInfo goa.InterceptorInfo +) + + diff --git a/codegen/service/testdata/interceptors/single-service-server-interceptor.golden b/codegen/service/testdata/interceptors/single-service-server-interceptor.golden new file mode 100644 index 0000000000..ae003b0548 --- /dev/null +++ b/codegen/service/testdata/interceptors/single-service-server-interceptor.golden @@ -0,0 +1,16 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Logging(context.Context, *LoggingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // LoggingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + LoggingInfo goa.InterceptorInfo +) + + diff --git a/codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload.golden b/codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload.golden new file mode 100644 index 0000000000..68f1cf20cb --- /dev/null +++ b/codegen/service/testdata/interceptors/streaming-interceptors-with-read-payload.golden @@ -0,0 +1,45 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Logging(context.Context, *LoggingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // LoggingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + LoggingInfo goa.InterceptorInfo + + // LoggingPayloadAccess provides type-safe access to the method payload. + // It allows reading and writing specific fields of the payload as defined + // in the design. + LoggingPayloadAccess interface { + Initial() string + } +) + +// Private implementation types +type ( + loggingPayloadAccess struct { + payload *MethodPayload + } +) + +// Public accessor methods for Info types +// Payload returns a type-safe accessor for the method payload. +func (info *LoggingInfo) Payload() LoggingPayloadAccess { + return &loggingPayloadAccess{payload: info.RawPayload.(*MethodPayload)} +} + +// Private implementation methods +func (p *loggingPayloadAccess) Initial() string { + if p.payload.Initial == nil { + var zero string + return zero + } + return *p.payload.Initial +} + + diff --git a/codegen/service/testdata/interceptors/streaming-interceptors-with-read-result.golden b/codegen/service/testdata/interceptors/streaming-interceptors-with-read-result.golden new file mode 100644 index 0000000000..4dec14cb90 --- /dev/null +++ b/codegen/service/testdata/interceptors/streaming-interceptors-with-read-result.golden @@ -0,0 +1,45 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Logging(context.Context, *LoggingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // LoggingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + LoggingInfo goa.InterceptorInfo + + // LoggingResultAccess provides type-safe access to the method result. + // It allows reading and writing specific fields of the result as defined + // in the design. + LoggingResultAccess interface { + Data() string + } +) + +// Private implementation types +type ( + loggingResultAccess struct { + result *MethodResult + } +) + +// Public accessor methods for Info types +// Result returns a type-safe accessor for the method result. +func (info *LoggingInfo) Result(res any) LoggingResultAccess { + return &loggingResultAccess{result: res.(*MethodResult)} +} + +// Private implementation methods +func (r *loggingResultAccess) Data() string { + if r.result.Data == nil { + var zero string + return zero + } + return *r.result.Data +} + + diff --git a/codegen/service/testdata/interceptors/streaming-interceptors.golden b/codegen/service/testdata/interceptors/streaming-interceptors.golden new file mode 100644 index 0000000000..ae003b0548 --- /dev/null +++ b/codegen/service/testdata/interceptors/streaming-interceptors.golden @@ -0,0 +1,16 @@ +// ServerInterceptors defines the interface for all server-side interceptors. +// Server interceptors execute after the request is decoded and before the payload +// is sent to the service (request interceptors) or after the service returns and +// before the response is encoded (response interceptors). +type ServerInterceptors interface { + Logging(context.Context, *LoggingInfo, goa.NextFunc) (any, error) +} + +// Access interfaces for interceptor payloads and results +type ( + // LoggingInfo provides metadata about the current interception. + // It includes service name, method name, and access to the endpoint. + LoggingInfo goa.InterceptorInfo +) + + diff --git a/codegen/service/testdata/interceptors_dsls.go b/codegen/service/testdata/interceptors_dsls.go new file mode 100644 index 0000000000..dead879c9f --- /dev/null +++ b/codegen/service/testdata/interceptors_dsls.go @@ -0,0 +1,297 @@ +package testdata + +import ( + . "goa.design/goa/v3/dsl" +) + +var NoInterceptorsDSL = func() { + Service("NoInterceptors", func() { + Method("Method", func() { + HTTP(func() { GET("/") }) + }) + }) +} + +var SingleAPIServerInterceptorDSL = func() { + Interceptor("logging") + API("SingleAPIServerInterceptor", func() { + ServerInterceptor("logging") + }) + Service("SingleAPIServerInterceptor", func() { + Method("Method", func() { + HTTP(func() { GET("/1") }) + }) + Method("Method2", func() { + HTTP(func() { GET("/2") }) + }) + }) +} + +var SingleServiceServerInterceptorDSL = func() { + Interceptor("logging") + Service("SingleServerInterceptor", func() { + ServerInterceptor("logging") + Method("Method", func() { + HTTP(func() { + GET("/1") + }) + }) + Method("Method2", func() { + HTTP(func() { + GET("/2") + }) + }) + }) +} + +var SingleMethodServerInterceptorDSL = func() { + Interceptor("logging") + Service("SingleMethodServerInterceptor", func() { + Method("Method", func() { + ServerInterceptor("logging") + HTTP(func() { GET("/1") }) + }) + Method("Method2", func() { + HTTP(func() { GET("/2") }) + }) + }) +} + +var SingleClientInterceptorDSL = func() { + Interceptor("tracing") + Service("SingleClientInterceptor", func() { + ClientInterceptor("tracing") + Method("Method", func() { + Payload(func() { + Attribute("id", Int) + }) + Result(func() { + Attribute("value", String) + }) + HTTP(func() { GET("/") }) + }) + }) +} + +var MultipleInterceptorsDSL = func() { + Interceptor("logging") + Interceptor("tracing") + Interceptor("metrics") + Service("MultipleInterceptors", func() { + ServerInterceptor("logging") + ServerInterceptor("tracing") + ClientInterceptor("metrics") + Method("Method", func() { + Payload(func() { + Attribute("query", String) + }) + Result(func() { + Attribute("data", String) + }) + HTTP(func() { GET("/") }) + }) + }) +} + +var InterceptorWithReadPayloadDSL = func() { + Interceptor("validation", func() { + ReadPayload(func() { + Attribute("name") + }) + }) + Service("InterceptorWithReadPayload", func() { + ServerInterceptor("validation") + ClientInterceptor("validation") + Method("Method", func() { + Payload(func() { + Attribute("name", String) + Required("name") + }) + HTTP(func() { POST("/") }) + }) + }) +} + +var InterceptorWithWritePayloadDSL = func() { + Interceptor("validation", func() { + WritePayload(func() { + Attribute("name") + }) + }) + Service("InterceptorWithWritePayload", func() { + ServerInterceptor("validation") + ClientInterceptor("validation") + Method("Method", func() { + Payload(func() { + Attribute("name", String) + Required("name") + }) + HTTP(func() { POST("/") }) + }) + }) +} + +var InterceptorWithReadWritePayloadDSL = func() { + Interceptor("validation", func() { + ReadPayload(func() { + Attribute("name") + }) + WritePayload(func() { + Attribute("name") + }) + }) + Service("InterceptorWithReadWritePayload", func() { + ServerInterceptor("validation") + ClientInterceptor("validation") + Method("Method", func() { + Payload(func() { + Attribute("name", String) + Required("name") + }) + HTTP(func() { POST("/") }) + }) + }) +} + +var InterceptorWithReadResultDSL = func() { + Interceptor("caching", func() { + ReadResult(func() { + Attribute("data") + }) + }) + Service("InterceptorWithReadResult", func() { + ServerInterceptor("caching") + ClientInterceptor("caching") + Method("Method", func() { + Result(func() { + Attribute("data", String) + Required("data") + }) + HTTP(func() { GET("/") }) + }) + }) +} + +var InterceptorWithWriteResultDSL = func() { + Interceptor("caching", func() { + WriteResult(func() { + Attribute("data") + }) + }) + Service("InterceptorWithWriteResult", func() { + ServerInterceptor("caching") + ClientInterceptor("caching") + Method("Method", func() { + Result(func() { + Attribute("data", String) + Required("data") + }) + HTTP(func() { GET("/") }) + }) + }) +} + +var InterceptorWithReadWriteResultDSL = func() { + Interceptor("caching", func() { + ReadResult(func() { + Attribute("data") + }) + WriteResult(func() { + Attribute("data") + }) + }) + Service("InterceptorWithReadWriteResult", func() { + ServerInterceptor("caching") + ClientInterceptor("caching") + Method("Method", func() { + Result(func() { + Attribute("data", String) + Required("data") + }) + HTTP(func() { GET("/") }) + }) + }) +} + +var StreamingInterceptorsDSL = func() { + Interceptor("logging") + Service("StreamingInterceptors", func() { + ServerInterceptor("logging") + Method("Method", func() { + StreamingPayload(func() { + Attribute("chunk", String) + }) + StreamingResult(func() { + Attribute("data", String) + }) + HTTP(func() { GET("/stream") }) + }) + }) +} + +var StreamingInterceptorsWithReadPayloadDSL = func() { + Interceptor("logging", func() { + ReadPayload(func() { + Attribute("initial") + }) + }) + Service("StreamingInterceptorsWithReadPayload", func() { + ServerInterceptor("logging") + Method("Method", func() { + Payload(func() { + Attribute("initial", String) + }) + StreamingPayload(func() { + Attribute("chunk", String) + }) + HTTP(func() { + Header("initial") + GET("/stream") + }) + }) + }) +} + +var StreamingInterceptorsWithReadResultDSL = func() { + Interceptor("logging", func() { + ReadResult(func() { + Attribute("data") + }) + }) + Service("StreamingInterceptorsWithReadResult", func() { + ServerInterceptor("logging") + Method("Method", func() { + Payload(func() { + Attribute("initial", String) + }) + StreamingPayload(func() { + Attribute("chunk", String) + }) + Result(func() { + Attribute("data", String) + }) + HTTP(func() { + Header("initial") + GET("/stream") + }) + }) + }) +} + +// Invalid DSL +var StreamingResultInterceptorDSL = func() { + Interceptor("logging", func() { + ReadResult(func() { + Attribute("data") + }) + }) + Service("StreamingResultInterceptor", func() { + ServerInterceptor("logging") + Method("Method", func() { + StreamingResult(func() { + Attribute("data", String) + }) + HTTP(func() { GET("/stream") }) + }) + }) +} diff --git a/codegen/service/testdata/service_dsls.go b/codegen/service/testdata/service_dsls.go index 5fd9b06b2c..76f7f917b8 100644 --- a/codegen/service/testdata/service_dsls.go +++ b/codegen/service/testdata/service_dsls.go @@ -967,3 +967,17 @@ var PkgPathPayloadAttributeDSL = func() { }) }) } + +var EndpointWithClientInterceptorDSL = func() { + Interceptor("tracing") + Service("ServiceWithClientInterceptor", func() { + Method("Method", func() { + ClientInterceptor("tracing") + Payload(String) + Result(String) + HTTP(func() { + POST("/") + }) + }) + }) +} diff --git a/codegen/service/testing.go b/codegen/service/testing.go new file mode 100644 index 0000000000..28243ecfac --- /dev/null +++ b/codegen/service/testing.go @@ -0,0 +1,42 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "goa.design/goa/v3/eval" + "goa.design/goa/v3/expr" +) + +// initDSL initializes the DSL environment and returns the root. +func initDSL(t *testing.T) *expr.RootExpr { + // reset all roots and codegen data structures + Services = make(ServicesData) + eval.Reset() + expr.Root = new(expr.RootExpr) + expr.GeneratedResultTypes = new(expr.ResultTypesRoot) + expr.Root.API = expr.NewAPIExpr("test api", func() {}) + expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()} + root := expr.Root + require.NoError(t, eval.Register(root)) + require.NoError(t, eval.Register(expr.GeneratedResultTypes)) + return root +} + +// runDSL returns the DSL root resulting from running the given DSL. +func runDSL(t *testing.T, dsl func()) *expr.RootExpr { + root := initDSL(t) + require.True(t, eval.Execute(dsl, nil)) + require.NoError(t, eval.RunDSL()) + return root +} + +// runDSLWithError returns the DSL root and error from running the given DSL. +func runDSLWithError(t *testing.T, dsl func()) (*expr.RootExpr, error) { + root := initDSL(t) + require.True(t, eval.Execute(dsl, nil)) + err := eval.RunDSL() + require.Error(t, err) + return root, err +} diff --git a/dsl/description.go b/dsl/description.go index 931c4f2ab3..b8e9c78479 100644 --- a/dsl/description.go +++ b/dsl/description.go @@ -45,6 +45,8 @@ func Description(d string) { e.Description = d case *expr.GRPCResponseExpr: e.Description = d + case *expr.InterceptorExpr: + e.Description = d default: eval.IncompatibleDSL() } diff --git a/dsl/interceptor.go b/dsl/interceptor.go new file mode 100644 index 0000000000..28b99ce7dd --- /dev/null +++ b/dsl/interceptor.go @@ -0,0 +1,293 @@ +package dsl + +import ( + "goa.design/goa/v3/eval" + "goa.design/goa/v3/expr" +) + +// Interceptor defines a request interceptor. Interceptors provide a type-safe way +// to read and write from and to the request and response. +// +// Interceptor must appear in a API, Service or Method expression. +// +// Interceptor accepts two arguments: the name of the interceptor and the +// defining DSL. +// +// Example: +// +// var Cache = Interceptor("Cache", func() { +// Description("Server-side interceptor which implements a transparent cache for the loaded records") +// +// ReadPayload(func() { +// Attribute("id") +// }) +// +// WriteResult(func() { +// Attribute("cachedAt") +// }) +// }) +func Interceptor(name string, fn ...func()) *expr.InterceptorExpr { + if len(fn) > 1 { + eval.ReportError("interceptor %q cannot have multiple definitions", name) + return nil + } + i := &expr.InterceptorExpr{Name: name} + if name == "" { + eval.ReportError("interceptor name cannot be empty") + return i + } + if len(fn) > 0 { + if !eval.Execute(fn[0], i) { + return i + } + } + for _, i := range expr.Root.Interceptors { + if i.Name == name { + eval.ReportError("interceptor %q already defined", name) + return i + } + } + expr.Root.Interceptors = append(expr.Root.Interceptors, i) + return i +} + +// ReadPayload defines the payload attributes read by the interceptor. +// +// ReadPayload must appear in an interceptor DSL. +// +// ReadPayload takes a function as argument which can use the Attribute DSL to +// define the attributes read by the interceptor. +// +// Example: +// +// ReadPayload(func() { +// Attribute("id") +// }) +// +// ReadPayload also accepts user defined types: +// +// // Interceptor can read any payload field +// ReadPayload(MethodPayload) +func ReadPayload(arg any) { + setInterceptorAttribute(arg, func(i *expr.InterceptorExpr, attr *expr.AttributeExpr) { + i.ReadPayload = attr + }) +} + +// WritePayload defines the payload attributes written by the interceptor. +// +// WritePayload must appear in an interceptor DSL. +// +// WritePayload takes a function as argument which can use the Attribute DSL to +// define the attributes written by the interceptor. +// +// Example: +// +// WritePayload(func() { +// Attribute("auth") +// }) +// +// WritePayload also accepts user defined types: +// +// // Interceptor can write any payload field +// WritePayload(MethodPayload) +func WritePayload(arg any) { + setInterceptorAttribute(arg, func(i *expr.InterceptorExpr, attr *expr.AttributeExpr) { + i.WritePayload = attr + }) +} + +// ReadResult defines the result attributes read by the interceptor. +// +// ReadResult must appear in an interceptor DSL. +// +// ReadResult takes a function as argument which can use the Attribute DSL to +// define the attributes read by the interceptor. +// +// Example: +// +// ReadResult(func() { +// Attribute("cachedAt") +// }) +// +// ReadResult also accepts user defined types: +// +// // Interceptor can read any result field +// ReadResult(MethodResult) +func ReadResult(arg any) { + setInterceptorAttribute(arg, func(i *expr.InterceptorExpr, attr *expr.AttributeExpr) { + i.ReadResult = attr + }) +} + +// WriteResult defines the result attributes written by the interceptor. +// +// WriteResult must appear in an interceptor DSL. +// +// WriteResult takes a function as argument which can use the Attribute DSL to +// define the attributes written by the interceptor. +// +// Example: +// +// WriteResult(func() { +// Attribute("cachedAt") +// }) +// +// WriteResult also accepts user defined types: +// +// // Interceptor can write any result field +// WriteResult(MethodResult) +func WriteResult(arg any) { + setInterceptorAttribute(arg, func(i *expr.InterceptorExpr, attr *expr.AttributeExpr) { + i.WriteResult = attr + }) +} + +// ServerInterceptor lists the server-side interceptors that apply to all the +// API endpoints, all the service endpoints or a specific endpoint. +// +// ServerInterceptor must appear in a API, Service or Method expression. +// +// ServerInterceptor accepts one or more interceptor or interceptor names as +// arguments. ServerInterceptor can appear multiple times in the same DSL. +// +// Example: +// +// Method("get_record", func() { +// // Interceptor defined with the Interceptor DSL +// ServerInterceptor(SetDeadline) +// +// // Name of interceptor defined with the Interceptor DSL +// ServerInterceptor("Cache") +// +// // Interceptor defined inline +// ServerInterceptor(Interceptor("CheckUserID", func() { +// ReadPayload(func() { +// Attribute("auth") +// }) +// })) +// +// // ... rest of the method DSL +// }) +func ServerInterceptor(interceptors ...any) { + addInterceptors(interceptors, false) +} + +// ClientInterceptor lists the client-side interceptors that apply to all the +// API endpoints, all the service endpoints or a specific endpoint. +// +// ClientInterceptor must appear in a API, Service or Method expression. +// +// ClientInterceptor accepts one or more interceptor or interceptor names as +// arguments. ClientInterceptor can appear multiple times in the same DSL. +// +// Example: +// +// Method("get_record", func() { +// // Interceptor defined with the Interceptor DSL +// ClientInterceptor(Retry) +// +// // Name of interceptor defined with the Interceptor DSL +// ClientInterceptor("Cache") +// +// // Interceptor defined inline +// ClientInterceptor(Interceptor("Sign", func() { +// ReadPayload(func() { +// Attribute("user_id") +// }) +// WritePayload(func() { +// Attribute("auth") +// }) +// })) +// +// // ... rest of the method DSL +// }) +func ClientInterceptor(interceptors ...any) { + addInterceptors(interceptors, true) +} + +// setInterceptorAttribute is a helper function that handles the common logic for +// setting interceptor attributes (ReadPayload, WritePayload, ReadResult, WriteResult). +func setInterceptorAttribute(arg any, setter func(i *expr.InterceptorExpr, attr *expr.AttributeExpr)) { + i, ok := eval.Current().(*expr.InterceptorExpr) + if !ok { + eval.IncompatibleDSL() + return + } + + var attr *expr.AttributeExpr + switch fn := arg.(type) { + case func(): + attr = &expr.AttributeExpr{Type: &expr.Object{}} + if !eval.Execute(fn, attr) { + return + } + case *expr.AttributeExpr: + attr = fn + case expr.DataType: + attr = &expr.AttributeExpr{Type: fn} + default: + eval.InvalidArgError("type, attribute or func()", arg) + return + } + setter(i, attr) +} + +// addInterceptors is a helper function that validates and adds interceptors to +// the current expression. +func addInterceptors(interceptors []any, client bool) { + kind := "ServerInterceptor" + if client { + kind = "ClientInterceptor" + } + if len(interceptors) == 0 { + eval.ReportError("%s: at least one interceptor must be specified", kind) + return + } + + var ints []*expr.InterceptorExpr + for _, i := range interceptors { + switch i := i.(type) { + case *expr.InterceptorExpr: + ints = append(ints, i) + case string: + var found bool + for _, in := range expr.Root.Interceptors { + if in.Name == i { + ints = append(ints, in) + found = true + break + } + } + if !found { + eval.ReportError("%s: interceptor %q not found", kind, i) + } + default: + eval.ReportError("%s: invalid interceptor %v", kind, i) + } + } + + current := eval.Current() + switch actual := current.(type) { + case *expr.APIExpr: + if client { + actual.ClientInterceptors = append(actual.ClientInterceptors, ints...) + } else { + actual.ServerInterceptors = append(actual.ServerInterceptors, ints...) + } + case *expr.ServiceExpr: + if client { + actual.ClientInterceptors = append(actual.ClientInterceptors, ints...) + } else { + actual.ServerInterceptors = append(actual.ServerInterceptors, ints...) + } + case *expr.MethodExpr: + if client { + actual.ClientInterceptors = append(actual.ClientInterceptors, ints...) + } else { + actual.ServerInterceptors = append(actual.ServerInterceptors, ints...) + } + default: + eval.IncompatibleDSL() + } +} diff --git a/dsl/interceptor_test.go b/dsl/interceptor_test.go new file mode 100644 index 0000000000..7e519768d1 --- /dev/null +++ b/dsl/interceptor_test.go @@ -0,0 +1,250 @@ +package dsl_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + . "goa.design/goa/v3/dsl" + "goa.design/goa/v3/eval" + "goa.design/goa/v3/expr" +) + +func TestInterceptor(t *testing.T) { + cases := map[string]struct { + DSL func() + Assert func(t *testing.T, intr *expr.InterceptorExpr) + }{ + "valid-minimal": { + func() { + Interceptor("minimal", func() {}) + }, + func(t *testing.T, intr *expr.InterceptorExpr) { + require.NotNil(t, intr, "interceptor should not be nil") + assert.Equal(t, "minimal", intr.Name) + }, + }, + "valid-complete": { + func() { + Interceptor("complete", func() { + Description("test interceptor") + ReadPayload(func() { + Attribute("foo", String) + }) + WritePayload(func() { + Attribute("bar", String) + }) + ReadResult(func() { + Attribute("baz", String) + }) + WriteResult(func() { + Attribute("qux", String) + }) + }) + }, + func(t *testing.T, intr *expr.InterceptorExpr) { + require.NotNil(t, intr, "interceptor should not be nil") + assert.Equal(t, "test interceptor", intr.Description) + + require.NotNil(t, intr.ReadPayload, "ReadPayload should not be nil") + rp := expr.AsObject(intr.ReadPayload.Type) + require.NotNil(t, rp, "ReadPayload should be an object") + assert.NotNil(t, rp.Attribute("foo"), "ReadPayload should have a foo attribute") + + require.NotNil(t, intr.WritePayload, "WritePayload should not be nil") + wp := expr.AsObject(intr.WritePayload.Type) + require.NotNil(t, wp, "WritePayload should be an object") + assert.NotNil(t, wp.Attribute("bar"), "WritePayload should have a bar attribute") + + require.NotNil(t, intr.ReadResult, "ReadResult should not be nil") + rr := expr.AsObject(intr.ReadResult.Type) + require.NotNil(t, rr, "ReadResult should be an object") + assert.NotNil(t, rr.Attribute("baz"), "ReadResult should have a baz attribute") + + require.NotNil(t, intr.WriteResult, "WriteResult should not be nil") + wr := expr.AsObject(intr.WriteResult.Type) + require.NotNil(t, wr, "WriteResult should be an object") + assert.NotNil(t, wr.Attribute("qux"), "WriteResult should have a qux attribute") + }, + }, + "empty-name": { + func() { + Interceptor("", func() {}) + }, + func(t *testing.T, intr *expr.InterceptorExpr) { + assert.NotNil(t, eval.Context.Errors, "expected a validation error") + }, + }, + "duplicate-name": { + func() { + Interceptor("duplicate", func() {}) + Interceptor("duplicate", func() {}) + }, + func(t *testing.T, intr *expr.InterceptorExpr) { + if eval.Context.Errors == nil { + t.Error("expected a validation error, got none") + } + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + eval.Context = &eval.DSLContext{} + expr.Root = new(expr.RootExpr) + tc.DSL() + if len(expr.Root.Interceptors) > 0 { + tc.Assert(t, expr.Root.Interceptors[0]) + } + }) + } +} + +func TestServerInterceptor(t *testing.T) { + cases := map[string]struct { + DSL func() + Assert func(t *testing.T, svc *expr.ServiceExpr, err error) + }{ + "valid-reference": { + func() { + var testInterceptor = Interceptor("test", func() {}) + Service("Service", func() { + ServerInterceptor(testInterceptor) + }) + }, + func(t *testing.T, svc *expr.ServiceExpr, err error) { + require.NoError(t, err) + require.NotNil(t, svc) + require.Len(t, svc.ServerInterceptors, 1, "should have 1 server interceptor") + assert.Equal(t, "test", svc.ServerInterceptors[0].Name) + }, + }, + "valid-by-name": { + func() { + Interceptor("test", func() {}) + Service("Service", func() { + ServerInterceptor("test") + }) + }, + func(t *testing.T, svc *expr.ServiceExpr, err error) { + require.NoError(t, err) + require.NotNil(t, svc) + require.Len(t, svc.ServerInterceptors, 1, "should have 1 server interceptor") + assert.Equal(t, "test", svc.ServerInterceptors[0].Name) + }, + }, + "invalid-reference": { + func() { + Service("Service", func() { + ServerInterceptor(42) // Invalid type + }) + }, + func(t *testing.T, svc *expr.ServiceExpr, err error) { + require.Error(t, err) + }, + }, + "invalid-name": { + func() { + Service("Service", func() { + ServerInterceptor("invalid") + }) + }, + func(t *testing.T, svc *expr.ServiceExpr, err error) { + require.Error(t, err) + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + eval.Context = &eval.DSLContext{} + expr.Root = new(expr.RootExpr) + tc.DSL() + root, err := runDSL(t, tc.DSL) + tc.Assert(t, root.Services[0], err) + }) + } +} + +func TestClientInterceptor(t *testing.T) { + cases := map[string]struct { + DSL func() + Assert func(t *testing.T, svc *expr.ServiceExpr, err error) + }{ + "valid-reference": { + func() { + var testInterceptor = Interceptor("test", func() {}) + Service("Service", func() { + ClientInterceptor(testInterceptor) + }) + }, + func(t *testing.T, svc *expr.ServiceExpr, err error) { + require.NoError(t, err) + require.NotNil(t, svc) + require.Len(t, svc.ClientInterceptors, 1, "should have 1 client interceptor") + }, + }, + "valid-by-name": { + func() { + Interceptor("test", func() {}) + Service("Service", func() { + ClientInterceptor("test") + }) + }, + func(t *testing.T, svc *expr.ServiceExpr, err error) { + require.NoError(t, err) + require.NotNil(t, svc) + require.Len(t, svc.ClientInterceptors, 1, "should have 1 client interceptor") + }, + }, + "invalid-reference": { + func() { + Service("Service", func() { + ClientInterceptor(42) // Invalid type + }) + }, + func(t *testing.T, svc *expr.ServiceExpr, err error) { + require.Error(t, err) + }, + }, + "invalid-name": { + func() { + Service("Service", func() { + ClientInterceptor("invalid") + }) + }, + func(t *testing.T, svc *expr.ServiceExpr, err error) { + require.Error(t, err) + }, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + eval.Context = &eval.DSLContext{} + expr.Root = new(expr.RootExpr) + tc.DSL() + root, err := runDSL(t, tc.DSL) + tc.Assert(t, root.Services[0], err) + }) + } +} + +// runDSL returns the DSL root resulting from running the given DSL. +func runDSL(t *testing.T, dsl func()) (*expr.RootExpr, error) { + t.Helper() + eval.Reset() + expr.Root = new(expr.RootExpr) + expr.GeneratedResultTypes = new(expr.ResultTypesRoot) + require.NoError(t, eval.Register(expr.Root)) + require.NoError(t, eval.Register(expr.GeneratedResultTypes)) + expr.Root.API = expr.NewAPIExpr("test api", func() {}) + expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()} + if eval.Execute(dsl, nil) { + return expr.Root, eval.RunDSL() + } else { + return expr.Root, errors.New(eval.Context.Error()) + } +} diff --git a/dsl/meta.go b/dsl/meta.go index 05f917f851..60ffb196f4 100644 --- a/dsl/meta.go +++ b/dsl/meta.go @@ -288,3 +288,42 @@ func Meta(name string, value ...string) { eval.IncompatibleDSL() } } + +// RemoveMeta removes a meta key from an object. +// +// RemoveMeta may appear where Meta can appear. +// +// RemoveMeta takes a single argument, the name of the meta key to remove. +func RemoveMeta(name string) { + switch e := eval.Current().(type) { + case *expr.APIExpr: + delete(e.Meta, name) + case *expr.ServerExpr: + delete(e.Meta, name) + case *expr.HostExpr: + delete(e.Meta, name) + case *expr.AttributeExpr: + delete(e.Meta, name) + case *expr.ResultTypeExpr: + delete(e.Meta, name) + case *expr.MethodExpr: + delete(e.Meta, name) + case *expr.ServiceExpr: + delete(e.Meta, name) + case *expr.HTTPServiceExpr: + delete(e.Meta, name) + case *expr.HTTPEndpointExpr: + delete(e.Meta, name) + case *expr.RouteExpr: + delete(e.Meta, name) + case *expr.HTTPFileServerExpr: + delete(e.Meta, name) + case *expr.HTTPResponseExpr: + delete(e.Meta, name) + case expr.CompositeExpr: + att := e.Attribute() + delete(att.Meta, name) + default: + eval.IncompatibleDSL() + } +} diff --git a/expr/api.go b/expr/api.go index 01ed202620..7d77fb975e 100644 --- a/expr/api.go +++ b/expr/api.go @@ -36,6 +36,10 @@ type ( // potentially multiple schemes. Incoming requests must validate // at least one requirement to be authorized. Requirements []*SecurityExpr + // ClientInterceptors is the list of API client interceptors. + ClientInterceptors []*InterceptorExpr + // ServerInterceptors is the list of API server interceptors. + ServerInterceptors []*InterceptorExpr // HTTP contains the HTTP specific API level expressions. HTTP *HTTPExpr // GRPC contains the gRPC specific API level expressions. diff --git a/expr/interceptor.go b/expr/interceptor.go new file mode 100644 index 0000000000..538d85463f --- /dev/null +++ b/expr/interceptor.go @@ -0,0 +1,84 @@ +package expr + +import ( + "goa.design/goa/v3/eval" +) + +type ( + // InterceptorExpr describes an interceptor definition in the design. + // Interceptors are used to inject user code into the request/response processing pipeline. + // There are four kinds of interceptors, in order of execution: + // * client-side payload: executes after the payload is encoded and before the request is sent to the server + // * server-side request: executes after the request is decoded and before the payload is sent to the service + // * server-side result: executes after the service returns a result and before the response is encoded + // * client-side response: executes after the response is decoded and before the result is sent to the client + InterceptorExpr struct { + // Name is the name of the interceptor + Name string + // Description is the optional description of the interceptor + Description string + // ReadPayload lists the payload attribute names read by the interceptor + ReadPayload *AttributeExpr + // WritePayload lists the payload attribute names written by the interceptor + WritePayload *AttributeExpr + // ReadResult lists the result attribute names read by the interceptor + ReadResult *AttributeExpr + // WriteResult lists the result attribute names written by the interceptor + WriteResult *AttributeExpr + } +) + +// EvalName returns the generic expression name used in error messages. +func (i *InterceptorExpr) EvalName() string { + return "interceptor " + i.Name +} + +// validate validates the interceptor. +func (i *InterceptorExpr) validate(m *MethodExpr) *eval.ValidationErrors { + verr := new(eval.ValidationErrors) + + if i.ReadPayload != nil || i.WritePayload != nil { + payloadObj := AsObject(m.Payload.Type) + if payloadObj == nil { + verr.Add(m, "interceptor %q cannot be applied because the method payload is not an object", i.Name) + } + if i.ReadPayload != nil { + i.validateAttributeAccess(m, "read payload", verr, payloadObj, i.ReadPayload) + } + if i.WritePayload != nil { + i.validateAttributeAccess(m, "write payload", verr, payloadObj, i.WritePayload) + } + } + + if i.ReadResult != nil || i.WriteResult != nil { + if m.IsResultStreaming() { + verr.Add(m, "interceptor %q cannot be applied because the method result is streaming", i.Name) + } + resultObj := AsObject(m.Result.Type) + if resultObj == nil { + verr.Add(m, "interceptor %q cannot be applied because the method result is not an object", i.Name) + } + if i.ReadResult != nil { + i.validateAttributeAccess(m, "read result", verr, resultObj, i.ReadResult) + } + if i.WriteResult != nil { + i.validateAttributeAccess(m, "write result", verr, resultObj, i.WriteResult) + } + } + + return verr +} + +// validateAttributeAccess validates that all attributes in attr exist in obj +func (i *InterceptorExpr) validateAttributeAccess(m *MethodExpr, source string, verr *eval.ValidationErrors, obj *Object, attr *AttributeExpr) { + attrObj := AsObject(attr.Type) + if attrObj == nil { + verr.Add(m, "interceptor %q %s attribute is not an object", i.Name, source) + return + } + for _, att := range *attrObj { + if obj.Attribute(att.Name) == nil { + verr.Add(m, "interceptor %q cannot %s attribute %q: attribute does not exist", i.Name, source, att.Name) + } + } +} diff --git a/expr/method.go b/expr/method.go index 3ba80728bd..94e3e59356 100644 --- a/expr/method.go +++ b/expr/method.go @@ -32,6 +32,10 @@ type ( // schemes. Incoming requests must validate at least one // requirement to be authorized. Requirements []*SecurityExpr + // ClientInterceptors is the list of client interceptors. + ClientInterceptors []*InterceptorExpr + // ServerInterceptors is the list of server interceptors. + ServerInterceptors []*InterceptorExpr // Service that owns method. Service *ServiceExpr // Meta is an arbitrary set of key/value pairs, see dsl.Meta @@ -84,7 +88,8 @@ func (m *MethodExpr) EvalName() string { } // Prepare makes sure the payload and result types are initialized (to the Empty -// type if nil). +// type if nil) and merges the method interceptors with the API and service level +// interceptors. func (m *MethodExpr) Prepare() { if m.Payload == nil { m.Payload = &AttributeExpr{Type: Empty} @@ -95,13 +100,57 @@ func (m *MethodExpr) Prepare() { if m.Result == nil { m.Result = &AttributeExpr{Type: Empty} } + + m.ClientInterceptors = mergeInterceptors(m.ClientInterceptors, m.Service.ClientInterceptors, Root.API.ClientInterceptors) + m.ServerInterceptors = mergeInterceptors(m.ServerInterceptors, m.Service.ServerInterceptors, Root.API.ServerInterceptors) +} + +// mergeInterceptors merges interceptors from different levels (method, service, API) +// while avoiding duplicates. The order of precedence is: method > service > API. +func mergeInterceptors(methodLevel, serviceLevel, apiLevel []*InterceptorExpr) []*InterceptorExpr { + existing := make(map[string]struct{}) + result := make([]*InterceptorExpr, 0, len(methodLevel)+len(serviceLevel)+len(apiLevel)) + + // Add method-level interceptors + for _, i := range methodLevel { + existing[i.Name] = struct{}{} + result = append(result, i) + } + + // Add service-level interceptors + for _, i := range serviceLevel { + if _, ok := existing[i.Name]; !ok { + result = append(result, i) + existing[i.Name] = struct{}{} + } + } + + // Add API-level interceptors + for _, i := range apiLevel { + if _, ok := existing[i.Name]; !ok { + result = append(result, i) + } + } + + return result } -// Validate validates the method payloads, results, and errors (if any). +// Validate validates the method payloads, results, errors, security +// requirements, and interceptors. func (m *MethodExpr) Validate() error { verr := new(eval.ValidationErrors) verr.Merge(m.Payload.Validate("payload", m)) - // validate security scheme requirements + verr.Merge(m.StreamingPayload.Validate("streaming_payload", m)) + verr.Merge(m.Result.Validate("result", m)) + verr.Merge(m.validateRequirements()) + verr.Merge(m.validateErrors()) + verr.Merge(m.validateInterceptors()) + return verr +} + +// validateRequirements validates the security requirements. +func (m *MethodExpr) validateRequirements() *eval.ValidationErrors { + verr := new(eval.ValidationErrors) var requirements []*SecurityExpr if len(m.Requirements) > 0 { requirements = m.Requirements @@ -185,12 +234,12 @@ func (m *MethodExpr) Validate() error { verr.Add(m, "payload of method %q of service %q defines a OAuth2 access token attribute, but no OAuth2 security scheme exist", m.Name, m.Service.Name) } } - if m.StreamingPayload.Type != Empty { - verr.Merge(m.StreamingPayload.Validate("streaming_payload", m)) - } - if m.Result.Type != Empty { - verr.Merge(m.Result.Validate("result", m)) - } + return verr +} + +// validateErrors validates the method errors. +func (m *MethodExpr) validateErrors() *eval.ValidationErrors { + verr := new(eval.ValidationErrors) for i, e := range m.Errors { if err := e.Validate(); err != nil { var verrs *eval.ValidationErrors @@ -220,6 +269,18 @@ func (m *MethodExpr) Validate() error { return verr } +// validateInterceptors validates the method interceptors. +func (m *MethodExpr) validateInterceptors() *eval.ValidationErrors { + verr := new(eval.ValidationErrors) + for _, i := range m.ClientInterceptors { + verr.Merge(i.validate(m)) + } + for _, i := range m.ServerInterceptors { + verr.Merge(i.validate(m)) + } + return verr +} + // hasTag is a helper function that traverses the given attribute and all its // bases recursively looking for an attribute with the given tag meta. This // recursion is only needed for attributes that have not been finalized yet. diff --git a/expr/root.go b/expr/root.go index 0240e54e1f..fb215c40f9 100644 --- a/expr/root.go +++ b/expr/root.go @@ -18,6 +18,8 @@ type ( API *APIExpr // Services contains the list of services exposed by the API. Services []*ServiceExpr + // Interceptors contains the list of interceptors. + Interceptors []*InterceptorExpr // Errors contains the list of errors returned by all the API // methods. Errors []*ErrorExpr diff --git a/expr/service.go b/expr/service.go index 1a7d88e9cf..ab906c4688 100644 --- a/expr/service.go +++ b/expr/service.go @@ -27,6 +27,10 @@ type ( // potentially multiple schemes. Incoming requests must validate // at least one requirement to be authorized. Requirements []*SecurityExpr + // ClientInterceptors is the list of client interceptors. + ClientInterceptors []*InterceptorExpr + // ServerInterceptors is the list of server interceptors. + ServerInterceptors []*InterceptorExpr // Meta is a set of key/value pairs with semantic that is // specific to each generator. Meta MetaExpr diff --git a/go.mod b/go.mod index fae2c06274..4e6ce461c3 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,8 @@ module goa.design/goa/v3 -go 1.22.0 -toolchain go1.23.1 +go 1.22.7 + +toolchain go1.23.3 require ( github.com/dimfeld/httppath v0.0.0-20170720192232-ee938bf73598 @@ -15,7 +16,7 @@ require ( golang.org/x/text v0.20.0 golang.org/x/tools v0.27.0 google.golang.org/grpc v1.68.0 - google.golang.org/protobuf v1.35.2 + google.golang.org/protobuf v1.35.1 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 86c6271fd4..35afe7f9ff 100644 --- a/go.sum +++ b/go.sum @@ -64,8 +64,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/grpc v1.68.0 h1:aHQeeJbo8zAkAa3pRzrVjZlbz6uSfeOXlJNQM0RAbz0= google.golang.org/grpc v1.68.0/go.mod h1:fmSPC5AsjSBCK54MyHRx48kpOti1/jRfOlwEWywNjWA= -google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= -google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= +google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/http/codegen/service_data.go b/http/codegen/service_data.go index c1818960c9..b570c4828b 100644 --- a/http/codegen/service_data.go +++ b/http/codegen/service_data.go @@ -595,7 +595,7 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { scope := codegen.NewNameScope() scope.Unique("c") // 'c' is reserved as the client's receiver name. scope.Unique("v") // 'v' is reserved as the request builder payload argument name. - rd := &ServiceData{ + sd := &ServiceData{ Service: svc, ServerStruct: "Server", MountPointStruct: "MountPoint", @@ -641,7 +641,7 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { VarName: scope.Unique(codegen.Goify(s.FilePath, true)), ArgName: scope.Unique(fmt.Sprintf("fileSystem%s", codegen.Goify(s.FilePath, true))), } - rd.FileServers = append(rd.FileServers, data) + sd.FileServers = append(sd.FileServers, data) } for _, httpEndpoint := range httpSvc.HTTPEndpoints { @@ -672,10 +672,10 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { // Path params may override requiredness, need to check payload. pointer = httpEndpoint.MethodExpr.Payload.IsPrimitivePointer(arg, true) } - name := rd.Scope.Name(codegen.Goify(arg, false)) + name := sd.Scope.Name(codegen.Goify(arg, false)) var vcode string if att.Validation != nil { - ctx := httpContext("", rd.Scope, true, false) + ctx := httpContext("", sd.Scope, true, false) vcode = codegen.AttributeValidationCode(att, nil, ctx, true, expr.IsAlias(att.Type), name, arg) } initArgs[j] = &InitArgData{ @@ -686,8 +686,8 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { Description: att.Description, FieldName: codegen.Goify(arg, true), FieldType: patt.Type, - TypeName: rd.Scope.GoTypeName(att), - TypeRef: rd.Scope.GoTypeRef(att), + TypeName: sd.Scope.GoTypeName(att), + TypeRef: sd.Scope.GoTypeRef(att), Type: att.Type, Pointer: pointer, Required: true, @@ -727,7 +727,7 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { } } - payload := buildPayloadData(httpEndpoint, rd) + payload := buildPayloadData(httpEndpoint, sd) var ( reqs service.RequirementsData @@ -817,8 +817,8 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { ServiceVarName: svc.VarName, ServicePkgName: svc.PkgName, Payload: payload, - Result: buildResultData(httpEndpoint, rd), - Errors: buildErrorsData(httpEndpoint, rd), + Result: buildResultData(httpEndpoint, sd), + Errors: buildErrorsData(httpEndpoint, sd), HeaderSchemes: hsch, BodySchemes: bosch, QuerySchemes: qsch, @@ -837,7 +837,7 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { Requirements: reqs, } if httpEndpoint.MethodExpr.IsStreaming() { - initWebSocketData(ed, httpEndpoint, rd) + initWebSocketData(ed, httpEndpoint, sd) } if httpEndpoint.MultipartRequest { @@ -870,26 +870,26 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { } } - rd.Endpoints = append(rd.Endpoints, ed) + sd.Endpoints = append(sd.Endpoints, ed) } for _, a := range httpSvc.HTTPEndpoints { collectUserTypes(a.Body.Type, func(ut expr.UserType) { - if d := attributeTypeData(ut, true, true, true, rd); d != nil { - rd.ServerBodyAttributeTypes = append(rd.ServerBodyAttributeTypes, d) + if d := attributeTypeData(ut, true, true, true, sd); d != nil { + sd.ServerBodyAttributeTypes = append(sd.ServerBodyAttributeTypes, d) } - if d := attributeTypeData(ut, true, false, false, rd); d != nil { - rd.ClientBodyAttributeTypes = append(rd.ClientBodyAttributeTypes, d) + if d := attributeTypeData(ut, true, false, false, sd); d != nil { + sd.ClientBodyAttributeTypes = append(sd.ClientBodyAttributeTypes, d) } }) if a.MethodExpr.StreamingPayload.Type != expr.Empty { collectUserTypes(a.StreamingBody.Type, func(ut expr.UserType) { - if d := attributeTypeData(ut, true, true, true, rd); d != nil { - rd.ServerBodyAttributeTypes = append(rd.ServerBodyAttributeTypes, d) + if d := attributeTypeData(ut, true, true, true, sd); d != nil { + sd.ServerBodyAttributeTypes = append(sd.ServerBodyAttributeTypes, d) } - if d := attributeTypeData(ut, true, false, false, rd); d != nil { - rd.ClientBodyAttributeTypes = append(rd.ClientBodyAttributeTypes, d) + if d := attributeTypeData(ut, true, false, false, sd); d != nil { + sd.ClientBodyAttributeTypes = append(sd.ClientBodyAttributeTypes, d) } }) } @@ -900,8 +900,8 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { // NOTE: ServerBodyAttributeTypes for response body types are // collected in buildResponseBodyType because we have to generate // body types for each view in a result type. - if d := attributeTypeData(ut, false, true, false, rd); d != nil { - rd.ClientBodyAttributeTypes = append(rd.ClientBodyAttributeTypes, d) + if d := attributeTypeData(ut, false, true, false, sd); d != nil { + sd.ClientBodyAttributeTypes = append(sd.ClientBodyAttributeTypes, d) } }) } @@ -912,14 +912,14 @@ func (ServicesData) analyze(httpSvc *expr.HTTPServiceExpr) *ServiceData { // NOTE: ServerBodyAttributeTypes for error response body types are // collected in buildResponseBodyType because we have to generate // body types for each view in a result type. - if d := attributeTypeData(ut, false, true, false, rd); d != nil { - rd.ClientBodyAttributeTypes = append(rd.ClientBodyAttributeTypes, d) + if d := attributeTypeData(ut, false, true, false, sd); d != nil { + sd.ClientBodyAttributeTypes = append(sd.ClientBodyAttributeTypes, d) } }) } } - return rd + return sd } // makeHTTPType traverses the attribute recursively and performs these actions: diff --git a/pkg/interceptor.go b/pkg/interceptor.go new file mode 100644 index 0000000000..022d7a4de6 --- /dev/null +++ b/pkg/interceptor.go @@ -0,0 +1,22 @@ +package goa + +import "context" + +type ( + // InterceptorInfo contains information about the request shared between + // all interceptors in the service chain. It provides access to the service name, + // method name, endpoint function, and request payload. + InterceptorInfo struct { + // Name of service handling request + Service string + // Name of method handling request + Method string + // Endpoint of request, can be used for retrying + Endpoint Endpoint + // Payload of request + RawPayload any + } + + // NextFunc is a function that will continue the request processing chain. + NextFunc func(ctx context.Context) (any, error) +)