Skip to content

Commit

Permalink
Fixes hang in WinUI apps published to AOT (dotnet#104583)
Browse files Browse the repository at this point in the history
* Fix exception during GC callback in WinUI scenarios due to removing from and adding to the nativeObjectReference list around the same time and causing removal to fail.

* Move common logic

* Change collection type to HashSet to match with JIT implementation and to improve performance of remove given it now takes a lock

* Move to using a custom collection to protect against GC freezing threads during add / remove.

* Address PR feedback by using alternative approach to handle race

* Address PR feedback

* Address PR feedback around handle being freed.
  • Loading branch information
manodasanW committed Sep 12, 2024
1 parent f34e9ac commit 5dc834c
Showing 1 changed file with 242 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
Expand Down Expand Up @@ -39,7 +39,7 @@ public abstract partial class ComWrappers
private static readonly Guid IID_IWeakReferenceSource = new Guid(0x00000038, 0, 0, 0xC0, 0, 0, 0, 0, 0, 0, 0x46);

private static readonly ConditionalWeakTable<object, NativeObjectWrapper> s_rcwTable = new ConditionalWeakTable<object, NativeObjectWrapper>();
private static readonly List<GCHandle> s_referenceTrackerNativeObjectWrapperCache = new List<GCHandle>();
private static readonly GCHandleSet s_referenceTrackerNativeObjectWrapperCache = new GCHandleSet();

private readonly ConditionalWeakTable<object, ManagedObjectWrapperHolder> _ccwTable = new ConditionalWeakTable<object, ManagedObjectWrapperHolder>();
private readonly Lock _lock = new Lock();
Expand Down Expand Up @@ -1011,10 +1011,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
throw new NotSupportedException();
}
_rcwCache.Add(identity, wrapper._proxyHandle);
if (wrapper is ReferenceTrackerNativeObjectWrapper referenceTrackerNativeObjectWrapper)
{
s_referenceTrackerNativeObjectWrapperCache.Add(referenceTrackerNativeObjectWrapper._nativeObjectWrapperWeakHandle);
}
AddWrapperToReferenceTrackerHandleCache(wrapper);
return true;
}
}
Expand All @@ -1040,10 +1037,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
wrapper.Release();
throw new NotSupportedException();
}
if (wrapper is ReferenceTrackerNativeObjectWrapper referenceTrackerNativeObjectWrapper)
{
s_referenceTrackerNativeObjectWrapperCache.Add(referenceTrackerNativeObjectWrapper._nativeObjectWrapperWeakHandle);
}
AddWrapperToReferenceTrackerHandleCache(wrapper);
return true;
}

Expand Down Expand Up @@ -1079,17 +1073,22 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
throw new NotSupportedException();
}
_rcwCache.Add(identity, wrapper._proxyHandle);
if (wrapper is ReferenceTrackerNativeObjectWrapper referenceTrackerNativeObjectWrapper)
{
s_referenceTrackerNativeObjectWrapperCache.Add(referenceTrackerNativeObjectWrapper._nativeObjectWrapperWeakHandle);
}
AddWrapperToReferenceTrackerHandleCache(wrapper);
}
}

return true;
}
#pragma warning restore IDE0060

private static void AddWrapperToReferenceTrackerHandleCache(NativeObjectWrapper wrapper)
{
if (wrapper is ReferenceTrackerNativeObjectWrapper referenceTrackerNativeObjectWrapper)
{
s_referenceTrackerNativeObjectWrapperCache.Add(referenceTrackerNativeObjectWrapper._nativeObjectWrapperWeakHandle);
}
}

private void RemoveRCWFromCache(IntPtr comPointer, GCHandle expectedValue)
{
using (LockHolder.Hold(_lock))
Expand Down Expand Up @@ -1220,17 +1219,30 @@ internal static void ReleaseExternalObjectsFromCurrentThread()
IntPtr contextToken = GetContextToken();

List<object> objects = new List<object>();
foreach (GCHandle weakNativeObjectWrapperHandle in s_referenceTrackerNativeObjectWrapperCache)

// Here we aren't part of a GC callback, so other threads can still be running
// who are adding and removing from the collection. This means we can possibly race
// with a handle being removed and freed and we can end up accessing a freed handle.
// To avoid this, we take a lock on modifications to the collection while we gather
// the objects.
using (s_referenceTrackerNativeObjectWrapperCache.ModificationLock.EnterScope())
{
ReferenceTrackerNativeObjectWrapper? nativeObjectWrapper = Unsafe.As<ReferenceTrackerNativeObjectWrapper?>(weakNativeObjectWrapperHandle.Target);
if (nativeObjectWrapper != null &&
nativeObjectWrapper._contextToken == contextToken)
foreach (GCHandle weakNativeObjectWrapperHandle in s_referenceTrackerNativeObjectWrapperCache)
{
objects.Add(nativeObjectWrapper._proxyHandle.Target);
ReferenceTrackerNativeObjectWrapper? nativeObjectWrapper = Unsafe.As<ReferenceTrackerNativeObjectWrapper?>(weakNativeObjectWrapperHandle.Target);
if (nativeObjectWrapper != null &&
nativeObjectWrapper._contextToken == contextToken)
{
object? target = nativeObjectWrapper._proxyHandle.Target;
if (target != null)
{
objects.Add(target);
}

// Separate the wrapper from the tracker runtime prior to
// passing them.
nativeObjectWrapper.DisconnectTracker();
// Separate the wrapper from the tracker runtime prior to
// passing them.
nativeObjectWrapper.DisconnectTracker();
}
}
}

Expand Down Expand Up @@ -1630,4 +1642,212 @@ private static unsafe IntPtr ObjectToComWeakRef(object target, out long wrapperI
return IntPtr.Zero;
}
}

// This is a GCHandle HashSet implementation based on LowLevelDictionary.
// It uses no locking for readers. While for writers (add / remove),
// it handles the locking itself.
// This implementation specifically makes sure that any readers of this
// collection during GC aren't impacted by other threads being
// frozen while in the middle of an write. It makes no guarantees on
// whether you will observe the element being added / removed, but does
// make sure the collection is in a good state and doesn't run into issues
// while iterating.
internal sealed class GCHandleSet : IEnumerable<GCHandle>
{
private const int DefaultSize = 7;

private Entry?[] _buckets = new Entry[DefaultSize];
private int _numEntries;
private readonly Lock _lock = new Lock(useTrivialWaits: true);

public Lock ModificationLock => _lock;

public void Add(GCHandle handle)
{
using (_lock.EnterScope())
{
int bucket = GetBucket(handle, _buckets.Length);
Entry? prev = null;
Entry? entry = _buckets[bucket];
while (entry != null)
{
// Handle already exists, nothing to add.
if (handle.Equals(entry.m_value))
{
return;
}

prev = entry;
entry = entry.m_next;
}

Entry newEntry = new Entry()
{
m_value = handle
};

if (prev == null)
{
_buckets[bucket] = newEntry;
}
else
{
prev.m_next = newEntry;
}

// _numEntries is only maintained for the purposes of deciding whether to
// expand the bucket and is not used during iteration to handle the
// scenario where element is in bucket but _numEntries hasn't been incremented
// yet.
_numEntries++;
if (_numEntries > (_buckets.Length * 2))
{
ExpandBuckets();
}
}
}

private void ExpandBuckets()
{
int newNumBuckets = _buckets.Length * 2 + 1;
Entry?[] newBuckets = new Entry[newNumBuckets];
for (int i = 0; i < _buckets.Length; i++)
{
Entry? entry = _buckets[i];
while (entry != null)
{
Entry? nextEntry = entry.m_next;

int bucket = GetBucket(entry.m_value, newNumBuckets);

// We are allocating new entries for the bucket to ensure that
// if there is an enumeration already in progress, we don't
// modify what it observes by changing next in existing instances.
Entry newEntry = new Entry()
{
m_value = entry.m_value,
m_next = newBuckets[bucket],
};
newBuckets[bucket] = newEntry;

entry = nextEntry;
}
}
_buckets = newBuckets;
}

public void Remove(GCHandle handle)
{
using (_lock.EnterScope())
{
int bucket = GetBucket(handle, _buckets.Length);
Entry? prev = null;
Entry? entry = _buckets[bucket];
while (entry != null)
{
if (handle.Equals(entry.m_value))
{
if (prev == null)
{
_buckets[bucket] = entry.m_next;
}
else
{
prev.m_next = entry.m_next;
}
_numEntries--;
return;
}

prev = entry;
entry = entry.m_next;
}
}
}

private static int GetBucket(GCHandle handle, int numBuckets)
{
int h = handle.GetHashCode();
return (int)((uint)h % (uint)numBuckets);
}

public Enumerator GetEnumerator() => new Enumerator(this);

IEnumerator<GCHandle> IEnumerable<GCHandle>.GetEnumerator() => GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable<GCHandle>)this).GetEnumerator();

private sealed class Entry
{
public GCHandle m_value;
public Entry? m_next;
}

public struct Enumerator : IEnumerator<GCHandle>
{
private readonly Entry?[] _buckets;
private int _currentIdx;
private Entry? _currentEntry;

public Enumerator(GCHandleSet set)
{
// We hold onto the buckets of the set rather than the set itself
// so that if it is ever expanded, we are not impacted by that during
// enumeration.
_buckets = set._buckets;
Reset();
}

public GCHandle Current
{
get
{
if (_currentEntry == null)
{
throw new InvalidOperationException("InvalidOperation_EnumOpCantHappen");
}

return _currentEntry.m_value;
}
}

object IEnumerator.Current => Current;

public void Dispose()
{
}

public bool MoveNext()
{
if (_currentEntry != null)
{
_currentEntry = _currentEntry.m_next;
}

if (_currentEntry == null)
{
// Certain buckets might be empty, so loop until we find
// one with an entry.
while (++_currentIdx != _buckets.Length)
{
_currentEntry = _buckets[_currentIdx];
if (_currentEntry != null)
{
return true;
}
}

return false;
}

return true;
}

public void Reset()
{
_currentIdx = -1;
_currentEntry = null;
}
}
}
}

0 comments on commit 5dc834c

Please sign in to comment.