Skip to content

Commit

Permalink
Simplify/optimise parts with new equality rules in mind
Browse files Browse the repository at this point in the history
  • Loading branch information
RReverser committed May 21, 2024
1 parent 42c5c5a commit 25edaff
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 166 deletions.
113 changes: 14 additions & 99 deletions src/ClientCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@
using System.Collections.Generic;
using System.Reflection;
using SpacetimeDB.SATS;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Linq;

namespace SpacetimeDB
{
public class ClientCache
{
public class TableCache
{
private readonly string name;
private readonly Type clientTableType;
private readonly AlgebraicType rowSchema;

// The function to use for decoding a type value
private Func<AlgebraicValue, object> decoderFunc;

// Maps from primary key to type value
public readonly Dictionary<byte[], (AlgebraicValue, object)> entries;
public readonly Dictionary<byte[], object> entries = new(ByteArrayComparer.Instance);

public Type ClientTableType
{
Expand All @@ -33,14 +31,7 @@ public Type ClientTableType
public Action<object, ClientApi.Event> BeforeDeleteCallback;
public Action<object, ClientApi.Event> DeleteCallback;
public Action<object, object, ClientApi.Event> UpdateCallback;
public Func<AlgebraicType, AlgebraicValue, AlgebraicValue, bool> ComparePrimaryKeyFunc;
public Func<AlgebraicValue, AlgebraicValue> GetPrimaryKeyValueFunc;
public Func<AlgebraicType, AlgebraicType> GetPrimaryKeyTypeFunc;

public string Name
{
get => name;
}
public Func<object, object> GetPrimaryKeyValueFunc;

public AlgebraicType RowSchema
{
Expand All @@ -49,7 +40,6 @@ public AlgebraicType RowSchema

public TableCache(Type clientTableType, AlgebraicType rowSchema, Func<AlgebraicValue, object> decoderFunc)
{
name = clientTableType.Name;
this.clientTableType = clientTableType;

this.rowSchema = rowSchema;
Expand All @@ -60,13 +50,8 @@ public TableCache(Type clientTableType, AlgebraicType rowSchema, Func<AlgebraicV
BeforeDeleteCallback = (Action<object, ClientApi.Event>)clientTableType.GetMethod("OnBeforeDeleteEvent")?.CreateDelegate(typeof(Action<object, ClientApi.Event>));
DeleteCallback = (Action<object, ClientApi.Event>)clientTableType.GetMethod("OnDeleteEvent")?.CreateDelegate(typeof(Action<object, ClientApi.Event>));
UpdateCallback = (Action<object, object, ClientApi.Event>)clientTableType.GetMethod("OnUpdateEvent")?.CreateDelegate(typeof(Action<object, object, ClientApi.Event>));
ComparePrimaryKeyFunc = (Func<AlgebraicType, AlgebraicValue, AlgebraicValue, bool>)clientTableType.GetMethod("ComparePrimaryKey", BindingFlags.Static | BindingFlags.Public)
?.CreateDelegate(typeof(Func<AlgebraicType, AlgebraicValue, AlgebraicValue, bool>));
GetPrimaryKeyValueFunc = (Func<AlgebraicValue, AlgebraicValue>)clientTableType.GetMethod("GetPrimaryKeyValue", BindingFlags.Static | BindingFlags.Public)
?.CreateDelegate(typeof(Func<AlgebraicValue, AlgebraicValue>));
GetPrimaryKeyTypeFunc = (Func<AlgebraicType, AlgebraicType>)clientTableType.GetMethod("GetPrimaryKeyType", BindingFlags.Static | BindingFlags.Public)
?.CreateDelegate(typeof(Func<AlgebraicType, AlgebraicType>));
entries = new Dictionary<byte[], (AlgebraicValue, object)>(new ByteArrayComparer());
GetPrimaryKeyValueFunc = (Func<object, object>)clientTableType.GetMethod("GetPrimaryKeyValue", BindingFlags.NonPublic | BindingFlags.Static)
?.CreateDelegate(typeof(Func<object, object>));
}

/// <summary>
Expand All @@ -83,19 +68,9 @@ public void SetAndForgetDecodedValue(AlgebraicValue value, out object obj)
/// Inserts the value into the table. There can be no existing value with the provided BSATN bytes.
/// </summary>
/// <param name="rowBytes">The BSATN encoded bytes of the row to retrieve.</param>
/// <param name="value">The parsed AlgebraicValue of the row encoded by the <paramref>rowBytes</paramref>.</param>
/// <param name="value">The parsed row encoded by the <paramref>rowBytes</paramref>.</param>
/// <returns>True if the row was inserted, false if the row wasn't inserted because it was a duplicate.</returns>
public bool InsertEntry(byte[] rowBytes, AlgebraicValue value)
{
if (entries.ContainsKey(rowBytes))
{
return false;
}

// Insert the row into our table
entries[rowBytes] = (value, decoderFunc(value));
return true;
}
public bool InsertEntry(byte[] rowBytes, object value) => entries.TryAdd(rowBytes, value);

/// <summary>
/// Deletes a value from the table.
Expand All @@ -104,47 +79,18 @@ public bool InsertEntry(byte[] rowBytes, AlgebraicValue value)
/// <returns>True if and only if the value was previously resident and has been deleted.</returns>
public bool DeleteEntry(byte[] rowBytes)
{
if (entries.TryGetValue(rowBytes, out var value))
if (entries.Remove(rowBytes))
{
entries.Remove(rowBytes);
return true;
}

Logger.LogWarning("Deleting value that we don't have (no cached value available)");
return false;
}

/// <summary>
/// Gets a value from the table
/// </summary>
/// <param name="rowBytes">The BSATN encoded bytes of the row to retrieve.</param>
/// <param name="value">Output: the parsed domain type corresponding to the <paramref>rowBytes</paramref>, or <c>null</c> if the row was not present in the cache.</param>
/// <returns>True if and only if the value is resident and was stored in <paramref>value</paramref>.</returns>
public bool TryGetValue(byte[] rowBytes, out object value)
{
if (entries.TryGetValue(rowBytes, out var v))
{
value = v.Item2;
return true;
}

value = null;
return false;
}

public bool ComparePrimaryKey(AlgebraicValue v1, AlgebraicValue v2)
{
return (bool)ComparePrimaryKeyFunc.Invoke(rowSchema, v1, v2);
}

public AlgebraicValue GetPrimaryKeyValue(AlgebraicValue row)
{
return GetPrimaryKeyValueFunc != null ? GetPrimaryKeyValueFunc.Invoke(row) : null;
}

public AlgebraicType GetPrimaryKeyType()
public object? GetPrimaryKeyValue(object row)
{
return GetPrimaryKeyTypeFunc != null ? GetPrimaryKeyTypeFunc.Invoke(rowSchema) : null;
return GetPrimaryKeyValueFunc?.Invoke(row);
}
}

Expand All @@ -165,33 +111,7 @@ public void AddTable(Type clientTableType, AlgebraicType tableRowDef, Func<Algeb
tables[name] = new TableCache(clientTableType, tableRowDef, decodeFunc);
}

public IEnumerable<object> GetObjects(string name)
{
if (!tables.TryGetValue(name, out var table))
{
yield break;
}

foreach (var entry in table.entries)
{
yield return entry.Value.Item2;
}
}

public IEnumerable<(AlgebraicValue, object)> GetEntries(string name)
{
if (!tables.TryGetValue(name, out var table))
{
yield break;
}

foreach (var entry in table.entries)
{
yield return entry.Value;
}
}

public TableCache GetTable(string name)
public TableCache? GetTable(string name)
{
if (tables.TryGetValue(name, out var table))
{
Expand All @@ -202,17 +122,12 @@ public TableCache GetTable(string name)
return null;
}

public int Count(string name)
public IEnumerable<object> GetObjects(string name)
{
if (!tables.TryGetValue(name, out var table))
{
return 0;
}

return table.entries.Count;
return GetTable(name)?.entries.Values ?? Enumerable.Empty<object>();
}

public IEnumerable<string> GetTableNames() => tables.Keys;
public int Count(string name) => GetTable(name)?.entries.Count ?? 0;

public IEnumerable<TableCache> GetTables() => tables.Values;
}
Expand Down
29 changes: 0 additions & 29 deletions src/SATS/AlgebraicValue.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Linq;
using System.Collections.Generic;
using System.IO;

Expand Down Expand Up @@ -476,33 +475,5 @@ public void Serialize(AlgebraicType type, BinaryWriter writer)
throw new NotImplementedException();
}
}

public class AlgebraicValueComparer : IEqualityComparer<AlgebraicValue>
{
private AlgebraicType type;
public AlgebraicValueComparer(AlgebraicType type)
{
this.type = type;
}

public bool Equals(AlgebraicValue l, AlgebraicValue r)
{
return AlgebraicValue.Compare(type, l, r);
}

public int GetHashCode(AlgebraicValue value)
{
var stream = new MemoryStream();
var writer = new BinaryWriter(stream);
value.Serialize(type, writer);
var s = stream.ToArray();
if (s.Length >= 4)
{
return BitConverter.ToInt32(s, 0);
}
return s.Sum(b => b);
}
}

}
}
Loading

0 comments on commit 25edaff

Please sign in to comment.