diff --git a/.golangci.yml b/.golangci.yml index d351035..d45ab38 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -73,3 +73,7 @@ issues: - forbidigo - varnamelen path: compatibility.go + - linters: + - dupl + - forcetypeassert + path: protopluginutil/source_retention_options_test.go diff --git a/go.mod b/go.mod index 7da5819..1bb9e9c 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/bufbuild/protocompile v0.9.0 + github.com/google/go-cmp v0.6.0 github.com/stretchr/testify v1.9.0 google.golang.org/protobuf v1.33.0 ) diff --git a/go.sum b/go.sum index 0b354d7..4224150 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,7 @@ github.com/bufbuild/protocompile v0.9.0/go.mod h1:s89m1O8CqSYpyE/YaSGtg1r1YFMF5n github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/protopluginutil/source_retention_options.go b/protopluginutil/source_retention_options.go new file mode 100644 index 0000000..9cd9483 --- /dev/null +++ b/protopluginutil/source_retention_options.go @@ -0,0 +1,562 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protopluginutil + +import ( + "fmt" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" +) + +const ( + fileMessagesTag = 4 + fileEnumsTag = 5 + fileServicesTag = 6 + fileExtensionsTag = 7 + fileOptionsTag = 8 + messageFieldsTag = 2 + messageNestedMessagesTag = 3 + messageEnumsTag = 4 + messageExtensionRangesTag = 5 + messageExtensionsTag = 6 + messageOptionsTag = 7 + messageOneofsTag = 8 + extensionRangeOptionsTag = 3 + fieldOptionsTag = 8 + oneofOptionsTag = 2 + enumValuesTag = 2 + enumOptionsTag = 3 + enumValOptionsTag = 3 + serviceMethodsTag = 2 + serviceOptionsTag = 3 + methodOptionsTag = 4 +) + +// StripSourceRetentionOptions returns a FileDescriptorProto that omits any source-retention options. + +// If the FileDescriptorProto has no source-retention options, the original FileDescriptorProto is returned. +// If the FileDescriptorProto has source-retention options, a new FileDescriptorProto is returned with +// the source-retention options stripped. +// +// Even when a copy is returned, it is not a deep copy: it may share data with the +// input FileDescriptorProto, and mutations to the returned FileDescriptorProto may impact +// the input FileDescriptorProto. +func StripSourceRetentionOptions(file *descriptorpb.FileDescriptorProto) (*descriptorpb.FileDescriptorProto, error) { + var path sourcePath + var removedPaths *sourcePathTrie + if file.GetSourceCodeInfo() != nil && len(file.GetSourceCodeInfo().GetLocation()) > 0 { + path = make(sourcePath, 0, 16) + removedPaths = &sourcePathTrie{} + } + var dirty bool + optionsPath := path.push(fileOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(file.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts != file.GetOptions() { + dirty = true + } + msgsPath := path.push(fileMessagesTag) + newMsgs, changed, err := stripOptionsFromAll(file.GetMessageType(), stripSourceRetentionOptionsFromMessage, msgsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + enumsPath := path.push(fileEnumsTag) + newEnums, changed, err := stripOptionsFromAll(file.GetEnumType(), stripSourceRetentionOptionsFromEnum, enumsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + extsPath := path.push(fileExtensionsTag) + newExts, changed, err := stripOptionsFromAll(file.GetExtension(), stripSourceRetentionOptionsFromField, extsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + svcsPath := path.push(fileServicesTag) + newSvcs, changed, err := stripOptionsFromAll(file.GetService(), stripSourceRetentionOptionsFromService, svcsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + + if !dirty { + return file, nil + } + + newFile, err := shallowCopy(file) + if err != nil { + return nil, err + } + newFile.Options = newOpts + newFile.MessageType = newMsgs + newFile.EnumType = newEnums + newFile.Extension = newExts + newFile.Service = newSvcs + newFile.SourceCodeInfo = stripSourcePathsForSourceRetentionOptions(newFile.GetSourceCodeInfo(), removedPaths) + return newFile, nil +} + +func stripSourceRetentionOptionsFromProtoMessage[M proto.Message]( + options M, + path sourcePath, + removedPaths *sourcePathTrie, +) (M, error) { + optionsRef := options.ProtoReflect() + // See if there are any options to strip. + var hasFieldToStrip bool + var numFieldsToKeep int + var err error + optionsRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { + fieldOpts, ok := field.Options().(*descriptorpb.FieldOptions) + if !ok { + err = fmt.Errorf("field options is unexpected type: got %T, want %T", field.Options(), fieldOpts) + return false + } + if fieldOpts.GetRetention() == descriptorpb.FieldOptions_RETENTION_SOURCE { + hasFieldToStrip = true + } else { + numFieldsToKeep++ + } + return true + }) + var zero M + if err != nil { + return zero, err + } + if !hasFieldToStrip { + return options, nil + } + + if numFieldsToKeep == 0 { + // Stripping the message would remove *all* options. In that case, + // we'll clear out the options by returning the zero value (i.e. nil). + removedPaths.addPath(path) // clear out all source locations, too + return zero, nil + } + + // There is at least one option to remove. So we need to make a copy that does not have those options. + newOptions := optionsRef.New() + ret, ok := newOptions.Interface().(M) + if !ok { + return zero, fmt.Errorf("creating new message of same type resulted in unexpected type; got %T, want %T", newOptions.Interface(), zero) + } + optionsRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { + fieldOpts, ok := field.Options().(*descriptorpb.FieldOptions) + if !ok { + err = fmt.Errorf("field options is unexpected type: got %T, want %T", field.Options(), fieldOpts) + return false + } + if fieldOpts.GetRetention() != descriptorpb.FieldOptions_RETENTION_SOURCE { + newOptions.Set(field, val) + } else { + removedPaths.addPath(path.push(int32(field.Number()))) + } + return true + }) + if err != nil { + return zero, err + } + return ret, nil +} + +func stripSourceRetentionOptionsFromMessage( + msg *descriptorpb.DescriptorProto, + path sourcePath, + removedPaths *sourcePathTrie, +) (*descriptorpb.DescriptorProto, error) { + var dirty bool + optionsPath := path.push(messageOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(msg.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts != msg.GetOptions() { + dirty = true + } + fieldsPath := path.push(messageFieldsTag) + newFields, changed, err := stripOptionsFromAll(msg.GetField(), stripSourceRetentionOptionsFromField, fieldsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + oneofsPath := path.push(messageOneofsTag) + newOneofs, changed, err := stripOptionsFromAll(msg.GetOneofDecl(), stripSourceRetentionOptionsFromOneof, oneofsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + extRangesPath := path.push(messageExtensionRangesTag) + newExtRanges, changed, err := stripOptionsFromAll(msg.GetExtensionRange(), stripSourceRetentionOptionsFromExtensionRange, extRangesPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + msgsPath := path.push(messageNestedMessagesTag) + newMsgs, changed, err := stripOptionsFromAll(msg.GetNestedType(), stripSourceRetentionOptionsFromMessage, msgsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + enumsPath := path.push(messageEnumsTag) + newEnums, changed, err := stripOptionsFromAll(msg.GetEnumType(), stripSourceRetentionOptionsFromEnum, enumsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + extsPath := path.push(messageExtensionsTag) + newExts, changed, err := stripOptionsFromAll(msg.GetExtension(), stripSourceRetentionOptionsFromField, extsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + + if !dirty { + return msg, nil + } + + newMsg, err := shallowCopy(msg) + if err != nil { + return nil, err + } + newMsg.Options = newOpts + newMsg.Field = newFields + newMsg.OneofDecl = newOneofs + newMsg.ExtensionRange = newExtRanges + newMsg.NestedType = newMsgs + newMsg.EnumType = newEnums + newMsg.Extension = newExts + return newMsg, nil +} + +func stripSourceRetentionOptionsFromField( + field *descriptorpb.FieldDescriptorProto, + path sourcePath, + removedPaths *sourcePathTrie, +) (*descriptorpb.FieldDescriptorProto, error) { + optionsPath := path.push(fieldOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(field.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts == field.GetOptions() { + return field, nil + } + newField, err := shallowCopy(field) + if err != nil { + return nil, err + } + newField.Options = newOpts + return newField, nil +} + +func stripSourceRetentionOptionsFromOneof( + oneof *descriptorpb.OneofDescriptorProto, + path sourcePath, + removedPaths *sourcePathTrie, +) (*descriptorpb.OneofDescriptorProto, error) { + optionsPath := path.push(oneofOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(oneof.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts == oneof.GetOptions() { + return oneof, nil + } + newOneof, err := shallowCopy(oneof) + if err != nil { + return nil, err + } + newOneof.Options = newOpts + return newOneof, nil +} + +func stripSourceRetentionOptionsFromExtensionRange( + extRange *descriptorpb.DescriptorProto_ExtensionRange, + path sourcePath, + removedPaths *sourcePathTrie, +) (*descriptorpb.DescriptorProto_ExtensionRange, error) { + optionsPath := path.push(extensionRangeOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(extRange.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts == extRange.GetOptions() { + return extRange, nil + } + newExtRange, err := shallowCopy(extRange) + if err != nil { + return nil, err + } + newExtRange.Options = newOpts + return newExtRange, nil +} + +func stripSourceRetentionOptionsFromEnum( + enum *descriptorpb.EnumDescriptorProto, + path sourcePath, + removedPaths *sourcePathTrie, +) (*descriptorpb.EnumDescriptorProto, error) { + var dirty bool + optionsPath := path.push(enumOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(enum.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts != enum.GetOptions() { + dirty = true + } + valsPath := path.push(enumValuesTag) + newVals, changed, err := stripOptionsFromAll(enum.GetValue(), stripSourceRetentionOptionsFromEnumValue, valsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + + if !dirty { + return enum, nil + } + + newEnum, err := shallowCopy(enum) + if err != nil { + return nil, err + } + newEnum.Options = newOpts + newEnum.Value = newVals + return newEnum, nil +} + +func stripSourceRetentionOptionsFromEnumValue( + enumVal *descriptorpb.EnumValueDescriptorProto, + path sourcePath, + removedPaths *sourcePathTrie, +) (*descriptorpb.EnumValueDescriptorProto, error) { + optionsPath := path.push(enumValOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(enumVal.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts == enumVal.GetOptions() { + return enumVal, nil + } + newEnumVal, err := shallowCopy(enumVal) + if err != nil { + return nil, err + } + newEnumVal.Options = newOpts + return newEnumVal, nil +} + +func stripSourceRetentionOptionsFromService( + svc *descriptorpb.ServiceDescriptorProto, + path sourcePath, + removedPaths *sourcePathTrie, +) (*descriptorpb.ServiceDescriptorProto, error) { + var dirty bool + optionsPath := path.push(serviceOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(svc.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts != svc.GetOptions() { + dirty = true + } + methodsPath := path.push(serviceMethodsTag) + newMethods, changed, err := stripOptionsFromAll(svc.GetMethod(), stripSourceRetentionOptionsFromMethod, methodsPath, removedPaths) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + + if !dirty { + return svc, nil + } + + newSvc, err := shallowCopy(svc) + if err != nil { + return nil, err + } + newSvc.Options = newOpts + newSvc.Method = newMethods + return newSvc, nil +} + +func stripSourceRetentionOptionsFromMethod( + method *descriptorpb.MethodDescriptorProto, + path sourcePath, + removedPaths *sourcePathTrie, +) (*descriptorpb.MethodDescriptorProto, error) { + optionsPath := path.push(methodOptionsTag) + newOpts, err := stripSourceRetentionOptionsFromProtoMessage(method.GetOptions(), optionsPath, removedPaths) + if err != nil { + return nil, err + } + if newOpts == method.GetOptions() { + return method, nil + } + newMethod, err := shallowCopy(method) + if err != nil { + return nil, err + } + newMethod.Options = newOpts + return newMethod, nil +} + +func stripSourcePathsForSourceRetentionOptions( + sourceInfo *descriptorpb.SourceCodeInfo, + removedPaths *sourcePathTrie, +) *descriptorpb.SourceCodeInfo { + if sourceInfo == nil || len(sourceInfo.GetLocation()) == 0 || removedPaths == nil { + // nothing to do + return sourceInfo + } + newLocations := make([]*descriptorpb.SourceCodeInfo_Location, len(sourceInfo.GetLocation())) + var i int + for _, loc := range sourceInfo.GetLocation() { + if removedPaths.isRemoved(loc.GetPath()) { + continue + } + newLocations[i] = loc + i++ + } + newLocations = newLocations[:i] + return &descriptorpb.SourceCodeInfo{Location: newLocations} +} + +func shallowCopy[M proto.Message](msg M) (M, error) { + msgRef := msg.ProtoReflect() + other := msgRef.New() + ret, ok := other.Interface().(M) + if !ok { + return ret, fmt.Errorf("creating new message of same type resulted in unexpected type; got %T, want %T", other.Interface(), ret) + } + msgRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { + other.Set(field, val) + return true + }) + return ret, nil +} + +// stripOptionsFromAll applies the given function to each element in the given +// slice in order to remove source-retention options from it. It returns the new +// slice and a bool indicating whether anything was actually changed. If the +// second value is false, then the returned slice is the same slice as the input +// slice. Usually, T is a pointer type, in which case the given updateFunc should +// NOT mutate the input value. Instead, it should return the input value if only +// if there is no update needed. If a mutation is needed, it should return a new +// value. +func stripOptionsFromAll[T comparable]( + slice []T, + updateFunc func(T, sourcePath, *sourcePathTrie) (T, error), + path sourcePath, + removedPaths *sourcePathTrie, +) ([]T, bool, error) { + var updated []T // initialized lazily, only when/if a copy is needed + for i, item := range slice { + newItem, err := updateFunc(item, path.push(int32(i)), removedPaths) + if err != nil { + return nil, false, err + } + if updated != nil { + updated[i] = newItem + } else if newItem != item { + updated = make([]T, len(slice)) + copy(updated[:i], slice) + updated[i] = newItem + } + } + if updated != nil { + return updated, true, nil + } + return slice, false, nil +} + +type sourcePath protoreflect.SourcePath + +func (p sourcePath) push(element int32) sourcePath { + if p == nil { + return nil + } + return append(p, element) +} + +type sourcePathTrie struct { + removed bool + children map[int32]*sourcePathTrie +} + +func (t *sourcePathTrie) addPath(path sourcePath) { + if t == nil { + return + } + if len(path) == 0 { + t.removed = true + return + } + child := t.children[path[0]] + if child == nil { + if t.children == nil { + t.children = map[int32]*sourcePathTrie{} + } + child = &sourcePathTrie{} + t.children[path[0]] = child + } + child.addPath(path[1:]) +} + +func (t *sourcePathTrie) isRemoved(path []int32) bool { + if t == nil { + return false + } + if t.removed { + return true + } + if len(path) == 0 { + return false + } + child := t.children[path[0]] + if child == nil { + return false + } + return child.isRemoved(path[1:]) +} diff --git a/protopluginutil/source_retention_options_test.go b/protopluginutil/source_retention_options_test.go new file mode 100644 index 0000000..f07a65b --- /dev/null +++ b/protopluginutil/source_retention_options_test.go @@ -0,0 +1,686 @@ +// Copyright 2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protopluginutil + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" +) + +func TestStripSourceRetentionOptions(t *testing.T) { + t.Parallel() + makeCustomOptionSet := func(startTag int32, extendee string, prefix string, label descriptorpb.FieldDescriptorProto_Label) []*descriptorpb.FieldDescriptorProto { + return []*descriptorpb.FieldDescriptorProto{ + { + Extendee: proto.String(extendee), + Name: proto.String(prefix + "no_retention"), + Number: proto.Int32(startTag), + Label: label.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + // No option + }, + { + Extendee: proto.String(extendee), + Name: proto.String(prefix + "unknown_retention"), + Number: proto.Int32(startTag + 1), + Label: label.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_BOOL.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_UNKNOWN.Enum(), + }, + }, + { + Extendee: proto.String(extendee), + Name: proto.String(prefix + "runtime_retention"), + Number: proto.Int32(startTag + 2), + Label: label.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_BYTES.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_RUNTIME.Enum(), + }, + }, + { + Extendee: proto.String(extendee), + Name: proto.String(prefix + "source_retention"), + Number: proto.Int32(startTag + 3), + Label: label.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_SOURCE.Enum(), + }, + }, + } + } + makeCustomOptions := func(extendee string, prefix string) []*descriptorpb.FieldDescriptorProto { + return append( + makeCustomOptionSet(10000, extendee, prefix, descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL), + makeCustomOptionSet(20000, extendee, prefix+"rep_", descriptorpb.FieldDescriptorProto_LABEL_REPEATED)..., + ) + } + + optsFileProto := &descriptorpb.FileDescriptorProto{ + Name: proto.String("opts.proto"), + Package: proto.String("foo.bar"), + Dependency: []string{"google/protobuf/descriptor.proto"}, + Extension: testCombineAll( + makeCustomOptions(".google.protobuf.FileOptions", "file_"), + makeCustomOptions(".google.protobuf.MessageOptions", "msg_"), + makeCustomOptions(".google.protobuf.FieldOptions", "field_"), + makeCustomOptions(".google.protobuf.OneofOptions", "oneof_"), + makeCustomOptions(".google.protobuf.ExtensionRangeOptions", "extrange_"), + makeCustomOptions(".google.protobuf.EnumOptions", "enum_"), + makeCustomOptions(".google.protobuf.EnumValueOptions", "enumval_"), + makeCustomOptions(".google.protobuf.ServiceOptions", "svc_"), + makeCustomOptions(".google.protobuf.MethodOptions", "method_"), + ), + } + optsFile, err := protodesc.NewFile(optsFileProto, protoregistry.GlobalFiles) + require.NoError(t, err) + + applyCustomOptionSet := func(all, retained protoreflect.Message, prefix protoreflect.Name, isList bool, file protoreflect.FileDescriptor) { + extType := dynamicpb.NewExtensionType(file.Extensions().ByName(prefix + "no_retention")) + var val protoreflect.Value + if isList { + listVal := extType.New().List() + listVal.Append(protoreflect.ValueOfString("foo")) + listVal.Append(protoreflect.ValueOfString("bar")) + val = protoreflect.ValueOfList(listVal) + } else { + val = protoreflect.ValueOfString("abc") + } + all.Set(extType.TypeDescriptor(), val) + retained.Set(extType.TypeDescriptor(), val) + + extType = dynamicpb.NewExtensionType(file.Extensions().ByName(prefix + "unknown_retention")) + if isList { + listVal := extType.New().List() + listVal.Append(protoreflect.ValueOfBool(false)) + listVal.Append(protoreflect.ValueOfBool(true)) + val = protoreflect.ValueOfList(listVal) + } else { + val = protoreflect.ValueOfBool(true) + } + all.Set(extType.TypeDescriptor(), val) + retained.Set(extType.TypeDescriptor(), val) + + extType = dynamicpb.NewExtensionType(file.Extensions().ByName(prefix + "runtime_retention")) + if isList { + listVal := extType.New().List() + listVal.Append(protoreflect.ValueOfBytes([]byte{0, 1, 2, 3})) + listVal.Append(protoreflect.ValueOfBytes([]byte{4, 5, 6, 7})) + val = protoreflect.ValueOfList(listVal) + } else { + val = protoreflect.ValueOfBytes([]byte{0, 1, 2, 3}) + } + all.Set(extType.TypeDescriptor(), val) + retained.Set(extType.TypeDescriptor(), val) + + extType = dynamicpb.NewExtensionType(file.Extensions().ByName(prefix + "source_retention")) + if isList { + listVal := extType.New().List() + listVal.Append(protoreflect.ValueOfInt32(123)) + listVal.Append(protoreflect.ValueOfInt32(-456)) + val = protoreflect.ValueOfList(listVal) + } else { + val = protoreflect.ValueOfInt32(123) + } + all.Set(extType.TypeDescriptor(), val) + // don't set retained because this is a source-only option (won't be retained) + } + applyCustomOptions := func(message proto.Message, prefix protoreflect.Name, file protoreflect.FileDescriptor) (proto.Message, proto.Message) { + allRef := message.ProtoReflect() + strippedRef := proto.Clone(message).ProtoReflect() + applyCustomOptionSet(allRef, strippedRef, prefix, false, file) + applyCustomOptionSet(allRef, strippedRef, prefix+"rep_", true, file) + return allRef.Interface(), strippedRef.Interface() + } + + fileOpts, fileOptsStripped := applyCustomOptions(&descriptorpb.FileOptions{}, "file_", optsFile) + msgOpts, msgOptsStripped := applyCustomOptions(&descriptorpb.MessageOptions{}, "msg_", optsFile) + fieldOpts, fieldOptsStripped := applyCustomOptions(&descriptorpb.FieldOptions{}, "field_", optsFile) + oneofOpts, oneofOptsStripped := applyCustomOptions(&descriptorpb.OneofOptions{}, "oneof_", optsFile) + extRangeOpts, extRangeOptsStripped := applyCustomOptions(&descriptorpb.ExtensionRangeOptions{}, "extrange_", optsFile) + enumOpts, enumOptsStripped := applyCustomOptions(&descriptorpb.EnumOptions{}, "enum_", optsFile) + enumValOpts, enumValOptsStripped := applyCustomOptions(&descriptorpb.EnumValueOptions{}, "enumval_", optsFile) + svcOpts, svcOptsStripped := applyCustomOptions(&descriptorpb.ServiceOptions{}, "svc_", optsFile) + methodOpts, methodOptsStripped := applyCustomOptions(&descriptorpb.MethodOptions{}, "method_", optsFile) + + allLocations := func(pathPrefix ...int32) []*descriptorpb.SourceCodeInfo_Location { + return []*descriptorpb.SourceCodeInfo_Location{ + {Path: append(pathPrefix, 10000)}, + {Path: append(pathPrefix, 10001)}, + {Path: append(pathPrefix, 10002)}, + {Path: append(pathPrefix, 10003)}, + {Path: append(pathPrefix, 10003, 1)}, + {Path: append(pathPrefix, 20000)}, + {Path: append(pathPrefix, 20001)}, + {Path: append(pathPrefix, 20002)}, + {Path: append(pathPrefix, 20003)}, + {Path: append(pathPrefix, 20003, 0, 1)}, + {Path: append(pathPrefix, 20003, 1, 1)}, + {Path: append(pathPrefix, 20003, 2, 1)}, + } + } + strippedLocations := func(pathPrefix ...int32) []*descriptorpb.SourceCodeInfo_Location { + return []*descriptorpb.SourceCodeInfo_Location{ + {Path: append(pathPrefix, 10000)}, + {Path: append(pathPrefix, 10001)}, + {Path: append(pathPrefix, 10002)}, + // 10003 is source retention + {Path: append(pathPrefix, 20000)}, + {Path: append(pathPrefix, 20001)}, + {Path: append(pathPrefix, 20002)}, + // 20003 is source retention + } + } + + beforeFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + Options: fileOpts.(*descriptorpb.FileOptions), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Message"), + Options: msgOpts.(*descriptorpb.MessageOptions), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("field"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("field"), + Options: fieldOpts.(*descriptorpb.FieldOptions), + OneofIndex: proto.Int32(0), + }, + { + Name: proto.String("another_field"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("anotherField"), + Options: fieldOpts.(*descriptorpb.FieldOptions), + }, + }, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + { + Name: proto.String("oo"), + Options: oneofOpts.(*descriptorpb.OneofOptions), + }, + }, + ExtensionRange: []*descriptorpb.DescriptorProto_ExtensionRange{ + { + Start: proto.Int32(100), + End: proto.Int32(200), + Options: extRangeOpts.(*descriptorpb.ExtensionRangeOptions), + }, + }, + NestedType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("NestedMessage"), + Options: msgOpts.(*descriptorpb.MessageOptions), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("field"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("field"), + Options: fieldOpts.(*descriptorpb.FieldOptions), + OneofIndex: proto.Int32(0), + }, + }, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("NestedEnum"), + Options: enumOpts.(*descriptorpb.EnumOptions), + Value: []*descriptorpb.EnumValueDescriptorProto{ + { + Name: proto.String("ZERO"), + Number: proto.Int32(0), + Options: enumValOpts.(*descriptorpb.EnumValueOptions), + }, + { + Name: proto.String("ONE"), + Number: proto.Int32(1), + Options: enumValOpts.(*descriptorpb.EnumValueOptions), + }, + }, + }, + }, + Extension: []*descriptorpb.FieldDescriptorProto{ + { + Extendee: proto.String(".Message"), + Name: proto.String("ext"), + Number: proto.Int32(101), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Options: fieldOpts.(*descriptorpb.FieldOptions), + }, + }, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("Enum"), + Options: enumOpts.(*descriptorpb.EnumOptions), + Value: []*descriptorpb.EnumValueDescriptorProto{ + { + Name: proto.String("ZERO"), + Number: proto.Int32(0), + Options: enumValOpts.(*descriptorpb.EnumValueOptions), + }, + { + Name: proto.String("ONE"), + Number: proto.Int32(1), + Options: enumValOpts.(*descriptorpb.EnumValueOptions), + }, + }, + }, + }, + Extension: []*descriptorpb.FieldDescriptorProto{ + { + Extendee: proto.String(".Message"), + Name: proto.String("ext"), + Number: proto.Int32(100), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Options: fieldOpts.(*descriptorpb.FieldOptions), + }, + }, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("Service"), + Options: svcOpts.(*descriptorpb.ServiceOptions), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("Do"), + InputType: proto.String(".Message"), + OutputType: proto.String(".Message"), + Options: methodOpts.(*descriptorpb.MethodOptions), + }, + }, + }, + }, + SourceCodeInfo: &descriptorpb.SourceCodeInfo{ + Location: testCombineAll( + allLocations(fileOptionsTag), + allLocations(fileMessagesTag, 0, messageOptionsTag), + allLocations(fileMessagesTag, 0, messageFieldsTag, 0, fieldOptionsTag), + allLocations(fileMessagesTag, 0, messageFieldsTag, 1, fieldOptionsTag), + allLocations(fileMessagesTag, 0, messageOneofsTag, 0, oneofOptionsTag), + allLocations(fileMessagesTag, 0, messageExtensionRangesTag, 0, extensionRangeOptionsTag), + allLocations(fileMessagesTag, 0, messageNestedMessagesTag, 0, messageOptionsTag), + allLocations(fileMessagesTag, 0, messageNestedMessagesTag, 0, messageFieldsTag, 0, fieldOptionsTag), + allLocations(fileMessagesTag, 0, messageEnumsTag, 0, enumOptionsTag), + allLocations(fileMessagesTag, 0, messageEnumsTag, 0, enumValuesTag, 0, enumValOptionsTag), + allLocations(fileMessagesTag, 0, messageEnumsTag, 0, enumValuesTag, 1, enumValOptionsTag), + allLocations(fileMessagesTag, 0, messageExtensionsTag, 0, fieldOptionsTag), + allLocations(fileEnumsTag, 0, enumOptionsTag), + allLocations(fileEnumsTag, 0, enumValuesTag, 0, enumValOptionsTag), + allLocations(fileEnumsTag, 0, enumValuesTag, 1, enumValOptionsTag), + allLocations(fileExtensionsTag, 0, fieldOptionsTag), + allLocations(fileServicesTag, 0, serviceOptionsTag), + allLocations(fileServicesTag, 0, serviceMethodsTag, 0, methodOptionsTag), + ), + }, + } + + // This one is the same as above, but uses the stripped option messages + afterFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + Options: fileOptsStripped.(*descriptorpb.FileOptions), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Message"), + Options: msgOptsStripped.(*descriptorpb.MessageOptions), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("field"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("field"), + Options: fieldOptsStripped.(*descriptorpb.FieldOptions), + OneofIndex: proto.Int32(0), + }, + { + Name: proto.String("another_field"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("anotherField"), + Options: fieldOptsStripped.(*descriptorpb.FieldOptions), + }, + }, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + { + Name: proto.String("oo"), + Options: oneofOptsStripped.(*descriptorpb.OneofOptions), + }, + }, + ExtensionRange: []*descriptorpb.DescriptorProto_ExtensionRange{ + { + Start: proto.Int32(100), + End: proto.Int32(200), + Options: extRangeOptsStripped.(*descriptorpb.ExtensionRangeOptions), + }, + }, + NestedType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("NestedMessage"), + Options: msgOptsStripped.(*descriptorpb.MessageOptions), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("field"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("field"), + Options: fieldOptsStripped.(*descriptorpb.FieldOptions), + OneofIndex: proto.Int32(0), + }, + }, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("NestedEnum"), + Options: enumOptsStripped.(*descriptorpb.EnumOptions), + Value: []*descriptorpb.EnumValueDescriptorProto{ + { + Name: proto.String("ZERO"), + Number: proto.Int32(0), + Options: enumValOptsStripped.(*descriptorpb.EnumValueOptions), + }, + { + Name: proto.String("ONE"), + Number: proto.Int32(1), + Options: enumValOptsStripped.(*descriptorpb.EnumValueOptions), + }, + }, + }, + }, + Extension: []*descriptorpb.FieldDescriptorProto{ + { + Extendee: proto.String(".Message"), + Name: proto.String("ext"), + Number: proto.Int32(101), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Options: fieldOptsStripped.(*descriptorpb.FieldOptions), + }, + }, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("Enum"), + Options: enumOptsStripped.(*descriptorpb.EnumOptions), + Value: []*descriptorpb.EnumValueDescriptorProto{ + { + Name: proto.String("ZERO"), + Number: proto.Int32(0), + Options: enumValOptsStripped.(*descriptorpb.EnumValueOptions), + }, + { + Name: proto.String("ONE"), + Number: proto.Int32(1), + Options: enumValOptsStripped.(*descriptorpb.EnumValueOptions), + }, + }, + }, + }, + Extension: []*descriptorpb.FieldDescriptorProto{ + { + Extendee: proto.String(".Message"), + Name: proto.String("ext"), + Number: proto.Int32(100), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Options: fieldOptsStripped.(*descriptorpb.FieldOptions), + }, + }, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("Service"), + Options: svcOptsStripped.(*descriptorpb.ServiceOptions), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("Do"), + InputType: proto.String(".Message"), + OutputType: proto.String(".Message"), + Options: methodOptsStripped.(*descriptorpb.MethodOptions), + }, + }, + }, + }, + SourceCodeInfo: &descriptorpb.SourceCodeInfo{ + Location: testCombineAll( + strippedLocations(fileOptionsTag), + strippedLocations(fileMessagesTag, 0, messageOptionsTag), + strippedLocations(fileMessagesTag, 0, messageFieldsTag, 0, fieldOptionsTag), + strippedLocations(fileMessagesTag, 0, messageFieldsTag, 1, fieldOptionsTag), + strippedLocations(fileMessagesTag, 0, messageOneofsTag, 0, oneofOptionsTag), + strippedLocations(fileMessagesTag, 0, messageExtensionRangesTag, 0, extensionRangeOptionsTag), + strippedLocations(fileMessagesTag, 0, messageNestedMessagesTag, 0, messageOptionsTag), + strippedLocations(fileMessagesTag, 0, messageNestedMessagesTag, 0, messageFieldsTag, 0, fieldOptionsTag), + strippedLocations(fileMessagesTag, 0, messageEnumsTag, 0, enumOptionsTag), + strippedLocations(fileMessagesTag, 0, messageEnumsTag, 0, enumValuesTag, 0, enumValOptionsTag), + strippedLocations(fileMessagesTag, 0, messageEnumsTag, 0, enumValuesTag, 1, enumValOptionsTag), + strippedLocations(fileMessagesTag, 0, messageExtensionsTag, 0, fieldOptionsTag), + strippedLocations(fileEnumsTag, 0, enumOptionsTag), + strippedLocations(fileEnumsTag, 0, enumValuesTag, 0, enumValOptionsTag), + strippedLocations(fileEnumsTag, 0, enumValuesTag, 1, enumValOptionsTag), + strippedLocations(fileExtensionsTag, 0, fieldOptionsTag), + strippedLocations(fileServicesTag, 0, serviceOptionsTag), + strippedLocations(fileServicesTag, 0, serviceMethodsTag, 0, methodOptionsTag), + ), + }, + } + + actualStrippedFile, err := StripSourceRetentionOptions(beforeFile) + require.NoError(t, err) + require.NotSame(t, actualStrippedFile, beforeFile) + require.Empty(t, cmp.Diff(afterFile, actualStrippedFile, protocmp.Transform())) + + // If we repeat the operation, we get back the same descriptor unchanged because + // it doesn't have any source-only options. + doubleStrippedFile, err := StripSourceRetentionOptions(actualStrippedFile) + require.NoError(t, err) + require.Same(t, doubleStrippedFile, actualStrippedFile) + require.Empty(t, cmp.Diff(afterFile, doubleStrippedFile, protocmp.Transform())) +} + +func TestStripSourceRetentionOptionsFromProtoMessage(t *testing.T) { + t.Parallel() + optsFileProto := &descriptorpb.FileDescriptorProto{ + Name: proto.String("opts.proto"), + Package: proto.String("foo.bar"), + Dependency: []string{"google/protobuf/descriptor.proto"}, + Extension: []*descriptorpb.FieldDescriptorProto{ + { + Extendee: proto.String(".google.protobuf.FileOptions"), + Name: proto.String("no_retention"), + Number: proto.Int32(10000), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + // No option + }, + { + Extendee: proto.String(".google.protobuf.FileOptions"), + Name: proto.String("unknown_retention"), + Number: proto.Int32(10001), + Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_UNKNOWN.Enum(), + }, + }, + { + Extendee: proto.String(".google.protobuf.FileOptions"), + Name: proto.String("runtime_retention"), + Number: proto.Int32(10002), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_BYTES.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_RUNTIME.Enum(), + }, + }, + { + Extendee: proto.String(".google.protobuf.FileOptions"), + Name: proto.String("source_retention"), + Number: proto.Int32(10003), + Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_SOURCE.Enum(), + }, + }, + }, + } + optsFile, err := protodesc.NewFile(optsFileProto, protoregistry.GlobalFiles) + require.NoError(t, err) + extNoRetention := dynamicpb.NewExtensionType(optsFile.Extensions().ByName("no_retention")) + extUnknownRetention := dynamicpb.NewExtensionType(optsFile.Extensions().ByName("unknown_retention")) + extRuntimeRetention := dynamicpb.NewExtensionType(optsFile.Extensions().ByName("runtime_retention")) + extSourceRetention := dynamicpb.NewExtensionType(optsFile.Extensions().ByName("source_retention")) + + // Create a message with these options. + optionsMsg := &descriptorpb.FileOptions{} + options := optionsMsg.ProtoReflect() + options.Set(extNoRetention.TypeDescriptor(), protoreflect.ValueOfString("abc")) + listVal := extUnknownRetention.New().List() + listVal.Append(protoreflect.ValueOfString("foo")) + listVal.Append(protoreflect.ValueOfString("bar")) + options.Set(extUnknownRetention.TypeDescriptor(), protoreflect.ValueOfList(listVal)) + options.Set(extRuntimeRetention.TypeDescriptor(), protoreflect.ValueOfBytes([]byte("xyz"))) + // The above will be retained, so create a copy now to serve as the expected result. + optionsAfterStrip := proto.Clone(options.Interface()) + // The below option will get stripped because it's retention policy is source. + listVal = extSourceRetention.New().List() + listVal.Append(protoreflect.ValueOfInt32(123)) + listVal.Append(protoreflect.ValueOfInt32(-456)) + options.Set(extSourceRetention.TypeDescriptor(), protoreflect.ValueOfList(listVal)) + + actualOptionsAfterStrip, err := stripSourceRetentionOptionsFromProtoMessage(optionsMsg, nil, nil) + require.NoError(t, err) + + require.NotSame(t, actualOptionsAfterStrip, optionsMsg) + require.Empty(t, cmp.Diff(optionsAfterStrip, actualOptionsAfterStrip, protocmp.Transform())) + + // If we do it again, there are no changes to made (since source-only options were + // already stripped). So we should get back unmodified value. + optionsMsg = actualOptionsAfterStrip + actualOptionsAfterStrip, err = stripSourceRetentionOptionsFromProtoMessage(optionsMsg, nil, nil) + require.NoError(t, err) + + require.Same(t, actualOptionsAfterStrip, optionsMsg) + require.Empty(t, cmp.Diff(optionsAfterStrip, actualOptionsAfterStrip, protocmp.Transform())) + + // If we have an options message with ONLY source-retention fields, then + // stripping the options results in a nil message (effectively clearing + // the descriptor's options field). + optionsMsg.Reset() + options = optionsMsg.ProtoReflect() // weird that we have to call this again (bug in protobuf-go?) + options.Set(extSourceRetention.TypeDescriptor(), protoreflect.ValueOfList(listVal)) + + actualOptionsAfterStrip, err = stripSourceRetentionOptionsFromProtoMessage(optionsMsg, nil, nil) + require.NoError(t, err) + + require.Same(t, (*descriptorpb.FileOptions)(nil), actualOptionsAfterStrip) +} + +func TestStripOptionsFromAll(t *testing.T) { + t.Parallel() + + errInvalid := errors.New("invalid value") + updateFunc := func(value *int32, _ sourcePath, _ *sourcePathTrie) (*int32, error) { + if value == nil { + return proto.Int32(-1), nil + } + if *value <= -100 { + return nil, errInvalid + } + if *value > 5 { + return proto.Int32(*value * 2), nil + } + return value, nil + } + + vals := []*int32{ + proto.Int32(0), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + proto.Int32(6), proto.Int32(7), proto.Int32(8), + } + newVals, changed, err := stripOptionsFromAll(vals, updateFunc, nil, nil) + require.NoError(t, err) + require.True(t, changed) + expected := []*int32{ + proto.Int32(0), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + proto.Int32(12), proto.Int32(14), proto.Int32(16), + } + require.Equal(t, expected, newVals) + + vals = []*int32{ + nil, proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + } + newVals, changed, err = stripOptionsFromAll(vals, updateFunc, nil, nil) + require.NoError(t, err) + require.True(t, changed) + expected = []*int32{ + proto.Int32(-1), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + } + require.Equal(t, expected, newVals) + + // No changes + vals = []*int32{ + proto.Int32(0), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + } + newVals, changed, err = stripOptionsFromAll(vals, updateFunc, nil, nil) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, vals, newVals) + + // Propagate error + vals = []*int32{ + proto.Int32(0), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(-101), proto.Int32(5), + } + _, _, err = stripOptionsFromAll(vals, updateFunc, nil, nil) + require.ErrorIs(t, err, errInvalid) +} + +func testCombineAll[T any](slices ...[]T) []T { + result := slices[0] + for _, exts := range slices[1:] { + result = append(result, exts...) + } + return result +}