diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs index f8547cb408f8d2..cc33e85d73917c 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs @@ -1,7 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; +using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -41,7 +41,7 @@ public abstract partial class ComWrappers private static readonly Guid IID_IWeakReferenceSource = new Guid(0x00000038, 0, 0, 0xC0, 0, 0, 0, 0, 0, 0, 0x46); private static readonly ConditionalWeakTable s_rcwTable = new ConditionalWeakTable(); - private static readonly List s_referenceTrackerNativeObjectWrapperCache = new List(); + private static readonly GCHandleSet s_referenceTrackerNativeObjectWrapperCache = new GCHandleSet(); private readonly ConditionalWeakTable _ccwTable = new ConditionalWeakTable(); private readonly Lock _lock = new Lock(useTrivialWaits: true); @@ -984,10 +984,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( throw new NotSupportedException(); } _rcwCache.Add(identity, wrapper._proxyHandle); - if (wrapper is ReferenceTrackerNativeObjectWrapper referenceTrackerNativeObjectWrapper) - { - s_referenceTrackerNativeObjectWrapperCache.Add(referenceTrackerNativeObjectWrapper._nativeObjectWrapperWeakHandle); - } + AddWrapperToReferenceTrackerHandleCache(wrapper); return true; } } @@ -1040,10 +1037,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( wrapper.Release(); throw new NotSupportedException(); } - if (wrapper is ReferenceTrackerNativeObjectWrapper referenceTrackerNativeObjectWrapper) - { - s_referenceTrackerNativeObjectWrapperCache.Add(referenceTrackerNativeObjectWrapper._nativeObjectWrapperWeakHandle); - } + AddWrapperToReferenceTrackerHandleCache(wrapper); return true; } @@ -1079,10 +1073,7 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( throw new NotSupportedException(); } _rcwCache.Add(identity, wrapper._proxyHandle); - if (wrapper is ReferenceTrackerNativeObjectWrapper referenceTrackerNativeObjectWrapper) - { - s_referenceTrackerNativeObjectWrapperCache.Add(referenceTrackerNativeObjectWrapper._nativeObjectWrapperWeakHandle); - } + AddWrapperToReferenceTrackerHandleCache(wrapper); } } @@ -1090,6 +1081,14 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( } #pragma warning restore IDE0060 + private static void AddWrapperToReferenceTrackerHandleCache(NativeObjectWrapper wrapper) + { + if (wrapper is ReferenceTrackerNativeObjectWrapper referenceTrackerNativeObjectWrapper) + { + s_referenceTrackerNativeObjectWrapperCache.Add(referenceTrackerNativeObjectWrapper._nativeObjectWrapperWeakHandle); + } + } + private void RemoveRCWFromCache(IntPtr comPointer, GCHandle expectedValue) { using (_lock.EnterScope()) @@ -1220,17 +1219,30 @@ internal static void ReleaseExternalObjectsFromCurrentThread() IntPtr contextToken = GetContextToken(); List objects = new List(); - foreach (GCHandle weakNativeObjectWrapperHandle in s_referenceTrackerNativeObjectWrapperCache) + + // Here we aren't part of a GC callback, so other threads can still be running + // who are adding and removing from the collection. This means we can possibly race + // with a handle being removed and freed and we can end up accessing a freed handle. + // To avoid this, we take a lock on modifications to the collection while we gather + // the objects. + using (s_referenceTrackerNativeObjectWrapperCache.ModificationLock.EnterScope()) { - ReferenceTrackerNativeObjectWrapper? nativeObjectWrapper = Unsafe.As(weakNativeObjectWrapperHandle.Target); - if (nativeObjectWrapper != null && - nativeObjectWrapper._contextToken == contextToken) + foreach (GCHandle weakNativeObjectWrapperHandle in s_referenceTrackerNativeObjectWrapperCache) { - objects.Add(nativeObjectWrapper._proxyHandle.Target); + ReferenceTrackerNativeObjectWrapper? nativeObjectWrapper = Unsafe.As(weakNativeObjectWrapperHandle.Target); + if (nativeObjectWrapper != null && + nativeObjectWrapper._contextToken == contextToken) + { + object? target = nativeObjectWrapper._proxyHandle.Target; + if (target != null) + { + objects.Add(target); + } - // Separate the wrapper from the tracker runtime prior to - // passing them. - nativeObjectWrapper.DisconnectTracker(); + // Separate the wrapper from the tracker runtime prior to + // passing them. + nativeObjectWrapper.DisconnectTracker(); + } } } @@ -1630,4 +1642,212 @@ private static unsafe IntPtr ObjectToComWeakRef(object target, out long wrapperI return IntPtr.Zero; } } + + // This is a GCHandle HashSet implementation based on LowLevelDictionary. + // It uses no locking for readers. While for writers (add / remove), + // it handles the locking itself. + // This implementation specifically makes sure that any readers of this + // collection during GC aren't impacted by other threads being + // frozen while in the middle of an write. It makes no guarantees on + // whether you will observe the element being added / removed, but does + // make sure the collection is in a good state and doesn't run into issues + // while iterating. + internal sealed class GCHandleSet : IEnumerable + { + private const int DefaultSize = 7; + + private Entry?[] _buckets = new Entry[DefaultSize]; + private int _numEntries; + private readonly Lock _lock = new Lock(useTrivialWaits: true); + + public Lock ModificationLock => _lock; + + public void Add(GCHandle handle) + { + using (_lock.EnterScope()) + { + int bucket = GetBucket(handle, _buckets.Length); + Entry? prev = null; + Entry? entry = _buckets[bucket]; + while (entry != null) + { + // Handle already exists, nothing to add. + if (handle.Equals(entry.m_value)) + { + return; + } + + prev = entry; + entry = entry.m_next; + } + + Entry newEntry = new Entry() + { + m_value = handle + }; + + if (prev == null) + { + _buckets[bucket] = newEntry; + } + else + { + prev.m_next = newEntry; + } + + // _numEntries is only maintained for the purposes of deciding whether to + // expand the bucket and is not used during iteration to handle the + // scenario where element is in bucket but _numEntries hasn't been incremented + // yet. + _numEntries++; + if (_numEntries > (_buckets.Length * 2)) + { + ExpandBuckets(); + } + } + } + + private void ExpandBuckets() + { + int newNumBuckets = _buckets.Length * 2 + 1; + Entry?[] newBuckets = new Entry[newNumBuckets]; + for (int i = 0; i < _buckets.Length; i++) + { + Entry? entry = _buckets[i]; + while (entry != null) + { + Entry? nextEntry = entry.m_next; + + int bucket = GetBucket(entry.m_value, newNumBuckets); + + // We are allocating new entries for the bucket to ensure that + // if there is an enumeration already in progress, we don't + // modify what it observes by changing next in existing instances. + Entry newEntry = new Entry() + { + m_value = entry.m_value, + m_next = newBuckets[bucket], + }; + newBuckets[bucket] = newEntry; + + entry = nextEntry; + } + } + _buckets = newBuckets; + } + + public void Remove(GCHandle handle) + { + using (_lock.EnterScope()) + { + int bucket = GetBucket(handle, _buckets.Length); + Entry? prev = null; + Entry? entry = _buckets[bucket]; + while (entry != null) + { + if (handle.Equals(entry.m_value)) + { + if (prev == null) + { + _buckets[bucket] = entry.m_next; + } + else + { + prev.m_next = entry.m_next; + } + _numEntries--; + return; + } + + prev = entry; + entry = entry.m_next; + } + } + } + + private static int GetBucket(GCHandle handle, int numBuckets) + { + int h = handle.GetHashCode(); + return (int)((uint)h % (uint)numBuckets); + } + + public Enumerator GetEnumerator() => new Enumerator(this); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); + + private sealed class Entry + { + public GCHandle m_value; + public Entry? m_next; + } + + public struct Enumerator : IEnumerator + { + private readonly Entry?[] _buckets; + private int _currentIdx; + private Entry? _currentEntry; + + public Enumerator(GCHandleSet set) + { + // We hold onto the buckets of the set rather than the set itself + // so that if it is ever expanded, we are not impacted by that during + // enumeration. + _buckets = set._buckets; + Reset(); + } + + public GCHandle Current + { + get + { + if (_currentEntry == null) + { + throw new InvalidOperationException("InvalidOperation_EnumOpCantHappen"); + } + + return _currentEntry.m_value; + } + } + + object IEnumerator.Current => Current; + + public void Dispose() + { + } + + public bool MoveNext() + { + if (_currentEntry != null) + { + _currentEntry = _currentEntry.m_next; + } + + if (_currentEntry == null) + { + // Certain buckets might be empty, so loop until we find + // one with an entry. + while (++_currentIdx != _buckets.Length) + { + _currentEntry = _buckets[_currentIdx]; + if (_currentEntry != null) + { + return true; + } + } + + return false; + } + + return true; + } + + public void Reset() + { + _currentIdx = -1; + _currentEntry = null; + } + } + } }