diff --git a/dataplane/standalone/apigen/apigen.go b/dataplane/standalone/apigen/apigen.go index e64e5ba1..e91ecc3b 100644 --- a/dataplane/standalone/apigen/apigen.go +++ b/dataplane/standalone/apigen/apigen.go @@ -196,52 +196,10 @@ func generate() error { return err } - protoData := &protoApiData{ - apis: make(map[string]*protoAPITmplData), - docInfo: xmlInfo, - common: &protoCommonTmplData{ - Enums: map[string]*protoEnum{}, - Lists: map[string]*protoTmplMessage{}, - }, - } - - for _, typeInfo := range saiTypeToProto { - if typeInfo.MessageDef != "" { - protoData.common.Messages = append(protoData.common.Messages, typeInfo.MessageDef) - } - } - for name, vals := range protoData.docInfo.enums { - protoName := trimSAIName(name, true, false) - unspecifiedName := trimSAIName(name, false, true) + "_UNSPECIFIED" - enum := &protoEnum{ - Name: protoName, - Values: []protoEnumValues{{Index: 0, Name: unspecifiedName}}, - } - for i, val := range vals { - if strings.TrimPrefix(val, "SAI_") == unspecifiedName { - continue - } - enum.Values = append(enum.Values, protoEnumValues{ - Index: i + 1, - Name: strings.TrimPrefix(val, "SAI_"), - }) - } - protoData.common.Enums[protoName] = enum - } - for _, attr := range protoData.docInfo.attrs { - for _, f := range attr.setFields { - name, isRepeated, err := saiTypeToProtoType(f.SaiType, protoData.docInfo, true) - if err != nil { - return err - } - if !isRepeated { - continue - } - msg := &protoTmplMessage{ - Name: name, - } - protoData.common.Lists[name] = msg - } + apis := make(map[string]*protoAPITmplData) + common, err := populateCommonTypes(xmlInfo) + if err != nil { + return err } for _, iface := range sai.ifaces { @@ -256,7 +214,7 @@ func generate() error { tf, isSwitchScoped, entry := createCCData(sai, fn) ccData.Funcs = append(ccData.Funcs, *tf) - err := populateTmplDataFromFunc(protoData, tf.Name, entry, tf.Operation, tf.TypeName, iface.name, isSwitchScoped) + err := populateTmplDataFromFunc(apis, xmlInfo, tf.Name, entry, tf.Operation, tf.TypeName, iface.name, isSwitchScoped) if err != nil { return err } @@ -280,7 +238,7 @@ func generate() error { if err := ccTmpl.Execute(impl, ccData); err != nil { return err } - if err := protoTmpl.Execute(proto, protoData.apis[iface.name]); err != nil { + if err := protoTmpl.Execute(proto, apis[iface.name]); err != nil { return err } } @@ -289,7 +247,7 @@ func generate() error { return err } - if err := protoCommonTmpl.Execute(protoCommonFile, protoData.common); err != nil { + if err := protoCommonTmpl.Execute(protoCommonFile, common); err != nil { return err } return nil diff --git a/dataplane/standalone/apigen/protogen.go b/dataplane/standalone/apigen/protogen.go index fcdc9da9..5be81c2e 100644 --- a/dataplane/standalone/apigen/protogen.go +++ b/dataplane/standalone/apigen/protogen.go @@ -23,10 +23,70 @@ import ( strcase "github.com/stoewer/go-strcase" ) +// populateCommonTypes fills the templates for all types that aren't attributes. +// These all reside in the common.proto file to simplify handling imports. +func populateCommonTypes(docInfo *protoGenInfo) (*protoCommonTmplData, error) { + common := &protoCommonTmplData{ + Enums: map[string]*protoEnum{}, + Lists: map[string]*protoTmplMessage{}, + } + // Generate the hand-crafted messages. + for _, typeInfo := range saiTypeToProto { + if typeInfo.MessageDef != "" { + common.Messages = append(common.Messages, typeInfo.MessageDef) + } + } + // Generate non-attribute enums. + for name, vals := range docInfo.enums { + protoName := trimSAIName(name, true, false) + unspecifiedName := trimSAIName(name, false, true) + "_UNSPECIFIED" + enum := &protoEnum{ + Name: protoName, + Values: []protoEnumValues{{Index: 0, Name: unspecifiedName}}, + } + for i, val := range vals { + if strings.TrimPrefix(val, "SAI_") == unspecifiedName { + continue + } + enum.Values = append(enum.Values, protoEnumValues{ + Index: i + 1, + Name: strings.TrimPrefix(val, "SAI_"), + }) + } + common.Enums[protoName] = enum + } + // Find all the repeated fields that appear in oneof and generate a list wrapper type. + for _, attr := range docInfo.attrs { + for _, f := range attr.setFields { + msgName, isRepeated, err := saiTypeToProtoType(f.SaiType, docInfo, true) + if err != nil { + return nil, err + } + if !isRepeated { + continue + } + repeatedName, _, err := saiTypeToProtoType(f.SaiType, docInfo, false) + if err != nil { + return nil, err + } + msg := &protoTmplMessage{ + Name: msgName, + Fields: []protoTmplField{{ + Index: 1, + Name: "list", + ProtoType: repeatedName, + }}, + } + common.Lists[msgName] = msg + } + } + return common, nil +} + // populateTmplDataFromFunc populatsd the protobuf template struct from a SAI function call. -func populateTmplDataFromFunc(protoData *protoApiData, funcName, entryType, operation, typeName, apiName string, isSwitchScoped bool) error { - if _, ok := protoData.apis[apiName]; !ok { - protoData.apis[apiName] = &protoAPITmplData{ +func populateTmplDataFromFunc(apis map[string]*protoAPITmplData, docInfo *protoGenInfo, funcName, entryType, operation, typeName, apiName string, isSwitchScoped bool) error { + if _, ok := apis[apiName]; !ok { + apis[apiName] = &protoAPITmplData{ Enums: make(map[string]protoEnum), ServiceName: trimSAIName(apiName, true, false), } @@ -66,7 +126,7 @@ func populateTmplDataFromFunc(protoData *protoApiData, funcName, entryType, oper req.Fields = append(req.Fields, idField) } requestIdx++ - attrs, err := createAttrs(requestIdx, protoData.docInfo, protoData.docInfo.attrs[typeName].createFields, false) + attrs, err := createAttrs(requestIdx, docInfo, docInfo.attrs[typeName].createFields, false) if err != nil { return err } @@ -79,13 +139,13 @@ func populateTmplDataFromFunc(protoData *protoApiData, funcName, entryType, oper }) case "set_attribute": // If there are no settable attributes, do nothing. - if len(protoData.docInfo.attrs[typeName].setFields) == 0 { + if len(docInfo.attrs[typeName].setFields) == 0 { return nil } req.Fields = append(req.Fields, idField) req.AttrsWrapperStart = "oneof attr {" req.AttrsWrapperEnd = "}" - attrs, err := createAttrs(2, protoData.docInfo, protoData.docInfo.attrs[typeName].setFields, true) + attrs, err := createAttrs(2, docInfo, docInfo.attrs[typeName].setFields, true) if err != nil { return err } @@ -104,15 +164,15 @@ func populateTmplDataFromFunc(protoData *protoApiData, funcName, entryType, oper } // For the attributes, generate code for the type if needed. - for i, attr := range protoData.docInfo.attrs[typeName].readFields { + for i, attr := range docInfo.attrs[typeName].readFields { attrEnum.Values = append(attrEnum.Values, protoEnumValues{ Index: i + 1, Name: strings.TrimPrefix(attr.EnumName, "SAI_"), }) } - protoData.apis[apiName].Enums[attrEnum.Name] = attrEnum + apis[apiName].Enums[attrEnum.Name] = attrEnum - attrs, err := createAttrs(1, protoData.docInfo, protoData.docInfo.attrs[typeName].readFields, false) + attrs, err := createAttrs(1, docInfo, docInfo.attrs[typeName].readFields, false) if err != nil { return err } @@ -120,8 +180,8 @@ func populateTmplDataFromFunc(protoData *protoApiData, funcName, entryType, oper default: return nil } - protoData.apis[apiName].Messages = append(protoData.apis[apiName].Messages, *req, *resp) - protoData.apis[apiName].RPCs = append(protoData.apis[apiName].RPCs, protoRPC{ + apis[apiName].Messages = append(apis[apiName].Messages, *req, *resp) + apis[apiName].RPCs = append(apis[apiName].RPCs, protoRPC{ RequestName: req.Name, ResponseName: resp.Name, Name: strcase.UpperCamelCase(funcName), @@ -132,11 +192,12 @@ func populateTmplDataFromFunc(protoData *protoApiData, funcName, entryType, oper func createAttrs(startIdx int, xmlInfo *protoGenInfo, attrs []attrTypeName, inOneof bool) ([]protoTmplField, error) { fields := []protoTmplField{} for _, attr := range attrs { - // Function pointers are attempted as streaming RPCs. + // Function pointers are implemented as streaming RPCs instead of settable attributes. + // TODO: Implement these. if strings.Contains(attr.SaiType, "sai_pointer_t") { continue } - // Proto field names can't beging with numbers, prepend _. + // Proto field names can't begin with numbers, prepend _. name := attr.MemberName if unicode.IsDigit(rune(attr.MemberName[0])) { name = fmt.Sprintf("_%s", name) @@ -223,13 +284,6 @@ message {{ .Name }} { `)) ) -// protoApiData contains the input and output for protobuf generation. -type protoApiData struct { - docInfo *protoGenInfo - apis map[string]*protoAPITmplData - common *protoCommonTmplData -} - // protoAPITmplData contains the formated information needed to render the protobuf template. type protoAPITmplData struct { Messages []protoTmplMessage