Skip to content

Commit

Permalink
feat: improve DbEntityResolver.cs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexyakunin committed Sep 13, 2023
1 parent ed9fbe4 commit 70e3741
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 80 deletions.
7 changes: 7 additions & 0 deletions samples/HelloCart/AppBase.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Samples.HelloCart.V2;
using static System.Console;

namespace Samples.HelloCart;
Expand All @@ -13,6 +14,12 @@ public abstract class AppBase

public virtual async Task InitializeAsync(IServiceProvider services)
{
var dbContext = services.GetService<AppDbContext>();
if (dbContext != null) {
await dbContext.Database.EnsureDeletedAsync();
await dbContext.Database.EnsureCreatedAsync();
}

var commander = services.Commander();

var pApple = new Product { Id = "apple", Price = 2M };
Expand Down
26 changes: 8 additions & 18 deletions samples/HelloCart/v3/DbCartService2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,12 @@

namespace Samples.HelloCart.V3;

public class DbCartService2 : ICartService
public class DbCartService2(
DbHub<AppDbContext> dbHub,
IProductService products,
IDbEntityResolver<string, DbCart> cartResolver
) : ICartService
{
private readonly DbHub<AppDbContext> _dbHub;
private readonly IDbEntityResolver<string, DbCart> _cartResolver;
private readonly IProductService _products;

public DbCartService2(
DbHub<AppDbContext> dbHub,
IProductService products,
IDbEntityResolver<string, DbCart> cartResolver)
{
_dbHub = dbHub;
_cartResolver = cartResolver;
_products = products;
}

public virtual async Task Edit(EditCommand<Cart> command, CancellationToken cancellationToken = default)
{
var (cartId, cart) = command;
Expand All @@ -29,7 +19,7 @@ public virtual async Task Edit(EditCommand<Cart> command, CancellationToken canc
return;
}

await using var dbContext = await _dbHub.CreateCommandDbContext(cancellationToken);
await using var dbContext = await dbHub.CreateCommandDbContext(cancellationToken);
var dbCart = await dbContext.Carts.FindAsync(DbKey.Compose(cartId), cancellationToken);
if (cart == null) {
if (dbCart != null)
Expand Down Expand Up @@ -69,7 +59,7 @@ public virtual async Task Edit(EditCommand<Cart> command, CancellationToken canc

public virtual async Task<Cart?> Get(string id, CancellationToken cancellationToken = default)
{
var dbCart = await _cartResolver.Get(id, cancellationToken);
var dbCart = await cartResolver.Get(id, cancellationToken);
return dbCart == null ? null : new Cart() {
Id = dbCart.Id,
Items = dbCart.Items.ToImmutableDictionary(i => i.DbProductId, i => i.Quantity),
Expand All @@ -82,7 +72,7 @@ public virtual async Task<decimal> GetTotal(string id, CancellationToken cancell
if (cart == null)
return 0;
var itemTotals = await Task.WhenAll(cart.Items.Select(async item => {
var product = await _products.Get(item.Key, cancellationToken);
var product = await products.Get(item.Key, cancellationToken);
return item.Value * (product?.Price ?? 0M);
}));
return itemTotals.Sum();
Expand Down
20 changes: 6 additions & 14 deletions samples/HelloCart/v3/DbProductService2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,11 @@

namespace Samples.HelloCart.V3;

public class DbProductService2 : IProductService
public class DbProductService2(
DbHub<AppDbContext> dbHub,
IDbEntityResolver<string, DbProduct> productResolver
) : IProductService
{
private readonly DbHub<AppDbContext> _dbHub;
private readonly IDbEntityResolver<string, DbProduct> _productResolver;

public DbProductService2(
DbHub<AppDbContext> dbHub,
IDbEntityResolver<string, DbProduct> productResolver)
{
_dbHub = dbHub;
_productResolver = productResolver;
}

public virtual async Task Edit(EditCommand<Product> command, CancellationToken cancellationToken = default)
{
var (productId, product) = command;
Expand All @@ -26,7 +18,7 @@ public virtual async Task Edit(EditCommand<Product> command, CancellationToken c
return;
}

await using var dbContext = await _dbHub.CreateCommandDbContext(cancellationToken);
await using var dbContext = await dbHub.CreateCommandDbContext(cancellationToken);
var dbProduct = await dbContext.Products.FindAsync(DbKey.Compose(productId), cancellationToken);
if (product == null) {
if (dbProduct != null)
Expand All @@ -43,7 +35,7 @@ public virtual async Task Edit(EditCommand<Product> command, CancellationToken c

public virtual async Task<Product?> Get(string id, CancellationToken cancellationToken = default)
{
var dbProduct = await _productResolver.Get(id, cancellationToken);
var dbProduct = await productResolver.Get(id, cancellationToken);
return dbProduct == null ? null : new Product() {
Id = dbProduct.Id,
Price = dbProduct.Price
Expand Down
149 changes: 113 additions & 36 deletions src/Stl.Fusion.EntityFramework/DbEntityResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
using System.Globalization;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Stl.Fusion.EntityFramework.Internal;
using Stl.Internal;
using Stl.Multitenancy;
using Stl.Net;

Expand All @@ -13,8 +13,8 @@ public interface IDbEntityResolver<TKey, TDbEntity>
where TKey : notnull
where TDbEntity : class
{
public Func<Expression, Expression> KeyExtractorExpressionBuilder { get; }
public Func<TDbEntity, TKey> KeyExtractor { get; }
Func<TDbEntity, TKey> KeyExtractor { get; init; }
Expression<Func<TDbEntity, TKey>> KeyExtractorExpression { get; init; }

Task<TDbEntity?> Get(Symbol tenantId, TKey key, CancellationToken cancellationToken = default);
}
Expand All @@ -37,10 +37,10 @@ public record Options
{
public static Options Default { get; set; } = new();

public string? KeyPropertyName { get; init; }
public Func<Expression, Expression>? KeyExtractorExpressionBuilder { get; init; }
public Func<IQueryable<TDbEntity>, IQueryable<TDbEntity>> QueryTransformer { get; init; } = q => q;
public Expression<Func<TDbEntity, TKey>>? KeyExtractor { get; init; }
public Expression<Func<IQueryable<TDbEntity>, IQueryable<TDbEntity>>>? QueryTransformer { get; init; }
public Action<Dictionary<TKey, TDbEntity>> PostProcessor { get; init; } = _ => { };
public int BatchSize { get; init; } = 8;
public Action<BatchProcessor<TKey, TDbEntity?>>? ConfigureBatchProcessor { get; init; }
public TimeSpan? Timeout { get; init; } = TimeSpan.FromSeconds(1);
public IRetryDelayer RetryDelayer { get; init; } = new RetryDelayer() {
Expand All @@ -49,38 +49,55 @@ public record Options
};
}

protected static MethodInfo ContainsMethod { get; } = typeof(HashSet<TKey>).GetMethod(nameof(HashSet<TKey>.Contains))!;
// ReSharper disable once StaticMemberInGenericType
protected static MethodInfo DbContextSetMethod { get; } = typeof(DbContext)
.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.Single(m => m.Name == nameof(DbContext.Set) && m.IsGenericMethod && m.GetParameters().Length == 0)

Check warning on line 55 in src/Stl.Fusion.EntityFramework/DbEntityResolver.cs

View workflow job for this annotation

GitHub Actions / build

.MakeGenericMethod(typeof(TDbEntity));
protected static MethodInfo QueryableWhereMethod { get; }
= new Func<IQueryable<TDbEntity>, Expression<Func<TDbEntity, bool>>, IQueryable<TDbEntity>>(Queryable.Where).Method;

private ConcurrentDictionary<Symbol, BatchProcessor<TKey, TDbEntity?>>? _batchProcessors;
private ITransientErrorDetector<TDbContext>? _transientErrorDetector;

protected Options Settings { get; }
protected string KeyPropertyName { get; init; }
protected (Func<TDbContext, TKey[], IAsyncEnumerable<TDbEntity>> Query, int BatchSize)[] Queries { get; init; }

public Func<Expression, Expression> KeyExtractorExpressionBuilder { get; init; }
public Func<TDbEntity, TKey> KeyExtractor { get; init; }
public Expression<Func<TDbEntity, TKey>> KeyExtractorExpression { get; init; }
public ITransientErrorDetector<TDbContext> TransientErrorDetector =>
_transientErrorDetector ??= Services.GetRequiredService<ITransientErrorDetector<TDbContext>>();

public DbEntityResolver(Options settings, IServiceProvider services) : base(services)
{
Settings = settings;
if (settings.KeyPropertyName == null) {
var keyExtractor = Settings.KeyExtractor;
if (keyExtractor == null) {
var dummyTenant = TenantRegistry.IsSingleTenant ? Tenant.Default : Tenant.Dummy;
using var dbContext = CreateDbContext(dummyTenant);
KeyPropertyName = dbContext.Model
var keyPropertyName = dbContext.Model
.FindEntityType(typeof(TDbEntity))!
.FindPrimaryKey()!
.Properties.Single().Name;

var pEntity = Expression.Parameter(typeof(TDbEntity), "e");
var eBody = Expression.PropertyOrField(pEntity, keyPropertyName);
keyExtractor = Expression.Lambda<Func<TDbEntity, TKey>>(eBody, pEntity);
}
else
KeyPropertyName = settings.KeyPropertyName;
KeyExtractorExpressionBuilder = settings.KeyExtractorExpressionBuilder
?? (eEntity => Expression.PropertyOrField(eEntity, KeyPropertyName));
var pEntity = Expression.Parameter(typeof(TDbEntity), "e");
var eBody = KeyExtractorExpressionBuilder(pEntity);
KeyExtractor = (Func<TDbEntity, TKey>) Expression.Lambda(eBody, pEntity).Compile();
KeyExtractorExpression = keyExtractor;
KeyExtractor = keyExtractor.Compile();
_batchProcessors = new();

var buffer = ArrayBuffer<(Func<TDbContext, TKey[], IAsyncEnumerable<TDbEntity>>, int)>.Lease(false);
try {
for (var batchSize = 2; batchSize < Settings.BatchSize; batchSize *= 2)
buffer.Add((CreateCompiledQuery(batchSize), batchSize));
buffer.Add((CreateCompiledQuery(Settings.BatchSize), Settings.BatchSize));
Queries = buffer.ToArray();
}
finally {
buffer.Release();
}
}

public async ValueTask DisposeAsync()
Expand All @@ -102,11 +119,66 @@ await batchProcessors.Values

// Protected methods

protected Func<TDbContext, TKey[], IAsyncEnumerable<TDbEntity>> CreateCompiledQuery(int batchSize)
{
var pDbContext = Expression.Parameter(typeof(TDbContext), "dbContext");
var pKeys = new ParameterExpression[batchSize];
for (var i = 0; i < batchSize; i++)
pKeys[i] = Expression.Parameter(typeof(TKey), $"key{i.ToString(CultureInfo.InvariantCulture)}");
var pEntity = Expression.Parameter(typeof(TDbEntity), "e");

// entity.Key expression
var eKey = KeyExtractorExpression.Body.Replace(KeyExtractorExpression.Parameters[0], pEntity);

// .Where predicate expression
var ePredicate = (Expression?)null;
for (var i = 0; i < batchSize; i++) {
var eCondition = Expression.Equal(eKey, pKeys[i]);
ePredicate = ePredicate == null ? eCondition : Expression.OrElse(ePredicate, eCondition);
}
var lPredicate = Expression.Lambda<Func<TDbEntity, bool>>(ePredicate!, pEntity);

// dbContext.Set<TDbEntity>().Where(...)
var eEntitySet = Expression.Call(pDbContext, DbContextSetMethod);
var eWhere = Expression.Call(null, QueryableWhereMethod, eEntitySet, Expression.Quote(lPredicate));

// Applying QueryTransformer
var qt = Settings.QueryTransformer;
var eBody = qt == null
? eWhere
: qt.Body.Replace(qt.Parameters[0], eWhere);

// Creating compiled query
var lambdaParameters = new ParameterExpression[batchSize + 1];
lambdaParameters[0] = pDbContext;
pKeys.CopyTo(lambdaParameters, 1);
var lambda = Expression.Lambda(eBody, lambdaParameters);
var query = new CompiledAsyncEnumerableQuery<TDbContext, TDbEntity>(lambda);

// Locating query.Execute methods
var mExecute = query.GetType()
.GetMethods()
.SingleOrDefault(m => m.Name == nameof(query.Execute)
&& m.IsGenericMethod
&& m.GetGenericArguments().Length == batchSize)
?.MakeGenericMethod(pKeys.Select(p => p.Type).ToArray());
if (mExecute == null)
throw Errors.BatchSizeIsTooLarge();

// Creating compiled query invoker
var pAllKeys = Expression.Parameter(typeof(TKey[]));
var eDbContext = Enumerable.Range(0, 1).Select(_ => (Expression)pDbContext);
var eAllKeys = Enumerable.Range(0, batchSize).Select(i => Expression.ArrayIndex(pAllKeys, Expression.Constant(i)));
var eExecuteCall = Expression.Call(Expression.Constant(query), mExecute, eDbContext.Concat(eAllKeys));
return (Func<TDbContext, TKey[], IAsyncEnumerable<TDbEntity>>)Expression.Lambda(eExecuteCall, pDbContext, pAllKeys).Compile();
}

protected BatchProcessor<TKey, TDbEntity?> GetBatchProcessor(Symbol tenantId)
{
var batchProcessors = _batchProcessors;
if (batchProcessors == null)
throw Stl.Internal.Errors.AlreadyDisposed(GetType());

return batchProcessors.GetOrAdd(tenantId,
static (tenantId1, self) => self.CreateBatchProcessor(tenantId1), this);
}
Expand All @@ -115,11 +187,14 @@ await batchProcessors.Values
{
var tenant = TenantRegistry.Get(tenantId);
var batchProcessor = new BatchProcessor<TKey, TDbEntity?> {
MaxBatchSize = 16,
BatchSize = Settings.BatchSize,
ConcurrencyLevel = 1,
Implementation = (batch, cancellationToken) => ProcessBatch(tenant, batch, cancellationToken),
};
Settings.ConfigureBatchProcessor?.Invoke(batchProcessor);
if (batchProcessor.BatchSize != Settings.BatchSize)
throw Errors.BatchSizeCannotBeChanged();

return batchProcessor;
}

Expand All @@ -138,41 +213,43 @@ protected virtual async Task ProcessBatch(
List<BatchItem<TKey, TDbEntity?>> batch,
CancellationToken cancellationToken)
{
if (batch.Count == 0)
return;

using var activity = StartProcessBatchActivity(tenant, batch.Count);
var (query, batchSize) = Queries.First(q => q.BatchSize >= batch.Count);
for (var tryIndex = 0;; tryIndex++) {
var dbContext = CreateDbContext(tenant);
await using var _ = dbContext.ConfigureAwait(false);
try {
var keys = new HashSet<TKey>();
foreach (var item in batch) {
var keys = new TKey[batchSize];
var i = 0;
foreach (var item in batch)
if (!item.TryCancel())
keys.Add(item.Input);
}
var pEntity = Expression.Parameter(typeof(TDbEntity), "e");
var eKey = KeyExtractorExpressionBuilder(pEntity);
var eBody = Expression.Call(Expression.Constant(keys), ContainsMethod, eKey);
var eLambda = (Expression<Func<TDbEntity, bool>>) Expression.Lambda(eBody, pEntity);
var query = Settings.QueryTransformer(dbContext.Set<TDbEntity>().Where(eLambda));
keys[i++] = item.Input;
var lastKey = keys[i - 1];
for (; i < batchSize; i++)
keys[i] = lastKey;

Dictionary<TKey, TDbEntity>? entities;
var entities = new Dictionary<TKey, TDbEntity>();
if (Settings.Timeout is { } timeout) {
using var cts = new CancellationTokenSource(timeout);
using var linkedCts = cancellationToken.LinkWith(cts.Token);
try {
entities = await query
.ToDictionaryAsync(KeyExtractor, linkedCts.Token)
.ConfigureAwait(false);
var result = query.Invoke(dbContext, keys);
await foreach (var e in result.WithCancellation(cancellationToken).ConfigureAwait(false))
entities.Add(KeyExtractor.Invoke(e), e);
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) {
throw new TimeoutException();
}
}
else {
entities = await query
.ToDictionaryAsync(KeyExtractor, cancellationToken)
.ConfigureAwait(false);
var result = query.Invoke(dbContext, keys);
await foreach (var e in result.WithCancellation(cancellationToken).ConfigureAwait(false))
entities.Add(KeyExtractor.Invoke(e), e);
}
Settings.PostProcessor(entities);
Settings.PostProcessor.Invoke(entities);

foreach (var item in batch) {
var entity = entities.GetValueOrDefault(item.Input);
Expand Down
5 changes: 5 additions & 0 deletions src/Stl.Fusion.EntityFramework/Internal/Errors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,9 @@ public static Exception UserIdRequired()

public static Exception UnsupportedDbHint(DbHint hint)
=> new NotSupportedException($"Unsupported DbHint: {hint}");

public static Exception BatchSizeCannotBeChanged()
=> new InvalidOperationException("ConfigureBatchProcessor delegate cannot change BatchProcessor's BatchSize.");
public static Exception BatchSizeIsTooLarge()
=> new InvalidOperationException("DbEntityResolver's BatchSize is too large.");
}
22 changes: 22 additions & 0 deletions src/Stl.Fusion.EntityFramework/Internal/ExpressionExt.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query;

namespace Stl.Fusion.EntityFramework.Internal;

public static class ExpressionExt
{
public static Expression Replace(this Expression source,
Expression from, Expression to)
=> new ReplacingExpressionVisitor(new[] { from }, new [] { to }).Visit(source);

public static Expression Replace(this Expression source,
Expression from1, Expression to1,
Expression from2, Expression to2)
=> new ReplacingExpressionVisitor(new[] { from1, from2 }, new [] { to1, to2 }).Visit(source);

public static Expression Replace(this Expression source,
Expression from1, Expression to1,
Expression from2, Expression to2,
Expression from3, Expression to3)
=> new ReplacingExpressionVisitor(new[] { from1, from2, from3 }, new [] { to1, to2, to3 }).Visit(source);
}
Loading

0 comments on commit 70e3741

Please sign in to comment.