diff --git a/Mono.Cecil/Import.cs b/Mono.Cecil/Import.cs index 272e96076..fac6aa58e 100644 --- a/Mono.Cecil/Import.cs +++ b/Mono.Cecil/Import.cs @@ -305,7 +305,7 @@ public virtual AssemblyNameReference ImportReference (SR.AssemblyName name) Mixin.CheckName (name); AssemblyNameReference reference; - if (TryGetAssemblyNameReference (name, out reference)) + if (Mixin.TryGetAssemblyNameReference (module, name, out reference)) return reference; reference = new AssemblyNameReference (name.Name, name.Version) @@ -320,23 +320,6 @@ public virtual AssemblyNameReference ImportReference (SR.AssemblyName name) return reference; } - bool TryGetAssemblyNameReference (SR.AssemblyName name, out AssemblyNameReference assembly_reference) - { - var references = module.AssemblyReferences; - - for (int i = 0; i < references.Count; i++) { - var reference = references [i]; - if (name.FullName != reference.FullName) // TODO compare field by field - continue; - - assembly_reference = reference; - return true; - } - - assembly_reference = null; - return false; - } - FieldReference ImportField (SR.FieldInfo field, ImportGenericContext context) { var declaring_type = ImportType (field.DeclaringType, context); @@ -756,57 +739,316 @@ public static void CheckModule (ModuleDefinition module) public static bool TryGetAssemblyNameReference (this ModuleDefinition module, AssemblyNameReference name_reference, out AssemblyNameReference assembly_reference) { - var references = module.AssemblyReferences; + return TryGetAssemblyNameReference (module, (AssemblyWrapper)name_reference, out assembly_reference); + } + + public static bool TryGetAssemblyNameReference (this ModuleDefinition module, SR.AssemblyName name_reference, out AssemblyNameReference assembly_reference) + { + return TryGetAssemblyNameReference (module, (AssemblyWrapper)name_reference, out assembly_reference); + } - for (int i = 0; i < references.Count; i++) { - var reference = references [i]; - if (!Equals (name_reference, reference)) - continue; + static bool TryGetAssemblyNameReference (this ModuleDefinition module, AssemblyWrapper name_reference, out AssemblyNameReference assembly_reference) + { + // Try to resolve the assembly using direct reference first + if (module.TryGetDirectAssemblyNameReference (name_reference, out assembly_reference)) { + return true; + } + + // If direct resolution fails, try to resolve using forwarded types + if (module.TryGetForwardedAssemblyNameReference (name_reference, out assembly_reference, new HashSet ())) { + return true; + } - assembly_reference = reference; + // As a fallback, if the assembly is still unresolved, + // check if it could be a system assembly like 'System.Private.CoreLib', + // which can be safely mapped to 'System.Runtime' in certain .NET platforms. + // This handles the internal details of .NET Core and .NET 5+, + // where 'System.Private.CoreLib' is the core library but it is not referenceable + // and 'System.Runtime' is the referenceable contract assembly. + if (module.TryGetSystemAssemblyNameReference (name_reference, out assembly_reference)) { return true; } + // If resolution fails, set the output reference to null assembly_reference = null; return false; } - static bool Equals (byte [] a, byte [] b) + static bool TryGetDirectAssemblyNameReference (this ModuleDefinition module, AssemblyWrapper name_reference, out AssemblyNameReference assembly_reference) { - if (ReferenceEquals (a, b)) - return true; - if (a == null) - return false; - if (a.Length != b.Length) - return false; - for (int i = 0; i < a.Length; i++) - if (a [i] != b [i]) - return false; - return true; + // Check each assembly reference in the module for a match by full name + foreach (var reference in module.AssemblyReferences) { + if (reference.FullName == name_reference.FullName) { + assembly_reference = reference; + return true; + } + } + + // If no direct reference is found, set the output reference to null + assembly_reference = null; + return false; } - static bool Equals (T a, T b) where T : class, IEquatable + static bool TryGetForwardedAssemblyNameReference ( + this ModuleDefinition module, + AssemblyWrapper name_reference, + out AssemblyNameReference assembly_reference, + HashSet checked_assemblies) { - if (ReferenceEquals (a, b)) - return true; - if (a == null) - return false; - return a.Equals (b); + // Initialize the output parameter to null + assembly_reference = null; + + // Iterate through all assembly references + foreach (var asm_ref in module.AssemblyReferences) { + if (!checked_assemblies.Add (asm_ref.FullName)) + continue; + + AssemblyDefinition resolved_assembly; + try { + // Attempt to resolve the assembly reference + resolved_assembly = module.AssemblyResolver.Resolve (asm_ref); + } + catch (AssemblyResolutionException) { + // Skip the assembly if resolution fails + continue; + } + + // Check exported types for type forwarding within the assembly + foreach (var module_def in resolved_assembly.Modules) { + foreach (var exported_type in module_def.ExportedTypes) { + // Check if the exported type has a scope that matches the target assembly name + var scope = exported_type.Scope as AssemblyNameReference; + if (scope == null) + continue; + + if (!checked_assemblies.Add (scope.FullName)) + continue; + + if (AssemblyWrapper.Equals (scope, name_reference)) { + // If a match is found, return the assembly reference from which the type was forwarded + assembly_reference = asm_ref; + return true; + } else { + if (TryGetForwardedAssemblyNameReference (module, (AssemblyWrapper)scope, out assembly_reference, checked_assemblies)) + return true; + } + } + } + } + + return false; } - static bool Equals (AssemblyNameReference a, AssemblyNameReference b) + static bool TryGetSystemAssemblyNameReference (this ModuleDefinition module, AssemblyWrapper name_reference, out AssemblyNameReference assembly_reference) { - if (ReferenceEquals (a, b)) - return true; - if (a.Name != b.Name) + assembly_reference = null; + + var possible_system_assemblies = new [] { + "System.Runtime", // The main system assembly for .NET Core and .NET 5+ + "mscorlib", // The main system assembly for the .NET Framework + "netstandard", // The pseudo-assembly for .NET Standard API surface + + // Additional potential system assemblies used in specific contexts: + "System.Private.CoreLib", // The implementation of the BCL in .NET Core and .NET 5+ + }; + + var core_libs = new List (); + foreach (var asm_ref in module.AssemblyReferences) { + foreach (var possible_system_assembly in possible_system_assemblies) { + if (string.Equals (asm_ref.Name, possible_system_assembly, StringComparison.Ordinal)) { + core_libs.Add (asm_ref); + break; + } + } + } + + var possible_core_lib = false; + foreach (var possible_system_assembly in possible_system_assemblies) { + if (string.Equals (name_reference.Name, possible_system_assembly, StringComparison.Ordinal)) { + possible_core_lib = true; + break; + } + } + + if (!possible_core_lib) return false; - if (!Equals (a.Version, b.Version)) + + if (core_libs.Count != 1) return false; - if (a.Culture != b.Culture) + + var core_lib = core_libs [0]; + AssemblyDefinition resolved_core_lib; + + try { + resolved_core_lib = module.AssemblyResolver.Resolve (name_reference.GetAssemblyNameReference()); + } + catch (AssemblyResolutionException) { return false; - if (!Equals (a.PublicKeyToken, b.PublicKeyToken)) + } + + if (resolved_core_lib == null) return false; - return true; + + foreach (var module_def in resolved_core_lib.Modules) { + if (module_def.GetType (typeof (object).FullName) != null) { + assembly_reference = core_lib; + return true; + } + + foreach (var exported_type in module_def.ExportedTypes) { + if (string.Equals (exported_type.FullName, typeof (object).FullName, StringComparison.Ordinal)) { + + var scope = exported_type.Scope as AssemblyNameReference; + if (scope != null) { + foreach (var possible_system_assembly in possible_system_assemblies) { + if (string.Equals (scope.Name, possible_system_assembly, StringComparison.Ordinal)) { + assembly_reference = core_lib; + return true; + } + } + } + } + } + } + + return false; + } + + sealed class AssemblyWrapper { + private readonly SR.AssemblyName assembly_name; + private readonly AssemblyNameReference assembly_name_reference; + + public AssemblyWrapper (SR.AssemblyName assembly_name) + { + CheckName (assembly_name); + this.assembly_name = assembly_name; + } + + public AssemblyWrapper (AssemblyNameReference assembly_name_reference) + { + CheckName (assembly_name_reference); + this.assembly_name_reference = assembly_name_reference; + } + + public string FullName { + get { + if (assembly_name == null) { + return assembly_name_reference.FullName; + } else { + return assembly_name.FullName; + } + } + } + + public string Name { + get { + if (assembly_name == null) { + return assembly_name_reference.Name; + } else { + return assembly_name.Name; + } + } + } + + public Version Version { + get { + if (assembly_name == null) { + return assembly_name_reference.Version; + } else { + return assembly_name.Version; + } + } + } + + public byte [] PublicKeyToken { + get { + if (assembly_name == null) { + return assembly_name_reference.PublicKeyToken; + } else { + return assembly_name.GetPublicKeyToken (); + } + } + } + + public string Culture { + get { + if (assembly_name == null) { + return assembly_name_reference.Culture; + } else { + return assembly_name.CultureInfo.Name; + } + } + } + + public AssemblyNameReference GetAssemblyNameReference () + { + if (this.assembly_name != null) { + return new AssemblyNameReference (this.assembly_name.Name, this.assembly_name.Version) { + PublicKeyToken = this.assembly_name.GetPublicKeyToken (), + Culture = this.assembly_name.CultureInfo.Name, + HashAlgorithm = (AssemblyHashAlgorithm)this.assembly_name.HashAlgorithm, + }; + } else { + return this.assembly_name_reference; + } + } + + public override string ToString () + { + return this.Name; + } + + public static bool Equals (AssemblyWrapper a, AssemblyWrapper b) + { + if (ReferenceEquals (a, b)) + return true; + if (a.assembly_name != null && ReferenceEquals (a.assembly_name, b.assembly_name)) + return true; + if (a.assembly_name_reference != null && ReferenceEquals (a.assembly_name_reference, b.assembly_name_reference)) + return true; + if (a.Name != b.Name) + return false; + if (!Equals (a.Version, b.Version)) + return false; + if (a.Culture != b.Culture) + return false; + if (!Equals (a.PublicKeyToken, b.PublicKeyToken)) + return false; + return true; + } + + static bool Equals (byte [] a, byte [] b) + { + if (ReferenceEquals (a, b)) + return true; + if (a == null) + return false; + if (a.Length != b.Length) + return false; + for (int i = 0; i < a.Length; i++) + if (a [i] != b [i]) + return false; + return true; + } + + static bool Equals (T a, T b) where T : class, IEquatable + { + if (ReferenceEquals (a, b)) + return true; + if (a == null) + return false; + return a.Equals (b); + } + + public static explicit operator AssemblyWrapper (SR.AssemblyName assembly_name) + { + return new AssemblyWrapper (assembly_name); + } + + public static explicit operator AssemblyWrapper (AssemblyNameReference assembly_name_reference) + { + return new AssemblyWrapper (assembly_name_reference); + } } } }