Skip to content

Commit

Permalink
[release/7.0] Check for marking virtual method due to base only when …
Browse files Browse the repository at this point in the history
…state changes (#3094)

* Check for marking virtual method due to base only when state changes (#3073)

Instead of checking every virtual method to see if it should be kept due
to a base method every iteration of the MarkStep pipeline, check each
method only when its relevant state has changed.

Co-authored-by: Sven Boemer <[email protected]>

* Don't mark override of abstract base if the override's declaring type is not marked (#3098)

* Don't mark an override every time the base is abstract, only if the declaring type is also marked
Adds a condition to ShouldMarkOverrideForBase to exit early if the declaring type of the method is not marked.

* Add test case for #3112 with pseudo-circular reference with ifaces

* Link issue to TODO

* Adds a test for recursive generics on interfaces

This is a copy of the test added in #3156

Co-authored-by: Sven Boemer <[email protected]>
Co-authored-by: vitek-karas <[email protected]>
  • Loading branch information
3 people authored Jan 18, 2023
1 parent 19fa656 commit ae8160b
Show file tree
Hide file tree
Showing 15 changed files with 363 additions and 151 deletions.
220 changes: 103 additions & 117 deletions src/linker/Linker.Steps/MarkStep.cs

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions src/linker/Linker.Steps/SealerStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void ProcessType (TypeDefinition type)
//
// cannot de-virtualize nor seal methods if something overrides them
//
if (IsAnyMarked (overrides))
if (IsAnyOverrideMarked (overrides))
continue;

SealMethod (method);
Expand All @@ -108,7 +108,7 @@ void ProcessType (TypeDefinition type)

var bases = Annotations.GetBaseMethods (method);
// Devirtualize if a method is not override to existing marked methods
if (!IsAnyMarked (bases))
if (!IsAnyBaseMarked (bases))
method.IsVirtual = method.IsFinal = method.IsNewSlot = false;
}
}
Expand All @@ -123,7 +123,7 @@ protected virtual void SealMethod (MethodDefinition method)
method.IsFinal = true;
}

bool IsAnyMarked (IEnumerable<OverrideInformation>? list)
bool IsAnyOverrideMarked (IEnumerable<OverrideInformation>? list)
{
if (list == null)
return false;
Expand All @@ -135,12 +135,13 @@ bool IsAnyMarked (IEnumerable<OverrideInformation>? list)
return false;
}

bool IsAnyMarked (List<MethodDefinition>? list)
bool IsAnyBaseMarked (IEnumerable<OverrideInformation>? list)
{
if (list == null)
return false;

foreach (var m in list) {
if (Annotations.IsMarked (m))
if (Annotations.IsMarked (m.Base))
return true;
}
return false;
Expand Down
10 changes: 5 additions & 5 deletions src/linker/Linker.Steps/ValidateVirtualMethodAnnotationsStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ protected override void Process ()
{
var annotations = Context.Annotations;
foreach (var method in annotations.VirtualMethodsWithAnnotationsToValidate) {
var baseMethods = annotations.GetBaseMethods (method);
if (baseMethods != null) {
foreach (var baseMethod in baseMethods) {
annotations.FlowAnnotations.ValidateMethodAnnotationsAreSame (method, baseMethod);
ValidateMethodRequiresUnreferencedCodeAreSame (method, baseMethod);
var baseOverrideInformations = annotations.GetBaseMethods (method);
if (baseOverrideInformations != null) {
foreach (var baseOv in baseOverrideInformations) {
annotations.FlowAnnotations.ValidateMethodAnnotationsAreSame (method, baseOv.Base);
ValidateMethodRequiresUnreferencedCodeAreSame (method, baseOv.Base);
}
}

Expand Down
12 changes: 11 additions & 1 deletion src/linker/Linker/Annotations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ public bool IsPublic (IMetadataTokenProvider provider)
return public_api.Contains (provider);
}

/// <summary>
/// Returns a list of all known methods that override <paramref name="method"/>. The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
/// </summary>
public IEnumerable<OverrideInformation>? GetOverrides (MethodDefinition method)
{
return TypeMapInfo.GetOverrides (method);
Expand All @@ -446,7 +449,14 @@ public bool IsPublic (IMetadataTokenProvider provider)
return TypeMapInfo.GetDefaultInterfaceImplementations (method);
}

public List<MethodDefinition>? GetBaseMethods (MethodDefinition method)
/// <summary>
/// Returns all base methods that <paramref name="method"/> overrides.
/// This includes methods on <paramref name="method"/>'s declaring type's base type (but not methods higher up in the type hierarchy),
/// methods on an interface that <paramref name="method"/>'s delcaring type implements,
/// and methods an interface implemented by a derived type of <paramref name="method"/>'s declaring type if the derived type uses <paramref name="method"/> as the implementing method.
/// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use <paramref name="method"/> to implement an interface.
/// </summary>
public List<OverrideInformation>? GetBaseMethods (MethodDefinition method)
{
return TypeMapInfo.GetBaseMethods (method);
}
Expand Down
26 changes: 18 additions & 8 deletions src/linker/Linker/TypeMapInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class TypeMapInfo
{
readonly HashSet<AssemblyDefinition> assemblies = new HashSet<AssemblyDefinition> ();
readonly LinkContext context;
protected readonly Dictionary<MethodDefinition, List<MethodDefinition>> base_methods = new Dictionary<MethodDefinition, List<MethodDefinition>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> base_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> override_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<(TypeDefinition InstanceType, InterfaceImplementation ImplementationProvider)>> default_interface_implementations = new Dictionary<MethodDefinition, List<(TypeDefinition, InterfaceImplementation)>> ();

Expand All @@ -57,17 +57,27 @@ void EnsureProcessed (AssemblyDefinition assembly)
MapType (type);
}

/// <summary>
/// Returns a list of all known methods that override <paramref name="method"/>. The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
/// </summary>
public IEnumerable<OverrideInformation>? GetOverrides (MethodDefinition method)
{
EnsureProcessed (method.Module.Assembly);
override_methods.TryGetValue (method, out List<OverrideInformation>? overrides);
return overrides;
}

public List<MethodDefinition>? GetBaseMethods (MethodDefinition method)
/// <summary>
/// Returns all base methods that <paramref name="method"/> overrides.
/// This includes the closest overridden virtual method on <paramref name="method"/>'s base types
/// methods on an interface that <paramref name="method"/>'s declaring type implements,
/// and methods an interface implemented by a derived type of <paramref name="method"/>'s declaring type if the derived type uses <paramref name="method"/> as the implementing method.
/// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use <paramref name="method"/> to implement an interface.
/// </summary>
public List<OverrideInformation>? GetBaseMethods (MethodDefinition method)
{
EnsureProcessed (method.Module.Assembly);
base_methods.TryGetValue (method, out List<MethodDefinition>? bases);
base_methods.TryGetValue (method, out List<OverrideInformation>? bases);
return bases;
}

Expand All @@ -77,14 +87,14 @@ void EnsureProcessed (AssemblyDefinition assembly)
return ret;
}

public void AddBaseMethod (MethodDefinition method, MethodDefinition @base)
public void AddBaseMethod (MethodDefinition method, MethodDefinition @base, InterfaceImplementation? matchingInterfaceImplementation)
{
if (!base_methods.TryGetValue (method, out List<MethodDefinition>? methods)) {
methods = new List<MethodDefinition> ();
if (!base_methods.TryGetValue (method, out List<OverrideInformation>? methods)) {
methods = new List<OverrideInformation> ();
base_methods[method] = methods;
}

methods.Add (@base);
methods.Add (new OverrideInformation (@base, method, context, matchingInterfaceImplementation));
}

public void AddOverride (MethodDefinition @base, MethodDefinition @override, InterfaceImplementation? matchingInterfaceImplementation = null)
Expand Down Expand Up @@ -204,7 +214,7 @@ void MapOverrides (MethodDefinition method)

void AnnotateMethods (MethodDefinition @base, MethodDefinition @override, InterfaceImplementation? matchingInterfaceImplementation = null)
{
AddBaseMethod (@override, @base);
AddBaseMethod (@override, @base, matchingInterfaceImplementation);
AddOverride (@base, @override, matchingInterfaceImplementation);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ public Task NeverInstantiatedTypeWithBaseInCopiedAssembly ()
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task OverrideInUnmarkedClassIsRemoved ()
{
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task UnusedTypeWithOverrideOfVirtualMethodIsRemoved ()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ namespace Mono.Linker.Tests.Cases.Attributes
[KeptMemberInAssembly ("impl", "Mono.Linker.Tests.Cases.Attributes.Dependencies.IReferencedAssemblyImpl", "Foo()")]
[KeptInterfaceOnTypeInAssembly ("impl", "Mono.Linker.Tests.Cases.Attributes.Dependencies.IReferencedAssemblyImpl",
"interface", "Mono.Linker.Tests.Cases.Attributes.Dependencies.IReferencedAssembly")]
[SetupLinkerTrimMode ("link")]
[IgnoreDescriptors (false)]
public class TypeWithDynamicInterfaceCastableImplementationAttributeIsKept
{
public static void Main ()
Expand Down Expand Up @@ -54,6 +56,7 @@ static IReferencedAssembly GetReferencedInterface (object obj)
#if NETCOREAPP
[Kept]
[KeptMember (".ctor()")]
[KeptInterface (typeof (IDynamicInterfaceCastable))]
class Foo : IDynamicInterfaceCastable
{
[Kept]
Expand All @@ -74,6 +77,7 @@ public bool IsInterfaceImplemented (RuntimeTypeHandle interfaceType, bool throwI

[Kept]
[KeptMember (".ctor()")]
[KeptInterface (typeof (IDynamicInterfaceCastable))]
class DynamicCastableImplementedInOtherAssembly : IDynamicInterfaceCastable
{
[Kept]
Expand Down
19 changes: 19 additions & 0 deletions test/Mono.Linker.Tests.Cases/DataFlow/GenericParameterDataFlow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ static void TestInterfaceTypeGenericRequirements ()
new InterfaceImplementationTypeWithInstantiationOverSelfOnBase ();
new InterfaceImplementationTypeWithOpenGenericOnBase<TestType> ();
new InterfaceImplementationTypeWithOpenGenericOnBaseWithRequirements<TestType> ();

RecursiveGenericWithInterfacesRequirement.Test ();
}

interface IGenericInterfaceTypeWithRequirements<[DynamicallyAccessedMembers (DynamicallyAccessedMemberTypes.PublicFields)] T>
Expand Down Expand Up @@ -345,6 +347,23 @@ class InterfaceImplementationTypeWithOpenGenericOnBaseWithRequirements<[Dynamica
{
}

class RecursiveGenericWithInterfacesRequirement
{
interface IFace<[DynamicallyAccessedMembers (DynamicallyAccessedMemberTypes.Interfaces)] T>
{
}

class TestType : IFace<TestType>
{
}

public static void Test ()
{
var a = typeof (IFace<string>);
var t = new TestType ();
}
}

static void TestTypeGenericRequirementsOnMembers ()
{
// Basically just root everything we need to test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;
using Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase.Dependencies;

namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase
{
/// <summary>
/// Reproduces the issue found in https://github.com/dotnet/linker/issues/3112.
/// <see cref="Derived1"/> derives from <see cref="Base"/> and uses <see cref="Base"/>'s method to implement <see cref="IFoo"/>,
/// creating a psuedo-circular assembly reference (but not quite since <see cref="Base"/> doesn't implement IFoo itself).
/// In the linker, IsMethodNeededByInstantiatedTypeDueToPreservedScope would iterate through <see cref="Base"/>'s method's base methods,
/// and in the process would trigger the assembly of <see cref="IFoo"/> to be processed. Since that assembly also has <see cref="Derived2"/> that
/// inherits from <see cref="Base"/> and implements <see cref="IBar"/> using <see cref="Base"/>'s methods, the linker adds
/// <see cref="IBar"/>'s method as a base to <see cref="Base"/>'s method, which modifies the collection as it's being iterated, causing an exception.
/// </summary>
[SetupCompileBefore ("base.dll", new[] { "Dependencies/Base.cs" })] // Base Implements IFoo.Method (psuedo-reference to ifoo.dll)
[SetupCompileBefore ("ifoo.dll", new[] { "Dependencies/IFoo.cs" }, references: new[] { "base.dll" })] // Derived2 references base.dll (circular reference)
[SetupCompileBefore ("derived1.dll", new[] { "Dependencies/Derived1.cs" }, references: new[] { "ifoo.dll", "base.dll" })]
public class BaseProvidesInterfaceMethodCircularReference
{
[Kept]
public static void Main ()
{
_ = new Derived1 ();
Foo ();
}

[Kept]
public static void Foo ()
{
((IFoo) null).Method ();
object x = null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;

namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase.Dependencies
{
public class Base
{
public virtual void Method()
{
throw new NotImplementedException();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase.Dependencies
{
public class Derived1 : Base, IFoo
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase.Dependencies
{
public interface IFoo
{
void Method();
}
public interface IBar
{
void Method();
}
public class Derived2 : Base, IBar
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@ public static void Main ()
t = typeof (UninstantiatedPublicClassWithPrivateInterface);
t = typeof (ImplementsUsedStaticInterface.InterfaceMethodUnused);

ImplementsUnusedStaticInterface.Test (); ;
ImplementsUnusedStaticInterface.Test ();
GenericMethodThatCallsInternalStaticInterfaceMethod
<ImplementsUsedStaticInterface.InterfaceMethodUsedThroughInterface> ();
// Use all public interfaces - they're marked as public only to denote them as "used"
typeof (IPublicInterface).RequiresPublicMethods ();
typeof (IPublicStaticInterface).RequiresPublicMethods ();
var ___ = new InstantiatedClassWithInterfaces ();
_ = new InstantiatedClassWithInterfaces ();
MarkIFormattable (null);
}

[Kept]
static void MarkIFormattable (IFormattable x)
{ }

[Kept]
internal static void GenericMethodThatCallsInternalStaticInterfaceMethod<T> () where T : IStaticInterfaceUsed
{
Expand Down Expand Up @@ -113,8 +118,8 @@ public static void Test ()
}
}

// Interfaces are kept despite being uninstantiated because it is relevant to variant casting
[Kept]
[KeptInterface (typeof (IEnumerator))]
[KeptInterface (typeof (IPublicInterface))]
[KeptInterface (typeof (IPublicStaticInterface))]
[KeptInterface (typeof (ICopyLibraryInterface))]
Expand Down Expand Up @@ -151,18 +156,12 @@ public static void InternalStaticInterfaceMethod () { }
static void IInternalStaticInterface.ExplicitImplementationInternalStaticInterfaceMethod () { }


[Kept]
[ExpectBodyModified]
bool IEnumerator.MoveNext () { throw new PlatformNotSupportedException (); }

[Kept]
object IEnumerator.Current {
[Kept]
[ExpectBodyModified]
get { throw new PlatformNotSupportedException (); }
}

[Kept]
void IEnumerator.Reset () { }

[Kept]
Expand Down Expand Up @@ -198,7 +197,6 @@ public string ToString (string format, IFormatProvider formatProvider)
}

[Kept]
[KeptInterface (typeof (IEnumerator))]
[KeptInterface (typeof (IPublicInterface))]
[KeptInterface (typeof (IPublicStaticInterface))]
[KeptInterface (typeof (ICopyLibraryInterface))]
Expand Down Expand Up @@ -235,13 +233,10 @@ public static void InternalStaticInterfaceMethod () { }

static void IInternalStaticInterface.ExplicitImplementationInternalStaticInterfaceMethod () { }

[Kept]
bool IEnumerator.MoveNext () { throw new PlatformNotSupportedException (); }

[Kept]
object IEnumerator.Current { [Kept] get { throw new PlatformNotSupportedException (); } }
object IEnumerator.Current { get { throw new PlatformNotSupportedException (); } }

[Kept]
void IEnumerator.Reset () { }

[Kept]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,15 @@ class MyType : IStaticInterfaceWithDefaultImpls
public int InstanceMethod () => 0;
}

// Keep MyType without marking it relevant to variant casting
[Kept]
static void KeepMyType (MyType x)
{ }

[Kept]
static void Test ()
{
var x = typeof (MyType); // The only use of MyType
KeepMyType (null);
}
}
}
Loading

0 comments on commit ae8160b

Please sign in to comment.