From 8ac1b70e45fbdd4285e57d12b13e38ad2756bdf8 Mon Sep 17 00:00:00 2001 From: mellamokb Date: Wed, 23 Aug 2023 10:58:15 -0400 Subject: [PATCH] New thread-safe Item Collection Fixes #95 Added new ThreadSafeSessionStateItemCollection to replace built-in Microsoft one. This fixes the race condition in the indexers. Also commented out all serialization since this is intended to only be used with the in-memory session, so serialization is not necessary. --- .../InProcSessionStateStoreAsync.cs | 2 +- .../ThreadSafeSessionStateItemCollection.cs | 297 ++++++++++++++++++ 2 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 src/SessionStateModule/ThreadSafeSessionStateItemCollection.cs diff --git a/src/SessionStateModule/InProcSessionStateStoreAsync.cs b/src/SessionStateModule/InProcSessionStateStoreAsync.cs index 5f43c14..1227b73 100644 --- a/src/SessionStateModule/InProcSessionStateStoreAsync.cs +++ b/src/SessionStateModule/InProcSessionStateStoreAsync.cs @@ -437,7 +437,7 @@ private SessionStateStoreData CreateLegitStoreData( { if (sessionItems == null) { - sessionItems = new SessionStateItemCollection(); + sessionItems = new ThreadSafeSessionStateItemCollection(); } if (staticObjects == null && context != null) diff --git a/src/SessionStateModule/ThreadSafeSessionStateItemCollection.cs b/src/SessionStateModule/ThreadSafeSessionStateItemCollection.cs new file mode 100644 index 0000000..1d8cfa5 --- /dev/null +++ b/src/SessionStateModule/ThreadSafeSessionStateItemCollection.cs @@ -0,0 +1,297 @@ +using System; +using System.Collections; +using System.Collections.Specialized; +using System.Globalization; +using System.IO; +using HttpRuntime = System.Web.HttpRuntime; +using ISessionStateItemCollection = System.Web.SessionState.ISessionStateItemCollection; + +namespace Microsoft.AspNet.SessionState +{ + /// A collection of objects stored in session state. This class cannot be inherited. + public sealed class ThreadSafeSessionStateItemCollection : NameObjectCollectionBase, ISessionStateItemCollection, ICollection, IEnumerable + { + private static Hashtable s_immutableTypes; + + private const int NO_NULL_KEY = -1; + + private const int SIZE_OF_INT32 = 4; + + private bool _dirty; + + private KeyedCollection _serializedItems; + + private Stream _stream; + + private int _iLastOffset; + + private object _serializedItemsLock = new object(); + + /// Gets or sets a value indicating whether the collection has been marked as changed. + /// true if the contents have been changed; otherwise, false. + public bool Dirty + { + get + { + return this._dirty; + } + set + { + this._dirty = value; + } + } + + /// Gets or sets a value in the collection by name. + /// The value in the collection with the specified name. If the specified key is not found, attempting to get it returns null, and attempting to set it creates a new element using the specified key. + /// The key name of the value in the collection. + public object this[string name] + { + get + { + lock (this._serializedItemsLock) + { + //this.DeserializeItem(name, true); + object obj = base.BaseGet(name); + if (obj != null && !IsImmutable(obj)) + { + this._dirty = true; + } + return obj; + } + } + set + { + lock (this._serializedItemsLock) + { + //this.MarkItemDeserialized(name); + base.BaseSet(name, value); + this._dirty = true; + } + } + } + + /// Gets or sets a value in the collection by numerical index. + /// The value in the collection stored at the specified index. If the specified key is not found, attempting to get it returns null, and attempting to set it creates a new element using the specified key. + /// The numerical index of the value in the collection. + public object this[int index] + { + get + { + lock (this._serializedItemsLock) + { + //this.DeserializeItem(index); + object obj = base.BaseGet(index); + if (obj != null && !IsImmutable(obj)) + { + this._dirty = true; + } + return obj; + } + } + set + { + lock (this._serializedItemsLock) + { + //this.MarkItemDeserialized(index); + base.BaseSet(index, value); + this._dirty = true; + } + } + } + + /// Gets a collection of the variable names for all values stored in the collection. + /// The collection that contains all the collection keys. + public override NameObjectCollectionBase.KeysCollection Keys + { + get + { + //this.DeserializeAllItems(); + return base.Keys; + } + } + + static ThreadSafeSessionStateItemCollection() + { + s_immutableTypes = new Hashtable(19); + Type type = typeof(string); + s_immutableTypes.Add(type, type); + type = typeof(int); + s_immutableTypes.Add(type, type); + type = typeof(bool); + s_immutableTypes.Add(type, type); + type = typeof(DateTime); + s_immutableTypes.Add(type, type); + type = typeof(decimal); + s_immutableTypes.Add(type, type); + type = typeof(byte); + s_immutableTypes.Add(type, type); + type = typeof(char); + s_immutableTypes.Add(type, type); + type = typeof(float); + s_immutableTypes.Add(type, type); + type = typeof(double); + s_immutableTypes.Add(type, type); + type = typeof(sbyte); + s_immutableTypes.Add(type, type); + type = typeof(short); + s_immutableTypes.Add(type, type); + type = typeof(long); + s_immutableTypes.Add(type, type); + type = typeof(ushort); + s_immutableTypes.Add(type, type); + type = typeof(uint); + s_immutableTypes.Add(type, type); + type = typeof(ulong); + s_immutableTypes.Add(type, type); + type = typeof(TimeSpan); + s_immutableTypes.Add(type, type); + type = typeof(Guid); + s_immutableTypes.Add(type, type); + type = typeof(IntPtr); + s_immutableTypes.Add(type, type); + type = typeof(UIntPtr); + s_immutableTypes.Add(type, type); + } + + /// Creates a new, empty object. + public ThreadSafeSessionStateItemCollection() : base(Misc.CaseInsensitiveInvariantKeyComparer) + { + } + + /// Removes all values and keys from the session-state collection. + public void Clear() + { + lock (this._serializedItemsLock) + { + if (this._serializedItems != null) + { + this._serializedItems.Clear(); + } + base.BaseClear(); + this._dirty = true; + } + } + + /// Returns an enumerator that can be used to read all the key names in the collection. + /// An that can iterate through the variable names in the session-state collection. + public override IEnumerator GetEnumerator() + { + //this.DeserializeAllItems(); + return base.GetEnumerator(); + } + + internal static bool IsImmutable(object o) + { + return s_immutableTypes[o.GetType()] != null; + } + + /// Deletes an item from the collection. + /// The name of the item to delete from the collection. + public void Remove(string name) + { + lock (this._serializedItemsLock) + { + if (this._serializedItems != null) + { + this._serializedItems.Remove(name); + } + base.BaseRemove(name); + this._dirty = true; + } + } + + /// Deletes an item at a specified index from the collection. + /// The index of the item to remove from the collection. + /// + /// is less than zero.- or - is equal to or greater than . + public void RemoveAt(int index) + { + lock (this._serializedItemsLock) + { + if (this._serializedItems != null && index < this._serializedItems.Count) + { + this._serializedItems.RemoveAt(index); + } + base.BaseRemoveAt(index); + this._dirty = true; + } + } + + private class KeyedCollection : NameObjectCollectionBase + { + internal object this[string name] + { + get + { + return base.BaseGet(name); + } + set + { + if (base.BaseGet(name) == null && value == null) + { + return; + } + base.BaseSet(name, value); + } + } + + internal object this[int index] + { + get + { + return base.BaseGet(index); + } + } + + internal KeyedCollection(int count) : base(count, Misc.CaseInsensitiveInvariantKeyComparer) + { + } + + internal void Clear() + { + base.BaseClear(); + } + + internal bool ContainsKey(string name) + { + return base.BaseGet(name) != null; + } + + internal string GetKey(int index) + { + return base.BaseGetKey(index); + } + + internal void Remove(string name) + { + base.BaseRemove(name); + } + + internal void RemoveAt(int index) + { + base.BaseRemoveAt(index); + } + } + + internal sealed class Misc + { + private static StringComparer s_caseInsensitiveInvariantKeyComparer; + + internal static StringComparer CaseInsensitiveInvariantKeyComparer + { + get + { + if (Misc.s_caseInsensitiveInvariantKeyComparer == null) + { + Misc.s_caseInsensitiveInvariantKeyComparer = StringComparer.Create(CultureInfo.InvariantCulture, true); + } + return Misc.s_caseInsensitiveInvariantKeyComparer; + } + } + + public Misc() + { + } + } + } +} \ No newline at end of file