diff --git a/internal/examples/protoc-gen-protogen-simple/main.go b/internal/examples/protoc-gen-protogen-simple/main.go index aed54fa..a694a61 100644 --- a/internal/examples/protoc-gen-protogen-simple/main.go +++ b/internal/examples/protoc-gen-protogen-simple/main.go @@ -37,7 +37,11 @@ func handle( responseWriter *protoplugin.ResponseWriter, request *protoplugin.Request, ) error { - plugin, err := protogen.Options{}.New(request.CodeGeneratorRequest()) + codeGeneratorRequest, err := request.CodeGeneratorRequest() + if err != nil { + return err + } + plugin, err := protogen.Options{}.New(codeGeneratorRequest) if err != nil { return err } diff --git a/internal/examples/protoc-gen-simple/main.go b/internal/examples/protoc-gen-simple/main.go index 38e7cc7..241234c 100644 --- a/internal/examples/protoc-gen-simple/main.go +++ b/internal/examples/protoc-gen-simple/main.go @@ -40,7 +40,11 @@ func handle( // plugin has not indicated it will support it. responseWriter.AddFeatureProto3Optional() - for _, fileDescriptorProto := range request.GenerateFileDescriptorProtos() { + fileDescriptorProtos, err := request.GenerateFileDescriptorProtos() + if err != nil { + return err + } + for _, fileDescriptorProto := range fileDescriptorProtos { topLevelMessageNames := make([]string, len(fileDescriptorProto.GetMessageType())) for i, descriptorProto := range fileDescriptorProto.GetMessageType() { topLevelMessageNames[i] = descriptorProto.GetName() diff --git a/request.go b/request.go index 90485a8..baa7546 100644 --- a/request.go +++ b/request.go @@ -15,9 +15,12 @@ package protoplugin import ( + "errors" + "fmt" "slices" "sync" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -27,7 +30,7 @@ import ( // Request wraps a CodeGeneratorRequest. // -// The backing CodeGeneratorRequest has been validated: +// The backing CodeGeneratorRequest has been validated to conform to the following: // // - The CodeGeneratorRequest will not be nil. // - FileToGenerate and ProtoFile will be non-empty. @@ -35,7 +38,8 @@ import ( // - Each value of FileToGenerate will be a valid path. // - Each value of FileToGenerate will have a corresponding value in ProtoFile. // - Each FileDescriptorProto in SourceFileDescriptors will have a valid path as the name field. -// - The values of FileToGenerate will have a 1-1 mapping to the names in SourceFileDescriptors. +// - SourceFileDescriptors is either empty, or the values of FileToGenerate will have a 1-1 mapping +// to the names in SourceFileDescriptors. // // Paths are considered valid if they are non-empty, relative, use '/' as the path separator, do not jump context, // and have `.proto` as the file extension. @@ -46,12 +50,21 @@ type Request struct { getSourceFileDescriptorNameToFileDescriptorProtoMap func() map[string]*descriptorpb.FileDescriptorProto } +// Parameter returns the value of the parameter field on the CodeGeneratorRequest. func (r *Request) Parameter() string { return r.codeGeneratorRequest.GetParameter() } // GenerateFileDescriptors returns the FileDescriptors for the files specified by the -// file_to_generate field. +// file_to_generate field on the CodeGeneratorRequest. +// +// If WithSourceRetentionOptions is specified and the source_file_descriptors field was +// not present on the CodeGeneratorRequest, an error is returned. +// +// The caller can assume that all FileDescriptors have a valid path as the name field. +// +// Paths are considered valid if they are non-empty, relative, use '/' as the path separator, do not jump context, +// and have `.proto` as the file extension. func (r *Request) GenerateFileDescriptors(options ...GenerateFileDescriptorsOption) ([]protoreflect.FileDescriptor, error) { requestFileOptions := newRequestFileOptions() for _, option := range options { @@ -73,6 +86,18 @@ func (r *Request) GenerateFileDescriptors(options ...GenerateFileDescriptorsOpti } // AllFiles returns the a Files registry for all files in the CodeGeneratorRequest. +// +// This matches with the proto_file field on the CodeGeneratorRequest, with the FileDescriptorProtos +// from the source_file_descriptors field used for the files in file_to_geneate if WithSourceRetentionOptions +// is specified. +// +// If WithSourceRetentionOptions is specified and the source_file_descriptors field was +// not present on the CodeGeneratorRequest, an error is returned. +// +// The caller can assume that all FileDescriptors have a valid path as the name field. +// +// Paths are considered valid if they are non-empty, relative, use '/' as the path separator, do not jump context, +// and have `.proto` as the file extension. func (r *Request) AllFiles(options ...AllFilesOption) (*protoregistry.Files, error) { requestFileOptions := newRequestFileOptions() for _, option := range options { @@ -83,7 +108,15 @@ func (r *Request) AllFiles(options ...AllFilesOption) (*protoregistry.Files, err // GenerateFileDescriptorProtos returns the FileDescriptors for the files specified by the // file_to_generate field. -func (r *Request) GenerateFileDescriptorProtos(options ...GenerateFileDescriptorProtosOption) []*descriptorpb.FileDescriptorProto { +// +// If WithSourceRetentionOptions is specified and the source_file_descriptors field was +// not present on the CodeGeneratorRequest, an error is returned. +// +// The caller can assume that all FileDescriptorProtoss have a valid path as the name field. +// +// Paths are considered valid if they are non-empty, relative, use '/' as the path separator, do not jump context, +// and have `.proto` as the file extension. +func (r *Request) GenerateFileDescriptorProtos(options ...GenerateFileDescriptorProtosOption) ([]*descriptorpb.FileDescriptorProto, error) { requestFileOptions := newRequestFileOptions() for _, option := range options { option.applyGenerateFileDescriptorProtosOption(requestFileOptions) @@ -91,8 +124,20 @@ func (r *Request) GenerateFileDescriptorProtos(options ...GenerateFileDescriptor return r.generateFileDescriptorProtos(requestFileOptions.sourceRetentionOptions) } -// AllFileDescriptorProtos returns the FileDescriptors for all files in the CodeGeneratorRequest. -func (r *Request) AllFileDescriptorProtos(options ...AllFileDescriptorProtosOption) []*descriptorpb.FileDescriptorProto { +// AllFileDescriptorProtos returns the FileDescriptorProtos for all files in the CodeGeneratorRequest. +// +// This matches with the proto_file field on the CodeGeneratorRequest, with the FileDescriptorProtos +// from the source_file_descriptors field used for the files in file_to_geneate if WithSourceRetentionOptions +// is specified. +// +// If WithSourceRetentionOptions is specified and the source_file_descriptors field was +// not present on the CodeGeneratorRequest, an error is returned. +// +// The caller can assume that all FileDescriptorProtoss have a valid path as the name field. +// +// Paths are considered valid if they are non-empty, relative, use '/' as the path separator, do not jump context, +// and have `.proto` as the file extension. +func (r *Request) AllFileDescriptorProtos(options ...AllFileDescriptorProtosOption) ([]*descriptorpb.FileDescriptorProto, error) { requestFileOptions := newRequestFileOptions() for _, option := range options { option.applyAllFileDescriptorProtosOption(requestFileOptions) @@ -100,7 +145,11 @@ func (r *Request) AllFileDescriptorProtos(options ...AllFileDescriptorProtosOpti return r.allFileDescriptorProtos(requestFileOptions.sourceRetentionOptions) } -// CompilerVersion returns the specified compiler version on the request, if it is present. +// CompilerVersion returns the specified compiler_version on the CodeGeneratorRequest. +// +// If the compiler_version field was not present, nil is returned. +// +// The caller can assume that the major, minor, and patch values are non-negative. func (r *Request) CompilerVersion() *CompilerVersion { if version := r.codeGeneratorRequest.GetCompilerVersion(); version != nil { return &CompilerVersion{ @@ -113,9 +162,15 @@ func (r *Request) CompilerVersion() *CompilerVersion { return nil } -// CodeGeneratorRequest returns the underlying CodeGeneratorRequest. -func (r *Request) CodeGeneratorRequest() *pluginpb.CodeGeneratorRequest { - return r.codeGeneratorRequest +// CodeGeneratorRequest returns the raw underlying CodeGeneratorRequest. +// +// The returned CodeGeneratorRequest is a copy - you can freely modiify it. +func (r *Request) CodeGeneratorRequest() (*pluginpb.CodeGeneratorRequest, error) { + clone, ok := proto.Clone(r.codeGeneratorRequest).(*pluginpb.CodeGeneratorRequest) + if !ok { + return nil, fmt.Errorf("proto.Clone on %T returned a %T", r.codeGeneratorRequest, clone) + } + return clone, nil } // *** PRIVATE *** @@ -135,17 +190,20 @@ func newRequest(codeGeneratorRequest *pluginpb.CodeGeneratorRequest) (*Request, } func (r *Request) allFiles(sourceRetentionOptions bool) (*protoregistry.Files, error) { - return protodesc.NewFiles( - &descriptorpb.FileDescriptorSet{ - File: r.allFileDescriptorProtos(sourceRetentionOptions), - }, - ) + fileDescriptorProtos, err := r.allFileDescriptorProtos(sourceRetentionOptions) + if err != nil { + return nil, err + } + return protodesc.NewFiles(&descriptorpb.FileDescriptorSet{File: fileDescriptorProtos}) } -func (r *Request) generateFileDescriptorProtos(sourceRetentionOptions bool) []*descriptorpb.FileDescriptorProto { +func (r *Request) generateFileDescriptorProtos(sourceRetentionOptions bool) ([]*descriptorpb.FileDescriptorProto, error) { // If we want source-retention options, source_file_descriptors is all we need. if sourceRetentionOptions { - return slices.Clone(r.codeGeneratorRequest.GetSourceFileDescriptors()) + if err := r.validateSourceFileDescriptorsPresent(); err != nil { + return nil, err + } + return slices.Clone(r.codeGeneratorRequest.GetSourceFileDescriptors()), nil } // Otherwise, we need to get the values in proto_file that are in file_to_generate. filesToGenerateMap := r.getFilesToGenerateMap() @@ -155,13 +213,16 @@ func (r *Request) generateFileDescriptorProtos(sourceRetentionOptions bool) []*d fileDescriptorProtos = append(fileDescriptorProtos, protoFile) } } - return fileDescriptorProtos + return fileDescriptorProtos, nil } -func (r *Request) allFileDescriptorProtos(sourceRetentionOptions bool) []*descriptorpb.FileDescriptorProto { +func (r *Request) allFileDescriptorProtos(sourceRetentionOptions bool) ([]*descriptorpb.FileDescriptorProto, error) { // If we do not want source-retention options, proto_file is all we need. if !sourceRetentionOptions { - return slices.Clone(r.codeGeneratorRequest.GetProtoFile()) + return slices.Clone(r.codeGeneratorRequest.GetProtoFile()), nil + } + if err := r.validateSourceFileDescriptorsPresent(); err != nil { + return nil, err } // Otherwise, we need to replace the values in proto_file that are in file_to_generate // with the values from source_file_descriptors. @@ -175,7 +236,15 @@ func (r *Request) allFileDescriptorProtos(sourceRetentionOptions bool) []*descri } fileDescriptorProtos[i] = protoFile } - return fileDescriptorProtos + return fileDescriptorProtos, nil +} + +func (r *Request) validateSourceFileDescriptorsPresent() error { + if len(r.codeGeneratorRequest.GetSourceFileDescriptors()) == 0 && + len(r.codeGeneratorRequest.GetProtoFile()) > 0 { + return errors.New("source_file_descriptors not set on CodeGeneratorRequest but source-retention options requested - you likely need to upgrade your protobuf compiler") + } + return nil } func (r *Request) getFilesToGenerateMapUncached() map[string]struct{} {