diff --git a/AssemblyToProcess/DerivedClasses.cs b/AssemblyToProcess/DerivedClasses.cs new file mode 100644 index 0000000..7b83f45 --- /dev/null +++ b/AssemblyToProcess/DerivedClasses.cs @@ -0,0 +1,23 @@ +using System.Collections.Generic; +using DeepCopy; + +namespace AssemblyToProcess +{ + [AddDeepCopyConstructor] + public class DictionaryClass : Dictionary + { + public SomeObject SomeProperty { get; set; } + } + + [AddDeepCopyConstructor] + public class ListClass : List + { + public SomeObject SomeProperty { get; set; } + } + + [AddDeepCopyConstructor] + public class SetClass : HashSet + { + public SomeObject SomeProperty { get; set; } + } +} \ No newline at end of file diff --git a/DeepCopy.Fody/Copy.cs b/DeepCopy.Fody/Copy.cs index 73c79fb..61bda16 100644 --- a/DeepCopy.Fody/Copy.cs +++ b/DeepCopy.Fody/Copy.cs @@ -80,7 +80,8 @@ IEnumerable Getter() => new[] return instructions; } - private IEnumerable CopyValue(TypeReference type, Func> getterBuilder, Instruction followUp, bool nullableCheck = true) + private IEnumerable CopyValue(TypeReference type, Func> getterBuilder, Instruction followUp, + bool nullableCheck = true) { var list = new List(); list.AddRange(getterBuilder.Invoke()); @@ -107,8 +108,10 @@ private IEnumerable CopyValue(TypeReference type, Func CopyDictionary(PropertyDefinition property) { - var typeDictionary = property.PropertyType.Resolve(); + return CopyDictionary(property.PropertyType, property); + } + + private IEnumerable CopyDictionary(TypeReference type, PropertyDefinition property) + { + var typeDictionary = type.Resolve(); var typeInstance = (TypeReference) typeDictionary; - var typesArguments = property.PropertyType.SolveGenericArguments().Cast().ToArray(); + var typesArguments = type.SolveGenericArguments().Cast().ToArray(); var typeKeyValuePair = ImportType(typeof(KeyValuePair<,>), typesArguments); var methodGetEnumerator = ImportMethod(ImportType(typeof(IEnumerable<>), typeKeyValuePair), nameof(IEnumerable.GetEnumerator), typeKeyValuePair); @@ -30,19 +35,23 @@ private IEnumerable CopyDictionary(PropertyDefinition property) if (IsType(typeDictionary, typeof(IDictionary<,>))) typeInstance = ImportType(typeof(Dictionary<,>), typesArguments); else - throw new NotSupportedException(property.FullName); + throw new NotSupportedException(type); } else if (!typeDictionary.HasDefaultConstructor()) - throw new NotSupportedException(property.FullName); + throw new NotSupportedException(type); var constructor = ModuleDefinition.ImportReference(NewConstructor(typeInstance).MakeGeneric(typesArguments)); var list = new List(); - list.Add(Instruction.Create(OpCodes.Ldarg_0)); - list.Add(Instruction.Create(OpCodes.Newobj, constructor)); - list.Add(property.MakeSet()); + if (property != null) + { + list.Add(Instruction.Create(OpCodes.Ldarg_0)); + list.Add(Instruction.Create(OpCodes.Newobj, constructor)); + list.Add(property.MakeSet()); + } list.Add(Instruction.Create(OpCodes.Ldarg_1)); - list.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); + if (property != null) + list.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); list.Add(Instruction.Create(OpCodes.Callvirt, methodGetEnumerator)); list.Add(Instruction.Create(OpCodes.Stloc, varEnumerator)); @@ -57,7 +66,8 @@ private IEnumerable CopyDictionary(PropertyDefinition property) list.Add(Instruction.Create(OpCodes.Stloc, varKeyValuePair)); list.Add(Instruction.Create(OpCodes.Ldarg_0)); - list.Add(Instruction.Create(OpCodes.Call, property.GetMethod)); + if (property != null) + list.Add(Instruction.Create(OpCodes.Call, property.GetMethod)); IEnumerable GetterKey() => new[] { @@ -74,7 +84,6 @@ IEnumerable GetterValue() => new[] var setItem = Instruction.Create(OpCodes.Callvirt, ImportMethod(typeDictionary, "set_Item", typesArguments)); var getValue = CopyValue(typesArguments[1], GetterValue, setItem).ToList(); list.AddRange(CopyValue(typesArguments[0], GetterKey, getValue.First(), false)); - //list.AddRange(GetterKey()); list.AddRange(getValue); list.Add(setItem); diff --git a/DeepCopy.Fody/CopyList.cs b/DeepCopy.Fody/CopyList.cs index 02aaa80..1c51de0 100644 --- a/DeepCopy.Fody/CopyList.cs +++ b/DeepCopy.Fody/CopyList.cs @@ -1,4 +1,3 @@ -using System; using System.Collections.Generic; using Mono.Cecil; using Mono.Cecil.Cil; @@ -8,30 +7,39 @@ namespace DeepCopy.Fody public partial class ModuleWeaver { private IEnumerable CopyList(PropertyDefinition property) + { + return CopyList(property.PropertyType, property); + } + + private IEnumerable CopyList(TypeReference type, PropertyDefinition property) { var loopStart = Instruction.Create(OpCodes.Nop); var conditionStart = Instruction.Create(OpCodes.Ldloc, IndexVariable); - var listType = property.PropertyType.Resolve(); + var listType = type.Resolve(); var instanceType = (TypeReference) listType; - var argumentType = property.PropertyType.SolveGenericArgument(); + var argumentType = type.SolveGenericArgument(); if (listType.IsInterface) { if (IsType(listType, typeof(IList<>))) instanceType = ModuleDefinition.ImportReference(typeof(List<>)).MakeGeneric(argumentType); else - throw new NotSupportedException(property.FullName); + throw new NotSupportedException(property); } else if (!listType.HasDefaultConstructor()) - throw new NotSupportedException(property.FullName); + throw new NotSupportedException(property); var listConstructor = ModuleDefinition.ImportReference(NewConstructor(instanceType).MakeGeneric(argumentType)); var list = new List(); - list.Add(Instruction.Create(OpCodes.Ldarg_0)); - list.Add(Instruction.Create(OpCodes.Newobj, listConstructor)); - list.Add(property.MakeSet()); + if (property != null) + { + list.Add(Instruction.Create(OpCodes.Ldarg_0)); + list.Add(Instruction.Create(OpCodes.Newobj, listConstructor)); + list.Add(property.MakeSet()); + } + list.Add(Instruction.Create(OpCodes.Ldc_I4_0)); list.Add(Instruction.Create(OpCodes.Stloc, IndexVariable)); list.Add(Instruction.Create(OpCodes.Br_S, conditionStart)); @@ -48,7 +56,8 @@ private IEnumerable CopyList(PropertyDefinition property) // condition list.Add(conditionStart); list.Add(Instruction.Create(OpCodes.Ldarg_1)); - list.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); + if (property != null) + list.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); list.Add(Instruction.Create(OpCodes.Callvirt, ImportMethod(listType, "get_Count", argumentType))); list.Add(Instruction.Create(OpCodes.Clt)); list.Add(Instruction.Create(OpCodes.Stloc, BooleanVariable)); @@ -62,19 +71,21 @@ private IEnumerable CopyList(PropertyDefinition property) private IEnumerable CopyListItem(PropertyDefinition property, TypeDefinition listType, TypeDefinition argumentType) { - var list = new List - { - Instruction.Create(OpCodes.Ldarg_0), - Instruction.Create(OpCodes.Call, property.GetMethod) - }; + var list = new List(); + list.Add(Instruction.Create(OpCodes.Ldarg_0)); + if (property != null) + list.Add(Instruction.Create(OpCodes.Call, property.GetMethod)); - IEnumerable Getter() => new[] + IEnumerable Getter() { - Instruction.Create(OpCodes.Ldarg_1), - Instruction.Create(OpCodes.Callvirt, property.GetMethod), - Instruction.Create(OpCodes.Ldloc, IndexVariable), - Instruction.Create(OpCodes.Callvirt, ImportMethod(listType, "get_Item", argumentType)) - }; + var getter = new List(); + getter.Add(Instruction.Create(OpCodes.Ldarg_1)); + if (property != null) + getter.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); + getter.Add(Instruction.Create(OpCodes.Ldloc, IndexVariable)); + getter.Add(Instruction.Create(OpCodes.Callvirt, ImportMethod(listType, "get_Item", argumentType))); + return getter; + } var add = Instruction.Create(OpCodes.Callvirt, ImportMethod(listType, "Add", argumentType)); list.AddRange(CopyValue(argumentType, Getter, add)); diff --git a/DeepCopy.Fody/CopySet.cs b/DeepCopy.Fody/CopySet.cs index 24ea2ed..57cbefe 100644 --- a/DeepCopy.Fody/CopySet.cs +++ b/DeepCopy.Fody/CopySet.cs @@ -10,9 +10,14 @@ public partial class ModuleWeaver { private IEnumerable CopySet(PropertyDefinition property) { - var typeSet = property.PropertyType.Resolve(); + return CopySet(property.PropertyType, property); + } + + private IEnumerable CopySet(TypeReference type, PropertyDefinition property) + { + var typeSet = type.Resolve(); var typeInstance = (TypeReference) typeSet; - var typeArgument = property.PropertyType.SolveGenericArgument(); + var typeArgument = type.SolveGenericArgument(); var methodGetEnumerator = ImportMethod(ImportType(typeof(IEnumerable<>), typeArgument), nameof(IEnumerable.GetEnumerator), typeArgument); var typeEnumerator = ImportType(methodGetEnumerator.ReturnType, typeArgument); @@ -28,19 +33,23 @@ private IEnumerable CopySet(PropertyDefinition property) if (IsType(typeSet, typeof(ISet<>))) typeInstance = ImportType(typeof(HashSet<>), typeArgument); else - throw new NotSupportedException(property.FullName); + throw new NotSupportedException(property); } else if (!typeSet.HasDefaultConstructor()) - throw new NotSupportedException(property.FullName); + throw new NotSupportedException(property); var constructor = ModuleDefinition.ImportReference(NewConstructor(typeInstance).MakeGeneric(typeArgument)); var list = new List(); - list.Add(Instruction.Create(OpCodes.Ldarg_0)); - list.Add(Instruction.Create(OpCodes.Newobj, constructor)); - list.Add(property.MakeSet()); + if (property != null) + { + list.Add(Instruction.Create(OpCodes.Ldarg_0)); + list.Add(Instruction.Create(OpCodes.Newobj, constructor)); + list.Add(property.MakeSet()); + } list.Add(Instruction.Create(OpCodes.Ldarg_1)); - list.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); + if (property != null) + list.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); list.Add(Instruction.Create(OpCodes.Callvirt, methodGetEnumerator)); list.Add(Instruction.Create(OpCodes.Stloc, varEnumerator)); @@ -55,7 +64,8 @@ private IEnumerable CopySet(PropertyDefinition property) list.Add(Instruction.Create(OpCodes.Stloc, varCurrent)); list.Add(Instruction.Create(OpCodes.Ldarg_0)); - list.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); + if (property != null) + list.Add(Instruction.Create(OpCodes.Callvirt, property.GetMethod)); IEnumerable Getter() => new[] { diff --git a/DeepCopy.Fody/DeepCopyMethodExtension.cs b/DeepCopy.Fody/DeepCopyMethodExtension.cs index d5d5d88..a9fac47 100644 --- a/DeepCopy.Fody/DeepCopyMethodExtension.cs +++ b/DeepCopy.Fody/DeepCopyMethodExtension.cs @@ -139,7 +139,7 @@ private void BuildMultiTypeSwitchMethodBody(MethodDefinition method, TypeDefinit else { if (!IsCopyConstructorAvailable(baseType, out var constructor)) - throw new CopyConstructorRequiredException(baseType); + throw new NoCopyConstructorFoundException(baseType); processor.Emit(OpCodes.Ldarg_0); processor.Emit(OpCodes.Newobj, constructor); diff --git a/DeepCopy.Fody/Exceptions.cs b/DeepCopy.Fody/Exceptions.cs new file mode 100644 index 0000000..69f9915 --- /dev/null +++ b/DeepCopy.Fody/Exceptions.cs @@ -0,0 +1,26 @@ +using Fody; +using Mono.Cecil; + +namespace DeepCopy.Fody +{ + public class NotSupportedException : DeepCopyException + { + public NotSupportedException(MemberReference type) + : base($"{type.FullName} is not supported") { } + } + + public class NoCopyConstructorFoundException : DeepCopyException + { + public NoCopyConstructorFoundException(MemberReference type) + : base($"No copy constructor for {type.FullName} found") { } + } + + public abstract class DeepCopyException : WeavingException + { + protected DeepCopyException(string message) : base(message) { } + + public MemberReference ProcessingType { private get; set; } + + public override string Message => (ProcessingType == null ? "" : $"{ProcessingType.FullName} -> ") + base.Message; + } +} \ No newline at end of file diff --git a/DeepCopy.Fody/ModuleWeaver.cs b/DeepCopy.Fody/ModuleWeaver.cs index d7cc857..b14a61b 100644 --- a/DeepCopy.Fody/ModuleWeaver.cs +++ b/DeepCopy.Fody/ModuleWeaver.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using System.Threading; using Fody; @@ -35,7 +36,6 @@ private VariableDefinition BooleanVariable { if (_booleanVariable != null) return _booleanVariable; _booleanVariable = new VariableDefinition(ModuleDefinition.ImportReference(TypeSystem.BooleanDefinition)); - CurrentBody.Value.InitLocals = true; CurrentBody.Value.Variables.Add(_booleanVariable); return _booleanVariable; } @@ -47,7 +47,6 @@ private VariableDefinition IndexVariable { if (_indexVariable != null) return _indexVariable; _indexVariable = new VariableDefinition(ModuleDefinition.ImportReference(TypeSystem.Int32Definition)); - CurrentBody.Value.InitLocals = true; CurrentBody.Value.Variables.Add(_indexVariable); return _indexVariable; } @@ -114,7 +113,7 @@ private void ExecuteInjectDeepCopy() var constructorResolved = constructor.Resolve(); constructorResolved.Body.SimplifyMacros(); - InsertCopyInstructions(target, constructorResolved.Body); + InsertCopyInstructions(target, constructorResolved.Body, null); constructorResolved.CustomAttributes.Remove(constructorResolved.SingleAttribute(InjectDeepCopyAttribute)); } } @@ -126,6 +125,8 @@ private void AddDeepConstructor(TypeDefinition type) var processor = constructor.Body.GetILProcessor(); + Func> baseCopyFunc = null; + if (type.BaseType.Resolve().MetadataToken == TypeSystem.ObjectDefinition.MetadataToken) { processor.Emit(OpCodes.Ldarg_0); @@ -137,41 +138,74 @@ private void AddDeepConstructor(TypeDefinition type) processor.Emit(OpCodes.Ldarg_1); processor.Emit(OpCodes.Call, baseConstructor); } + else if (IsType(type.BaseType.GetElementType().Resolve(), typeof(Dictionary<,>))) + { + processor.Emit(OpCodes.Ldarg_0); + processor.Emit(OpCodes.Call, ImportDefaultConstructor(type.BaseType)); + baseCopyFunc = reference => CopyDictionary(reference, null); + } + else if (IsType(type.BaseType.GetElementType().Resolve(), typeof(List<>))) + { + processor.Emit(OpCodes.Ldarg_0); + processor.Emit(OpCodes.Call, ImportDefaultConstructor(type.BaseType)); + baseCopyFunc = reference => CopyList(reference, null); + } + else if (IsType(type.BaseType.GetElementType().Resolve(), typeof(HashSet<>))) + { + processor.Emit(OpCodes.Ldarg_0); + processor.Emit(OpCodes.Call, ImportDefaultConstructor(type.BaseType)); + baseCopyFunc = reference => CopySet(reference, null); + } else - throw new CopyConstructorRequiredException(type.BaseType); + throw new NoCopyConstructorFoundException(type.BaseType); - InsertCopyInstructions(type, constructor.Body); + InsertCopyInstructions(type, constructor.Body, baseCopyFunc); processor.Emit(OpCodes.Ret); type.Methods.Add(constructor); } - private void InsertCopyInstructions(TypeDefinition type, MethodBody body) + private void InsertCopyInstructions(TypeDefinition type, MethodBody body, Func> baseCopyInstruction) { - _booleanVariable = null; - _indexVariable = null; - CurrentBody.Value = body; + try + { + _booleanVariable = null; + _indexVariable = null; + CurrentBody.Value = body; - var baseConstructorCall = body.Instructions.Single(i => i.OpCode == OpCodes.Call && i.Operand is MethodReference method && method.Name == ".ctor"); - var index = body.Instructions.IndexOf(baseConstructorCall) + 1; - var properties = new List(); + var baseConstructorCall = body.Instructions.Single(i => i.OpCode == OpCodes.Call && i.Operand is MethodReference method && method.Name == ConstructorName); + var index = body.Instructions.IndexOf(baseConstructorCall) + 1; + var properties = new List(); - foreach (var property in type.Properties) - { - if (!TryCopy(property, out var instructions)) - continue; - properties.Add(property.Name); - foreach (var instruction in instructions) - body.Instructions.Insert(index++, instruction); - } + if (baseCopyInstruction != null) + foreach (var instruction in baseCopyInstruction.Invoke(type.BaseType)) + body.Instructions.Insert(index++, instruction); + + foreach (var property in type.Properties) + { + if (!TryCopy(property, out var instructions)) + continue; + properties.Add(property.Name); + foreach (var instruction in instructions) + body.Instructions.Insert(index++, instruction); + } - LogInfo.Invoke(properties.Count == 0 - ? $"DeepCopy {type.FullName} -> no properties" - : $"DeepCopy {type.FullName} -> {string.Join(", ", properties)}"); + LogInfo.Invoke($"{type.FullName} -> {(properties.Count == 0 ? "no properties" : string.Join(", ", properties))}"); - body.OptimizeMacros(); + if (body.HasVariables) + body.InitLocals = true; - CurrentBody.Value = null; + body.OptimizeMacros(); + } + catch (DeepCopyException exception) + { + exception.ProcessingType = type; + throw; + } + finally + { + CurrentBody.Value = null; + } } #region Setup diff --git a/DeepCopy.Fody/MonoCecilExtensions.cs b/DeepCopy.Fody/MonoCecilExtensions.cs index 3794415..a7a2520 100644 --- a/DeepCopy.Fody/MonoCecilExtensions.cs +++ b/DeepCopy.Fody/MonoCecilExtensions.cs @@ -42,7 +42,7 @@ public static MethodReference GetMethod(this TypeDefinition type, string name) if (TryFindMethod(type, name, out var method)) return method; - throw new NullReferenceException($"No method {name} found for type {type.FullName}"); + throw new MissingMethodException(type.FullName, name); } private static bool TryFindMethod(this TypeDefinition type, string name, out MethodReference method) @@ -109,8 +109,6 @@ public static IEnumerable SolveGenericArguments(this TypeReferen if (!type.IsGenericInstance) throw new ArgumentException(); var arguments = ((GenericInstanceType) type).GenericArguments; - if (arguments.Count != 2) - throw new ArgumentException(); return arguments.Select(a => a.GetElementType().Resolve()).ToArray(); } diff --git a/DeepCopy.Fody/Utilities.cs b/DeepCopy.Fody/Utilities.cs index 697b33c..387bf9f 100644 --- a/DeepCopy.Fody/Utilities.cs +++ b/DeepCopy.Fody/Utilities.cs @@ -22,6 +22,14 @@ private MethodReference ImportDefaultConstructor(TypeDefinition type) return ModuleDefinition.ImportReference(type.GetConstructors().Single(c => !c.HasParameters)); } + private MethodReference ImportDefaultConstructor(TypeReference type) + { + var constructor = type.Resolve().GetConstructors().Single(c => !c.HasParameters && !c.IsStatic); + return ModuleDefinition.ImportReference(type.IsGenericInstance + ? constructor.MakeGeneric(type.SolveGenericArguments().Cast().ToArray()) + : constructor); + } + private bool IsType(IMetadataTokenProvider typeDefinition, Type type) { return typeDefinition.MetadataToken == ModuleDefinition.ImportReference(type).Resolve().MetadataToken; diff --git a/Tests/CopyDerivedClassTests.cs b/Tests/CopyDerivedClassTests.cs index 6f62e06..78bb75f 100644 --- a/Tests/CopyDerivedClassTests.cs +++ b/Tests/CopyDerivedClassTests.cs @@ -18,5 +18,48 @@ public void TestDerivedClass() AssertCopyOfSomeClass(instance.Object, copy.Object); AssertCopyOfSomeClass(instance.BaseObject, copy.BaseObject); } + + [Fact] + public void TestDerivedDictionaryClass() + { + var type = GetTestType(typeof(DictionaryClass)); + dynamic instance = Activator.CreateInstance(type); + instance.SomeProperty = CreateSomeObject(); + instance["foo"] = CreateSomeObject(); + + var copy = Activator.CreateInstance(type, instance); + AssertCopyOfSomeClass(instance.SomeProperty, copy.SomeProperty); + AssertCopyOfSomeClass(instance["foo"], copy["foo"]); + } + + [Fact] + public void TestDerivedListClass() + { + var type = GetTestType(typeof(ListClass)); + dynamic instance = Activator.CreateInstance(type); + instance.SomeProperty = CreateSomeObject(); + instance.Add(CreateSomeObject()); + + var copy = Activator.CreateInstance(type, instance); + Assert.Equal(instance.Count, copy.Count); + AssertCopyOfSomeClass(instance.SomeProperty, copy.SomeProperty); + AssertCopyOfSomeClass(instance[0], copy[0]); + } + + [Fact] + public void TestDerivedSetClass() + { + var type = GetTestType(typeof(SetClass)); + dynamic instance = Activator.CreateInstance(type); + instance.SomeProperty = CreateSomeObject(); + instance.Add(CreateSomeObject()); + + var copy = Activator.CreateInstance(type, instance); + Assert.Equal(instance.Count, copy.Count); + AssertCopyOfSomeClass(instance.SomeProperty, copy.SomeProperty); + var instanceArray = ToArray(instance); + var copyArray = ToArray(copy); + AssertCopyOfSomeClass(instanceArray[0], copyArray[0]); + } } } \ No newline at end of file