diff --git a/example/message.proto b/example/message.proto index e492991..e1c8959 100644 --- a/example/message.proto +++ b/example/message.proto @@ -34,6 +34,7 @@ message GetItemsResponse { ItemType type = 2 [(google.api.field_behavior) = REQUIRED]; string name = 3 [(google.api.field_behavior) = REQUIRED]; google.protobuf.Timestamp created_at = 4 [(google.api.field_behavior) = REQUIRED]; + repeated NestedItem nested_item = 5; } repeated Item items = 1 [(google.api.field_behavior) = REQUIRED]; diff --git a/example/openapi.yaml b/example/openapi.yaml index 4ee548e..61ca7b4 100644 --- a/example/openapi.yaml +++ b/example/openapi.yaml @@ -173,6 +173,10 @@ components: createdAt: type: string format: date-time + nestedItem: + type: array + items: + $ref: '#/components/schemas/GetItemsResponse.NestedItem' required: - id - type @@ -202,6 +206,8 @@ components: enum: - ITEM_TYPE_UNSPECIFIED - ITEM_TYPE_BASIC + NestedEnum: + type: object NestedEnum.ItemType: type: string enum: diff --git a/internal/gen/generator.go b/internal/gen/generator.go index 45115cf..8488c8c 100644 --- a/internal/gen/generator.go +++ b/internal/gen/generator.go @@ -70,6 +70,12 @@ func NewGenerator(files []*protogen.File, opts ...GeneratorOption) (*Generator, } } } + } + + for _, f := range files { + if !f.Generate { + continue + } for _, m := range f.Messages { name := descriptorName(m.Desc) @@ -93,9 +99,10 @@ func NewGenerator(files []*protogen.File, opts ...GeneratorOption) (*Generator, // Generator instance. type Generator struct { - spec *ogen.Spec - indent int - requests map[string]struct{} + spec *ogen.Spec + indent int + requests map[string]struct{} + descriptorNames map[string]struct{} } // YAML returns OpenAPI specification bytes. @@ -121,6 +128,7 @@ func (g *Generator) init() { g.spec = ogen.NewSpec() g.spec.Init() g.requests = make(map[string]struct{}) + g.descriptorNames = make(map[string]struct{}) } func (g *Generator) mkMethod(rule HTTPRule, m *protogen.Method) (string, *ogen.Operation, error) { @@ -143,7 +151,7 @@ func (g *Generator) mkMethod(rule HTTPRule, m *protogen.Method) (string, *ogen.O func (g *Generator) mkInput(rule HTTPRule, m *protogen.Method, op *ogen.Operation) (string, error) { name := descriptorName(m.Input.Desc) - g.requests[name] = struct{}{} + g.setRequest(name) var ( fields = collectFields(m.Input) @@ -403,8 +411,29 @@ func (g *Generator) hasSchema(s string) bool { return ok } -func (g *Generator) hasRequest(r string) bool { - _, ok := g.requests[r] +func (g *Generator) setRequest(s string) { + if g.hasRequest(s) { + return + } + + g.requests[s] = struct{}{} +} + +func (g *Generator) hasRequest(s string) bool { + _, ok := g.requests[s] + return ok +} + +func (g *Generator) setDescriptorName(s string) { + if g.hasDescriptorName(s) { + return + } + + g.descriptorNames[s] = struct{}{} +} + +func (g *Generator) hasDescriptorName(s string) bool { + _, ok := g.descriptorNames[s] return ok } diff --git a/internal/gen/schema.go b/internal/gen/schema.go index 2bb1336..a699e91 100644 --- a/internal/gen/schema.go +++ b/internal/gen/schema.go @@ -59,6 +59,15 @@ func (g *Generator) mkSchema(msg *protogen.Message) error { } if field.Message != nil { + name := descriptorName(field.Desc) + if g.hasDescriptorName(name) { + s.SetRef(descriptorRef(field.Message.Desc)) + + continue + } + + g.setDescriptorName(name) + if err := g.mkSchema(field.Message); err != nil { return err }