From 25edaffbf175e272a9ac931a5356bc3c8dd18212 Mon Sep 17 00:00:00 2001 From: Ingvar Stepanyan Date: Tue, 21 May 2024 16:49:46 +0100 Subject: [PATCH] Simplify/optimise parts with new equality rules in mind --- src/ClientCache.cs | 113 +++++-------------------------------- src/SATS/AlgebraicValue.cs | 29 ---------- src/SpacetimeDBClient.cs | 61 ++++++++------------ 3 files changed, 37 insertions(+), 166 deletions(-) diff --git a/src/ClientCache.cs b/src/ClientCache.cs index 93f82577..fb66925f 100644 --- a/src/ClientCache.cs +++ b/src/ClientCache.cs @@ -3,8 +3,7 @@ using System.Collections.Generic; using System.Reflection; using SpacetimeDB.SATS; -using System.Numerics; -using System.Runtime.CompilerServices; +using System.Linq; namespace SpacetimeDB { @@ -12,7 +11,6 @@ public class ClientCache { public class TableCache { - private readonly string name; private readonly Type clientTableType; private readonly AlgebraicType rowSchema; @@ -20,7 +18,7 @@ public class TableCache private Func decoderFunc; // Maps from primary key to type value - public readonly Dictionary entries; + public readonly Dictionary entries = new(ByteArrayComparer.Instance); public Type ClientTableType { @@ -33,14 +31,7 @@ public Type ClientTableType public Action BeforeDeleteCallback; public Action DeleteCallback; public Action UpdateCallback; - public Func ComparePrimaryKeyFunc; - public Func GetPrimaryKeyValueFunc; - public Func GetPrimaryKeyTypeFunc; - - public string Name - { - get => name; - } + public Func GetPrimaryKeyValueFunc; public AlgebraicType RowSchema { @@ -49,7 +40,6 @@ public AlgebraicType RowSchema public TableCache(Type clientTableType, AlgebraicType rowSchema, Func decoderFunc) { - name = clientTableType.Name; this.clientTableType = clientTableType; this.rowSchema = rowSchema; @@ -60,13 +50,8 @@ public TableCache(Type clientTableType, AlgebraicType rowSchema, Func)clientTableType.GetMethod("OnBeforeDeleteEvent")?.CreateDelegate(typeof(Action)); DeleteCallback = (Action)clientTableType.GetMethod("OnDeleteEvent")?.CreateDelegate(typeof(Action)); UpdateCallback = (Action)clientTableType.GetMethod("OnUpdateEvent")?.CreateDelegate(typeof(Action)); - ComparePrimaryKeyFunc = (Func)clientTableType.GetMethod("ComparePrimaryKey", BindingFlags.Static | BindingFlags.Public) - ?.CreateDelegate(typeof(Func)); - GetPrimaryKeyValueFunc = (Func)clientTableType.GetMethod("GetPrimaryKeyValue", BindingFlags.Static | BindingFlags.Public) - ?.CreateDelegate(typeof(Func)); - GetPrimaryKeyTypeFunc = (Func)clientTableType.GetMethod("GetPrimaryKeyType", BindingFlags.Static | BindingFlags.Public) - ?.CreateDelegate(typeof(Func)); - entries = new Dictionary(new ByteArrayComparer()); + GetPrimaryKeyValueFunc = (Func)clientTableType.GetMethod("GetPrimaryKeyValue", BindingFlags.NonPublic | BindingFlags.Static) + ?.CreateDelegate(typeof(Func)); } /// @@ -83,19 +68,9 @@ public void SetAndForgetDecodedValue(AlgebraicValue value, out object obj) /// Inserts the value into the table. There can be no existing value with the provided BSATN bytes. /// /// The BSATN encoded bytes of the row to retrieve. - /// The parsed AlgebraicValue of the row encoded by the rowBytes. + /// The parsed row encoded by the rowBytes. /// True if the row was inserted, false if the row wasn't inserted because it was a duplicate. - public bool InsertEntry(byte[] rowBytes, AlgebraicValue value) - { - if (entries.ContainsKey(rowBytes)) - { - return false; - } - - // Insert the row into our table - entries[rowBytes] = (value, decoderFunc(value)); - return true; - } + public bool InsertEntry(byte[] rowBytes, object value) => entries.TryAdd(rowBytes, value); /// /// Deletes a value from the table. @@ -104,9 +79,8 @@ public bool InsertEntry(byte[] rowBytes, AlgebraicValue value) /// True if and only if the value was previously resident and has been deleted. public bool DeleteEntry(byte[] rowBytes) { - if (entries.TryGetValue(rowBytes, out var value)) + if (entries.Remove(rowBytes)) { - entries.Remove(rowBytes); return true; } @@ -114,37 +88,9 @@ public bool DeleteEntry(byte[] rowBytes) return false; } - /// - /// Gets a value from the table - /// - /// The BSATN encoded bytes of the row to retrieve. - /// Output: the parsed domain type corresponding to the rowBytes, or null if the row was not present in the cache. - /// True if and only if the value is resident and was stored in value. - public bool TryGetValue(byte[] rowBytes, out object value) - { - if (entries.TryGetValue(rowBytes, out var v)) - { - value = v.Item2; - return true; - } - - value = null; - return false; - } - - public bool ComparePrimaryKey(AlgebraicValue v1, AlgebraicValue v2) - { - return (bool)ComparePrimaryKeyFunc.Invoke(rowSchema, v1, v2); - } - - public AlgebraicValue GetPrimaryKeyValue(AlgebraicValue row) - { - return GetPrimaryKeyValueFunc != null ? GetPrimaryKeyValueFunc.Invoke(row) : null; - } - - public AlgebraicType GetPrimaryKeyType() + public object? GetPrimaryKeyValue(object row) { - return GetPrimaryKeyTypeFunc != null ? GetPrimaryKeyTypeFunc.Invoke(rowSchema) : null; + return GetPrimaryKeyValueFunc?.Invoke(row); } } @@ -165,33 +111,7 @@ public void AddTable(Type clientTableType, AlgebraicType tableRowDef, Func GetObjects(string name) - { - if (!tables.TryGetValue(name, out var table)) - { - yield break; - } - - foreach (var entry in table.entries) - { - yield return entry.Value.Item2; - } - } - - public IEnumerable<(AlgebraicValue, object)> GetEntries(string name) - { - if (!tables.TryGetValue(name, out var table)) - { - yield break; - } - - foreach (var entry in table.entries) - { - yield return entry.Value; - } - } - - public TableCache GetTable(string name) + public TableCache? GetTable(string name) { if (tables.TryGetValue(name, out var table)) { @@ -202,17 +122,12 @@ public TableCache GetTable(string name) return null; } - public int Count(string name) + public IEnumerable GetObjects(string name) { - if (!tables.TryGetValue(name, out var table)) - { - return 0; - } - - return table.entries.Count; + return GetTable(name)?.entries.Values ?? Enumerable.Empty(); } - public IEnumerable GetTableNames() => tables.Keys; + public int Count(string name) => GetTable(name)?.entries.Count ?? 0; public IEnumerable GetTables() => tables.Values; } diff --git a/src/SATS/AlgebraicValue.cs b/src/SATS/AlgebraicValue.cs index 9d78de35..594d15a6 100644 --- a/src/SATS/AlgebraicValue.cs +++ b/src/SATS/AlgebraicValue.cs @@ -1,5 +1,4 @@ using System; -using System.Linq; using System.Collections.Generic; using System.IO; @@ -476,33 +475,5 @@ public void Serialize(AlgebraicType type, BinaryWriter writer) throw new NotImplementedException(); } } - - public class AlgebraicValueComparer : IEqualityComparer - { - private AlgebraicType type; - public AlgebraicValueComparer(AlgebraicType type) - { - this.type = type; - } - - public bool Equals(AlgebraicValue l, AlgebraicValue r) - { - return AlgebraicValue.Compare(type, l, r); - } - - public int GetHashCode(AlgebraicValue value) - { - var stream = new MemoryStream(); - var writer = new BinaryWriter(stream); - value.Serialize(type, writer); - var s = stream.ToArray(); - if (s.Length >= 4) - { - return BitConverter.ToInt32(s, 0); - } - return s.Sum(b => b); - } - } - } } diff --git a/src/SpacetimeDBClient.cs b/src/SpacetimeDBClient.cs index c7dade65..ea6b3489 100644 --- a/src/SpacetimeDBClient.cs +++ b/src/SpacetimeDBClient.cs @@ -46,7 +46,6 @@ struct DbOp public ClientCache.TableCache table; public DbValue? delete; public DbValue? insert; - public AlgebraicValue rowValue; } /// @@ -215,7 +214,7 @@ struct PreProcessedMessage { public Message message; public List dbOps; - public Dictionary> inserts; + public Dictionary> inserts; } private readonly BlockingCollection _messageQueue = @@ -254,27 +253,16 @@ PreProcessedMessage PreProcessMessage(byte[] bytes) using var reader = new BinaryReader(stream); // This is all of the inserts - Dictionary> subscriptionInserts = null; + Dictionary> subscriptionInserts = null; // All row updates that have a primary key, this contains inserts, deletes and updates - var primaryKeyChanges = new Dictionary>(); + var primaryKeyChanges = new Dictionary<(System.Type tableType, object primaryKeyValue), DbOp>(); - Dictionary GetPrimaryKeyLookup(string tableName, AlgebraicType schema) + HashSet GetInsertHashSet(System.Type tableType, int tableSize) { - if (!primaryKeyChanges.TryGetValue(tableName, out var value)) - { - value = new Dictionary(new AlgebraicValue.AlgebraicValueComparer(schema)); - primaryKeyChanges[tableName] = value; - } - - return value; - } - - HashSet GetInsertHashSet(string tableName, int tableSize) - { - if (!subscriptionInserts.TryGetValue(tableName, out var hashSet)) + if (!subscriptionInserts.TryGetValue(tableType, out var hashSet)) { hashSet = new HashSet(capacity: tableSize, comparer: new ByteArrayComparer()); - subscriptionInserts[tableName] = hashSet; + subscriptionInserts[tableType] = hashSet; } return hashSet; @@ -285,13 +273,11 @@ HashSet GetInsertHashSet(string tableName, int tableSize) { case ClientApi.Message.TypeOneofCase.SubscriptionUpdate: subscriptionUpdate = message.SubscriptionUpdate; - subscriptionInserts = new Dictionary>( - capacity: subscriptionUpdate.TableUpdates.Sum(a => a.TableRowOperations.Count)); + subscriptionInserts = new(capacity: subscriptionUpdate.TableUpdates.Sum(a => a.TableRowOperations.Count)); // First apply all of the state foreach (var update in subscriptionUpdate.TableUpdates) { var tableName = update.TableName; - var hashSet = GetInsertHashSet(tableName, subscriptionUpdate.TableUpdates.Count); var table = clientDB.GetTable(tableName); if (table == null) { @@ -299,6 +285,8 @@ HashSet GetInsertHashSet(string tableName, int tableSize) continue; } + var hashSet = GetInsertHashSet(table.ClientTableType, subscriptionUpdate.TableUpdates.Count); + foreach (var row in update.TableRowOperations) { var rowBytes = row.Row.ToByteArray(); @@ -323,7 +311,6 @@ HashSet GetInsertHashSet(string tableName, int tableSize) { table = table, insert = new(obj, rowBytes), - rowValue = deserializedRow, }; if (!hashSet.Add(rowBytes)) @@ -365,15 +352,10 @@ HashSet GetInsertHashSet(string tableName, int tableSize) throw new Exception("Failed to deserialize row"); } - var primaryKeyValue = table.GetPrimaryKeyValue(deserializedRow); - var primaryKeyType = table.GetPrimaryKeyType(); table.SetAndForgetDecodedValue(deserializedRow, out var obj); + var primaryKeyValue = table.GetPrimaryKeyValue(obj); - var op = new DbOp - { - table = table, - rowValue = deserializedRow, - }; + var op = new DbOp { table = table }; var dbValue = new DbValue(obj, rowBytes); @@ -386,10 +368,13 @@ HashSet GetInsertHashSet(string tableName, int tableSize) op.delete = dbValue; } - if (primaryKeyType != null) + if (primaryKeyValue != null) { - var primaryKeyLookup = GetPrimaryKeyLookup(tableName, primaryKeyType); - if (primaryKeyLookup.TryGetValue(primaryKeyValue, out var oldOp)) + // Compound key that we use for lookup. + // Consists of type of the table (for faster comparison that string names) + actual primary key of the row. + var key = (table.ClientTableType, primaryKeyValue); + + if (primaryKeyChanges.TryGetValue(key, out var oldOp)) { if ((op.insert is not null && oldOp.insert is not null) || (op.delete is not null && oldOp.delete is not null)) { @@ -406,10 +391,10 @@ HashSet GetInsertHashSet(string tableName, int tableSize) table = insertOp.table, delete = deleteOp.delete, insert = insertOp.insert, - rowValue = insertOp.rowValue, }; } - primaryKeyLookup[primaryKeyValue] = op; + + primaryKeyChanges[key] = op; } else { @@ -419,7 +404,7 @@ HashSet GetInsertHashSet(string tableName, int tableSize) } // Combine primary key updates and non-primary key updates - dbOps.AddRange(primaryKeyChanges.Values.SelectMany(a => a.Values)); + dbOps.AddRange(primaryKeyChanges.Values); // Convert the generic event arguments in to a domain specific event object, this gets fed back into // the message.TransactionUpdate.Event.FunctionCall.CallInfo field. @@ -497,7 +482,7 @@ void ExecuteStateDiff() { foreach (var table in clientDB.GetTables()) { - if (!preProcessedMessage.inserts.TryGetValue(table.Name, out var hashSet)) + if (!preProcessedMessage.inserts.TryGetValue(table.ClientTableType, out var hashSet)) { continue; } @@ -509,7 +494,7 @@ void ExecuteStateDiff() table = table, // This is a row that we had before, but we do not have it now. // This must have been a delete. - delete = new(oldValue.Item2, rowBytes), + delete = new(oldValue, rowBytes), }); } } @@ -607,7 +592,7 @@ private void OnMessageProcessCompleteUpdate(Message message, List dbOps) if (update.insert is {} insert) { - if (update.table.InsertEntry(insert.bytes, update.rowValue)) + if (update.table.InsertEntry(insert.bytes, insert.value)) { update.table.InternalValueInsertedCallback(insert.value); }